This is the implementation for the paper
Zhe Huang, Ruohua Li, Kazuki Shin, Katherine Driggs-Campbell
published in RA-L.
GST is the abbreviation of our model Gumbel Social Transformer. All code was developed and tested on Ubuntu 18.04 with CUDA 10.2, Python 3.6.9, and PyTorch 1.7.1.
If you find this repo useful, please cite
@article{huang2022learning,
title={Learning Sparse Interaction Graphs of Partially Detected Pedestrians for Trajectory Prediction},
author={Huang, Zhe and Li, Ruohua and Shin, Kazuki and Driggs-Campbell, Katherine},
journal={IEEE Robotics and Automation Letters},
year={2022},
volume={7},
number={2},
pages={1198-1205},
doi={10.1109/LRA.2021.3138547}
}
virtualenv -p /usr/bin/python3 myenv
source myenv/bin/activate
You can run either
pip install -r requirements.txt
or
pip install numpy
pip install scipy
pip install matplotlib
pip install tensorboardX
pip install torch==1.7.1
If you want to use tensorboard --logdir results
to check training curves, install tensorflow
by running
pip install tensorflow
sh run/make_dirs.sh
sh run/create_datasets.sh
To train and evaluate a model with n=1, i.e., the target pedestrian pays attention to at most one partially observed pedestrian, run
sh run/train_sparse.sh
sh run/eval_sparse.sh
To train and evaluate a model with n=1 and temporal component as a temporal convolution network, run
sh run/train_sparse_tcn.sh
sh run/eval_sparse_tcn.sh
To train and evaluate a model with full connection, i.e., the target pedestrian pays attention to all partially observed pedestrians in the scene, run
sh run/train_full_connection.sh
sh run/eval_full_connection.sh
To train and evaluate a model in which the target pedestrian pays attention to all fully observed pedestrians in the scene, run
sh run/train_full_connection_fully_observed.sh
sh run/eval_full_connection_fully_observed.sh
--spatial_num_heads_edges
: n, i.e., the upperbound number of pedestrians that the target pedestrian can pay attention to in the scene. When n=0, it is defined as full connection, i.e., the target pedestrian pays attention to all pedestrians in the scene. Default is 4.--only_observe_full_period
: The target pedestrian only pays attention to fully observed pedestrians. Default is False.--temporal
: Temporal component.lstm
is Masked LSTM, andtemporal_convolution_net
is temporal convolution network. Default islstm
.--decode_style
: Decoding style. It has to match the option--temporal
.recursive
matcheslstm
, andreadout
matchestemporal_convolution_net
. Default isrecursive
.--ghost
: Add a ghost pedestrian in the scene to encourage sparsity. When--spatial_num_heads_edges
is set as zero, i.e., the target pedestrian pays attention to all pedestrians in the scene,--ghost
has to be set as False. Default is False.
Part of the code is based on the following works and repos:
[1] Mohamed, Abduallah, et al. "Social-stgcnn: A social spatio-temporal graph convolutional neural network for human trajectory prediction." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020. [GitHub]
[2] Pytorch implementation of Multi-head Attention. [Modules] [Functional]
Please feel free to open an issue or send an email to [email protected].