Train.py

一、处理流程

  1. class  STsim_Trainer

    (1)初始化def __init__(self)

      (2)  验证方法def ST_eval()---不使用,归并到训练方法中

            (3)  训练方法def ST_train()--核心

  2. def ST_train()

   (1)创建编码器对象---net = STTrajSimEncoder()---来自model_network.py文件

     (2)   数据加载并预处理---dataload = data_utils.DataLoader()---来自data_utils.py文件

     (3)  构建三元组训练样本对---dataload.get_triplets()---来自data_utils.py文件

               (4)  计算获取三元组训练样本对间地面真实距离---data_utils.triplet_groud_truth()---来自data_utils.py文件

        (5)  创建优化器对象optimizer = torch.optim.Adam()

               (6)  创建损失函数对象lossfunction = LossFun()

     (7)  获取batch批次数 bt_num = int(dataload.return_triplets_num() / self.train_batch)【训练集20000/128=156】

     (8)  加载node、d2vec的三元组数据---batch_l = data_utils.batch_list()---apn_node_triplets、apn_d2vec_triplets

              ------------------------------------------------------------------------------------------

   for epoch in range(int(lastepoch), self.epochs):  ---开始训练,设定训练轮数epoch=150次

    net.train()---将模型设置为训练模式

      for bt in range(bt_num):  ---按批次依次执行

      (1)  获取每个批次参与数据---a_node_batch, a_time_batch, p_node_batch, p_time_batch, n_node_batch, n_time_batch, batch_index = batch_l.getbatch_one()

          锚点节点、锚点时间、正样本节点、正样本时间、负样本节点、负样本时间、批次开始索引

      (2)  获取每个锚点、正样本、负样本嵌入学习表示 --- a_embedding/p_embedding / n_embedding  = net(road_network, a_node_batch, a_time_batch)

      (3)  计算损失函数值---loss = lossfunction(a_embedding, p_embedding, n_embedding, batch_index)

                          (4)  学习配置重置  --- optimizer.zero_grad()、loss.backward()、optimizer.step()

         (5)  判断训练轮数是否为偶数次,若是,则将模型设置为验证模式,获取度量指标值  HR10、HR50、HR1050

二、重点函数解读

三、知识扩展

  1. 相关导入数据包

  import torch: PyTorch 是一个机器学习框架,提供了张量计算和构建神经网络的功能。
  import torch.nn as nn: 导入 PyTorch 的神经网络模块,其中包含了神经网络的各种层和函数。
  import torch.nn.functional as F: 导入 PyTorch 中的函数模块,其中包含了各种激活函数和损失函数等。
  from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence: 从 PyTorch 的神经网络工具中导入了两个函数,用于处理可变长度序列的填充和打包操作。
  from torch_geometric.nn import GCNConv: 导入 PyTorch Geometric 库中的 GCNConv 类,用于定义图卷积网络(Graph Convolutional Network,GCN)的层。
  from Model import Date2VecConvert: 这是从名为 Model 的模块中导入了一个自定义的类 Date2VecConvert,这个类可能是作者自己定义的用于日期向量转换的模型。
  import time: 这是导入 Python 的时间模块,用于处理时间相关的操作。
  import datetime: 这是导入 Python 的日期时间模块,用于处理日期时间相关的操作。
  import numpy as np: 这是导入 NumPy 库,并将其重命名为 np。NumPy 是 Python 中用于科学计算的核心库,它提供了多维数组对象和各种数学函数,是许多机器学习和数据处理任务的基础。
 

  2. 文件扩展名 .pth

    通常表示 PyTorch 模型的保存文件。在 PyTorch 中,.pth 文件通常包含了模型的权重参数以及其他必要的信息,可以通过加载这个文件来恢复模型的状态。

  在提供的路径 /d2v_model/d2v_98291_17.169918439404636.pth 中,.pth 文件应该是一个保存了某个模型状态的文件。根据文件名的一般惯例,这个文件可能是一个名为 d2v_98291_17.169918439404636 的模型在某次训练后保存的结果,其中包含了模型的权重参数等信息。

  如果你想要使用这个模型,可以通过 PyTorch 提供的加载模型的函数来加载这个 .pth 文件,然后就可以使用这个模型进行推理或者继续训练。加载模型的代码可能类似于:

  import torch
  # 创建一个与模型结构相同的实例
  model = YourModelClass()
  # 加载模型参数
  model.load_state_dict(torch.load('/d2v_model/d2v_98291_17.169918439404636.pth'))
  # 将模型设置为评估模式(不进行梯度计算)
  model.eval()  

  3. BallTree

    BallTree 是一种数据结构,用于高效地组织和检索多维空间中的数据点。它是一种基于树结构的数据结构,通常用于近似最近邻搜索(Approximate Nearest Neighbor Search)等问题。

  在给定一组数据点后,BallTree 会将这些数据点递归地划分为球形区域(ball)并构建一颗树。每个节点代表一个球形区域,而叶子节点包含了实际的数据点。BallTree 使得在多维空间中搜索最近邻点变得高效,因为它可以通过递归地比较球形区域之间的距离来确定搜索路径,从而避免了对所有数据点进行线性搜索。

  在你提供的代码中,BallTree 可能被用于构建一个 Ball 树,以便在训练数据集 sample_train2D 上执行最近邻搜索或其他空间查询任务。一旦 Ball 树被构建完成,你可以使用它来快速找到训练数据集中某个点的最近邻点,或者执行其他类似的空间查询操作。

  总之,BallTree 是一个用于高效处理多维空间数据的数据结构,常用于近似最近邻搜索和空间查询等任务。【data_utils.py文件中,get_triplet函数,ball_tree = BallTree(sample_train2D)】

   **通过BallTree来获取锚点轨迹的正样本:

    dist, index = ball_tree.query([sample_train2D[i]], j+1) # k nearest neighbors
    p_index = list(index[0])
    p_index = p_index[-1]

    p_sample = train_node_list[p_index]   # positive sample
  
  **通过随机采样来获取锚点轨迹的负样本:

    n_index = random.randint(0,len(train_node_list)-1)
    n_sample = train_node_list[n_index]   # negative sample
    a_sample = train_node_list[i]                # 锚点轨迹

  4.net.train()

     在 PyTorch 中,net.train() 是用于将模型设置为训练模式的方法。当调用 net.train() 时,模型会进入训练模式,这意味着模型中的某些层(如 Dropout 和 Batch Normalization)将以训练模式运行,而不是推断模式。

   在训练模式下,这些层的行为可能会有所不同,例如,在 Dropout 层中,会随机丢弃一些神经元以进行正则化。总之,net.train() 的作用是告诉模型开始训练阶段,以便适当地调整网络中的某些层的行为。

  5. seq_lengths = list(map(len, traj_seqs))

    计算列表 traj_seqs 中每个元素的长度,并将这些长度存储在 seq_lengths 列表中。

  6. 对序列进行填充,使得所有序列的长度相等

    traj_one += [0]*(max(seq_lengths)-len(traj_one)):对于每个序列 traj_one,它会使用 [0] 来填充序列,直到其长度等于最长序列的长度。

四、想法与疑惑

dataload = dataload = data_utils.DataLoader().DataLoader()     # 数据预处理
posted @ 2025-05-07 14:33  才品  阅读(96)  评论(0)    收藏  举报