Train.py
一、处理流程
1. class STsim_Trainer
(1)初始化def __init__(self)
(2) 验证方法def ST_eval()---不使用,归并到训练方法中
(3) 训练方法def ST_train()--核心
2. def ST_train()
(1)创建编码器对象---net = STTrajSimEncoder()---来自model_network.py文件
(2) 数据加载并预处理---dataload = data_utils.DataLoader()---来自data_utils.py文件
(3) 构建三元组训练样本对---dataload.get_triplets()---来自data_utils.py文件
(4) 计算获取三元组训练样本对间地面真实距离---data_utils.triplet_groud_truth()---来自data_utils.py文件
(5) 创建优化器对象optimizer = torch.optim.Adam()
(6) 创建损失函数对象lossfunction = LossFun()
(7) 获取batch批次数 bt_num = int(dataload.return_triplets_num() / self.train_batch)【训练集20000/128=156】
(8) 加载node、d2vec的三元组数据---batch_l = data_utils.batch_list()---apn_node_triplets、apn_d2vec_triplets
------------------------------------------------------------------------------------------
for epoch in range(int(lastepoch), self.epochs): ---开始训练,设定训练轮数epoch=150次
net.train()---将模型设置为训练模式
for bt in range(bt_num): ---按批次依次执行
(1) 获取每个批次参与数据---a_node_batch, a_time_batch, p_node_batch, p_time_batch, n_node_batch, n_time_batch, batch_index = batch_l.getbatch_one()
锚点节点、锚点时间、正样本节点、正样本时间、负样本节点、负样本时间、批次开始索引
(2) 获取每个锚点、正样本、负样本嵌入学习表示 --- a_embedding/p_embedding / n_embedding = net(road_network, a_node_batch, a_time_batch)
(3) 计算损失函数值---loss = lossfunction(a_embedding, p_embedding, n_embedding, batch_index)
(4) 学习配置重置 --- optimizer.zero_grad()、loss.backward()、optimizer.step()
(5) 判断训练轮数是否为偶数次,若是,则将模型设置为验证模式,获取度量指标值 HR10、HR50、HR1050
二、重点函数解读
三、知识扩展
1. 相关导入数据包
import torch: PyTorch 是一个机器学习框架,提供了张量计算和构建神经网络的功能。 import torch.nn as nn: 导入 PyTorch 的神经网络模块,其中包含了神经网络的各种层和函数。 import torch.nn.functional as F: 导入 PyTorch 中的函数模块,其中包含了各种激活函数和损失函数等。 from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence: 从 PyTorch 的神经网络工具中导入了两个函数,用于处理可变长度序列的填充和打包操作。 from torch_geometric.nn import GCNConv: 导入 PyTorch Geometric 库中的 GCNConv 类,用于定义图卷积网络(Graph Convolutional Network,GCN)的层。 from Model import Date2VecConvert: 这是从名为 Model 的模块中导入了一个自定义的类 Date2VecConvert,这个类可能是作者自己定义的用于日期向量转换的模型。 import time: 这是导入 Python 的时间模块,用于处理时间相关的操作。 import datetime: 这是导入 Python 的日期时间模块,用于处理日期时间相关的操作。 import numpy as np: 这是导入 NumPy 库,并将其重命名为 np。NumPy 是 Python 中用于科学计算的核心库,它提供了多维数组对象和各种数学函数,是许多机器学习和数据处理任务的基础。 2. 文件扩展名 .pth
通常表示 PyTorch 模型的保存文件。在 PyTorch 中,.pth 文件通常包含了模型的权重参数以及其他必要的信息,可以通过加载这个文件来恢复模型的状态。
在提供的路径 /d2v_model/d2v_98291_17.169918439404636.pth 中,.pth 文件应该是一个保存了某个模型状态的文件。根据文件名的一般惯例,这个文件可能是一个名为 d2v_98291_17.169918439404636 的模型在某次训练后保存的结果,其中包含了模型的权重参数等信息。
如果你想要使用这个模型,可以通过 PyTorch 提供的加载模型的函数来加载这个 .pth 文件,然后就可以使用这个模型进行推理或者继续训练。加载模型的代码可能类似于:
import torch
# 创建一个与模型结构相同的实例
model = YourModelClass()
# 加载模型参数
model.load_state_dict(torch.load('/d2v_model/d2v_98291_17.169918439404636.pth'))
# 将模型设置为评估模式(不进行梯度计算)
model.eval()
3. BallTree
BallTree BallTree 是一种数据结构,用于高效地组织和检索多维空间中的数据点。它是一种基于树结构的数据结构,通常用于近似最近邻搜索(Approximate Nearest Neighbor Search)等问题。
在给定一组数据点后,BallTree 会将这些数据点递归地划分为球形区域(ball)并构建一颗树。每个节点代表一个球形区域,而叶子节点包含了实际的数据点。BallTree 使得在多维空间中搜索最近邻点变得高效,因为它可以通过递归地比较球形区域之间的距离来确定搜索路径,从而避免了对所有数据点进行线性搜索。
在你提供的代码中,BallTree 可能被用于构建一个 Ball 树,以便在训练数据集 sample_train2D 上执行最近邻搜索或其他空间查询任务。一旦 Ball 树被构建完成,你可以使用它来快速找到训练数据集中某个点的最近邻点,或者执行其他类似的空间查询操作。
总之,BallTree 是一个用于高效处理多维空间数据的数据结构,常用于近似最近邻搜索和空间查询等任务。【data_utils.py文件中,get_triplet函数,ball_tree = BallTree(sample_train2D)】
**通过BallTree来获取锚点轨迹的正样本:
dist, index = ball_tree.query([sample_train2D[i]], j+1) # k nearest neighbors
p_index = list(index[0])
p_index = p_index[-1]
p_sample = train_node_list[p_index] # positive sample
**通过随机采样来获取锚点轨迹的负样本:
n_index = random.randint(0,len(train_node_list)-1)
n_sample = train_node_list[n_index] # negative sample
a_sample = train_node_list[i] # 锚点轨迹
4.net.train()
在 PyTorch 中,net.train() 是用于将模型设置为训练模式的方法。当调用 net.train() 时,模型会进入训练模式,这意味着模型中的某些层(如 Dropout 和 Batch Normalization)将以训练模式运行,而不是推断模式。
在训练模式下,这些层的行为可能会有所不同,例如,在 Dropout 层中,会随机丢弃一些神经元以进行正则化。总之,net.train() 的作用是告诉模型开始训练阶段,以便适当地调整网络中的某些层的行为。
5. seq_lengths = list(map(len, traj_seqs))
计算列表 traj_seqs 中每个元素的长度,并将这些长度存储在 seq_lengths 列表中。
6. 对序列进行填充,使得所有序列的长度相等
traj_one += [0]*(max(seq_lengths)-len(traj_one)):对于每个序列 traj_one,它会使用 [0] 来填充序列,直到其长度等于最长序列的长度。
四、想法与疑惑
dataload = dataload = data_utils.DataLoader().DataLoader() # 数据预处理

浙公网安备 33010602011771号