logo
0
0
WeChat Login
添夹中文文档说明

DiffTraj 中文说明

DiffTraj 是一个基于扩散概率模型的 GPS 轨迹生成项目。

环境要求

项目说明中的基础要求如下:

  • Python >= 3.7
  • torch >= 1.7
  • pandas
  • numpy
  • matplotlib
  • colored

说明文档中提到的 pathlibshutildatetimemath 属于 Python 标准库,无需单独安装。

基于当前代码实际导入,项目还使用了:

  • scikit-learn

Conda 环境

当前项目已经在根目录创建了本地 conda 环境:

conda activate /root/DiffTraj/conda_env

如果你需要复现当前环境,可以直接使用项目中的 env.yml

conda env create -f env.yml

训练轨迹生成模型

  1. 准备你自己的 GPS 轨迹数据,并确保具备合法使用权限。
  2. 修改 utils/config.pymain.py 中对应的数据路径与配置。
  3. 运行训练:
python main.py
  1. 训练过程中,代码会每 10 个 epoch 保存一次模型。

轨迹生成使用方式

可以参考 process/Traj_Generation.ipynb

仓库中已经提供了以下示例文件:

  • model.pt:预训练模型
  • heads.npy:示例引导信息

你可以结合 Notebook 查看如何使用已有模型生成轨迹。

生成结果示例

Chengdu

引用

如果你在研究中使用了本项目,请引用:

@inproceedings{zhu2023DiffTraj, author = {Yuanshao Zhu, Yongchao Ye, Shiyao Zhang, Xiangyu Zhao and James, J.Q. Yu}, title = {DiffTraj: Generating GPS Trajectory with Diffusion Probabilistic Model}, booktitle = {Proceedings of the 37th Annual Conference on Neural Information Processing Systems}, year = {2023} }