关于模型预测用的是哪个回合

Looking at the code, I can help clarify the model saving and inference setup:

  1. Model Saving:
  • Yes, the model is saved at every epoch in the train() method:
def train(self):
    self.train_dataset = TrajectoryLoader(self.params.train_dataset, self.road_network.node2id, self.params.batch_size, self.params.label_num)
    for i in range(self.params.epochs):
        self.train_epoch(i, 1, self.train_dataset)
        # if i%10==0:  # This comment suggests it used to save every 10 epochs
        self.save(i)
  1. Inference:
  • In your code, you're using epoch 10 for inference since you specify load_model="test_10":
causalTAD = CausalTAD(save_model=None, load_model="test_10", city="xian")

This means the model will load the weights from epoch 10 to perform the testing. However, this might not be optimal since:

  1. The model trains for 400 epochs (specified in Params: self.epochs = 400)
  2. Later epochs might perform better as the model has had more time to learn
  3. You might want to evaluate multiple checkpoints to find the best performing one

I would recommend:

  1. Either save checkpoints less frequently (e.g., every 10 epochs) to save storage space
  2. Or evaluate multiple checkpoints to find the optimal epoch for inference
  3. Consider implementing early stopping based on validation performance

Would you like me to help you modify the code to implement any of these suggestions?

posted @ 2024-12-17 17:03  GraphL  阅读(11)  评论(0)    收藏  举报