# 模型迁移教程 这里以[STGODE](https://github.com/square-coder/STGODE/tree/main)为例。 1. [创建模型](#创建模型) 2. [修改 import](#修改import) 3. [修改最外层 Module 与创建配置文件](#修改最外层Module与创建配置文件) 4. [图的传入](#图的传入) 5. [在 model_selector 中添加自己的模型](#在model_selector中添加自己的模型) 6. [调整输入输出的shape](#调整输入输出的shape) 7. [开始训练模型](#开始训练模型) # 创建模型 确定模型的名称 `{model_name}`。推荐全大写,记住这个名称。 在model文件夹下创建模型文件夹,以 `{model_name}` 命名。这里存放模型文件。 STGODE的模型文件一共有两个,`model.py, odegcn.py`,把他拷贝到`TrafficWheel/model/STGODE`下 # 修改import 原仓库中,model.py文件import了odegcn.py文件下的ODEG模块。 ```python from odegcn import ODEG ``` 迁移到TrafficWheel中,我们需要将odegcn指向我们的文件夹。 ```python from model.STGODE.odegcn import ODEG ``` # 修改最外层Module与创建配置文件 我们只关注最外层的`nn.Module`类,这里STGODE最外层的`nn.Module`类是`class ODEGCN(nn.Module):` ```python class ODEGCN(nn.Module): """ the overall network framework """ def __init__(self, num_nodes, num_features, num_timesteps_input, num_timesteps_output, A_sp_hat, A_se_hat): """ Args: num_nodes : number of nodes in the graph num_features : number of features at each node in each time step num_timesteps_input : number of past time steps fed into the network num_timesteps_output : desired number of future time steps output by the network A_sp_hat : nomarlized adjacency spatial matrix A_se_hat : nomarlized adjacency semantic matrix """ super(ODEGCN, self).__init__() # spatial graph self.sp_blocks = nn.ModuleList( [nn.Sequential( STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat), STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3) ]) # semantic graph self.se_blocks = nn.ModuleList([nn.Sequential( STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat), STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3) ]) self.pred = nn.Sequential( nn.Linear(num_timesteps_input * 64, num_timesteps_output * 32), nn.ReLU(), nn.Linear(num_timesteps_output * 32, num_timesteps_output) ) ``` 我们先关注init函数,即模型定义。这里,我们要看init传入了什么东西,这里传入了一堆: ```python def __init__(self, num_nodes, num_features, num_timesteps_input, num_timesteps_output, A_sp_hat, A_se_hat): ``` 好在作者给了注释。我们的TrafficWheel在定义模型时只传入一个字典args,所有的参数都在args中提取。 我们把init的传参修改为args字典,并使用args['key']从字典中访问参数,像这样 ```python def __init__(self, args): super(ODEGCN, self).__init__() num_nodes = args['num_nodes'] num_features = args['num_features'] num_timesteps_input = args['history'] num_timesteps_output = args['horizon'] A_sp_hat, A_se_hat = get_A_hat(args) # 这个之后讲 ``` 这里的args读取了`config/STGODE/PEMSDX.yaml`的配置文件。写配置文件最简单的方法是,先在config目录中创建文件夹,与模型名一致为STGODE。再从config中已有模型的yaml文件中复制参数。例如,这里从STGNCDE中复制了一份PEMSD4.yaml的配置文件,放在`config/STGODE`下。文件名需要与数据集名一致,可以是PEMSD{3,4,7,8}中的任一种。 STGNCDE在model部分的配置如下: ```yaml model: type: type1 g_type: agc input_dim: 2 output_dim: 1 embed_dim: 10 hid_dim: 128 hid_hid_dim: 128 num_layers: 2 cheb_k: 2 solver: rk4 ``` 你需要保留`input_dim`和`output_dim`这两个属性,`num_nodes`这个属性会从data中自动拷贝一份到model,不用写。 删掉其余参数,然后将刚刚init中访问的key的value写到yaml文件里,如下: ``` model: input_dim: 1 output_dim: 1 history: 12 horizon: 12 num_features: 1 ``` 后续如果在子模块中有新的参数,就在config后续补充,并把args传给子模块,在子模块的init中使用args访问参数。 yaml支持int值,bool值`{True, False}`,字符串str,列表`[[64,64,64],[64,64,64],[64,64,64]]`,在访问args时会自动转换。 # 图的传入 在修改init时,有一个比较棘手的地方,原版代码中传入了两个矩阵,是tensor类,而yaml不支持tensor(也不可能把矩阵写入tensor)。 我们需要自己构造函数来生成矩阵,在`model/STGODE`中创建一个新文件,我命名为`adj.py`。在这里构造一个`get_A_hat`的方法,传参是args。返回所需要的两个矩阵。 在`STGODE.py`里面import这个方法,并调用。 ```py from model.STGODE.adj import get_A_hat ``` ```python def __init__(self, args): ...略 A_sp_hat, A_se_hat = get_A_hat(args) ``` 这里我们需要看原版代码中矩阵是怎么生成并传入init的。 原版代码的`run_stode.py`里面的77~87行 ```python data, mean, std, dtw_matrix, sp_matrix = read_data(args) train_loader, valid_loader, test_loader = generate_dataset(data, args) A_sp_wave = get_normalized_adj(sp_matrix).to(device) A_se_wave = get_normalized_adj(dtw_matrix).to(device) net = ODEGCN(num_nodes=data.shape[1], num_features=data.shape[2], num_timesteps_input=args.his_length, num_timesteps_output=args.pred_length, A_sp_hat=A_sp_wave, A_se_hat=A_se_wave) ``` 这里先用`read_data(args)`读取数据,返回两个矩阵`dtw_matrix, sp_matrix`,再调用`get_normalized_adj(matrix)`标准化两个矩阵。 我们发现`read_data`和`get_normalized_adj`这两个函数都在`utils.py`里面,所以在utils里面找这两个函数 ```python from utils import generate_dataset, read_data, get_normalized_adj ``` 我们将这两个函数复制到TrafficWheel中的adj.py中 ```python files = { 'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'], 'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'], 'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'], 'pems08': ['PEMS08/pems08.npz', 'PEMS08/distance.csv'], 'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'], 'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'], 'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv'] } def read_data(args): """read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw Args: sigma1: float, default=0.1, sigma for the semantic matrix sigma2: float, default=10, sigma for the spatial matrix thres1: float, default=0.6, the threshold for the semantic matrix thres2: float, default=0.5, the threshold for the spatial matrix Returns: data: tensor, T * N * 1 dtw_matrix: array, semantic adjacency matrix sp_matrix: array, spatial adjacency matrix """ filename = args.filename file = files[filename] filepath = "./data/" if args.remote: filepath = '/home/lantu.lqq/ftemp/data/' data = np.load(filepath + file[0])['data'] num_node = data.shape[1] mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1) std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1) data = (data - mean_value) / std_value mean_value = mean_value.reshape(-1)[0] std_value = std_value.reshape(-1)[0] if not os.path.exists(f'data/{filename}_dtw_distance.npy'): data_mean = np.mean([data[:, :, 0][24*12*i: 24*12*(i+1)] for i in range(data.shape[0]//(24*12))], axis=0) data_mean = data_mean.squeeze().T dtw_distance = np.zeros((num_node, num_node)) for i in tqdm(range(num_node)): for j in range(i, num_node): dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0] for i in range(num_node): for j in range(i): dtw_distance[i][j] = dtw_distance[j][i] np.save(f'data/{filename}_dtw_distance.npy', dtw_distance) dist_matrix = np.load(f'data/{filename}_dtw_distance.npy') mean = np.mean(dist_matrix) std = np.std(dist_matrix) dist_matrix = (dist_matrix - mean) / std sigma = args.sigma1 dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2) dtw_matrix = np.zeros_like(dist_matrix) dtw_matrix[dist_matrix > args.thres1] = 1 if not os.path.exists(f'data/{filename}_spatial_distance.npy'): with open(filepath + file[1], 'r') as fp: dist_matrix = np.zeros((num_node, num_node)) + np.float('inf') file = csv.reader(fp) for line in file: break for line in file: start = int(line[0]) end = int(line[1]) dist_matrix[start][end] = float(line[2]) dist_matrix[end][start] = float(line[2]) np.save(f'data/{filename}_spatial_distance.npy', dist_matrix) dist_matrix = np.load(f'data/{filename}_spatial_distance.npy') # normalization std = np.std(dist_matrix[dist_matrix != np.float('inf')]) mean = np.mean(dist_matrix[dist_matrix != np.float('inf')]) dist_matrix = (dist_matrix - mean) / std sigma = args.sigma2 sp_matrix = np.exp(- dist_matrix**2 / sigma**2) sp_matrix[sp_matrix < args.thres2] = 0 # np.save(f'data/{filename}_sp_c_matrix.npy', sp_matrix) # sp_matrix = np.load(f'data/{filename}_sp_c_matrix.npy') print(f'average degree of spatial graph is {np.sum(sp_matrix > 0)/2/num_node}') print(f'average degree of semantic graph is {np.sum(dtw_matrix > 0)/2/num_node}') return torch.from_numpy(data.astype(np.float32)), mean_value, std_value, dtw_matrix, sp_matrix ``` 由于文件太长,我们需要删掉不必要的部分,最后我们只需要返回两个矩阵`dtw_matrix, sp_matrix`,因此把无关的data,mean_value, std_value相关函数去掉。 ```python def read_data(args): """read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw Args: sigma1: float, default=0.1, sigma for the semantic matrix sigma2: float, default=10, sigma for the spatial matrix thres1: float, default=0.6, the threshold for the semantic matrix thres2: float, default=0.5, the threshold for the spatial matrix Returns: data: tensor, T * N * 1 dtw_matrix: array, semantic adjacency matrix sp_matrix: array, spatial adjacency matrix """ files = { 'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'], 'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'], 'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'], 'pems08': ['PEMS08/pems08.npz', 'PEMS08/distance.csv'], 'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'], 'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'], 'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv'] } filename = args.filename file = files[filename] filepath = "./data/" data = np.load(filepath + file[0])['data'] num_node = data.shape[1] mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1) std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1) data = (data - mean_value) / std_value if not os.path.exists(f'data/{filename}_dtw_distance.npy'): data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))], axis=0) data_mean = data_mean.squeeze().T dtw_distance = np.zeros((num_node, num_node)) for i in tqdm(range(num_node)): for j in range(i, num_node): dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0] for i in range(num_node): for j in range(i): dtw_distance[i][j] = dtw_distance[j][i] np.save(f'data/{filename}_dtw_distance.npy', dtw_distance) dist_matrix = np.load(f'data/{filename}_dtw_distance.npy') mean = np.mean(dist_matrix) std = np.std(dist_matrix) dist_matrix = (dist_matrix - mean) / std sigma = args.sigma1 dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2) dtw_matrix = np.zeros_like(dist_matrix) dtw_matrix[dist_matrix > args.thres1] = 1 if not os.path.exists(f'data/{filename}_spatial_distance.npy'): with open(filepath + file[1], 'r') as fp: dist_matrix = np.zeros((num_node, num_node)) + float('inf') file = csv.reader(fp) for line in file: break for line in file: start = int(line[0]) end = int(line[1]) dist_matrix[start][end] = float(line[2]) dist_matrix[end][start] = float(line[2]) np.save(f'data/{filename}_spatial_distance.npy', dist_matrix) dist_matrix = np.load(f'data/{filename}_spatial_distance.npy') # normalization std = np.std(dist_matrix[dist_matrix != float('inf')]) mean = np.mean(dist_matrix[dist_matrix != float('inf')]) dist_matrix = (dist_matrix - mean) / std sigma = args.sigma2 sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2) sp_matrix[sp_matrix < args.thres2] = 0 return dtw_matrix, sp_matrix ``` 原版函数也传入了args,并通过args.sigma这种成员变量的方式访问参数。而我们是用字典传参,所以要将类似`args.sigma`改成`args['sigma']`,并在config中配置相应参数. ```python sigma = args['sigma1'] dtw_matrix[dist_matrix > args['thres1']] = 1 sigma = args['sigma2'] sp_matrix[sp_matrix < args['thres2']] = 0 ``` 这里需要注意的是,由于PEMSD3的距离文件中,from和to是传感器编号(而不是从0开始的下标),所以要通过data中的`PEMSD3.txt`文件做个映射,将传感器编号映射到0到357的下标。 ```csv from,to,distance 317842,318711,0.872 318721,315955,1.322 ... ``` ```py # 计算spatial_distance, 如果存在缓存则直接读取缓存 if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy'): if num_node == 358: with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f: id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表 # 使用 pandas 读取 CSV 文件,跳过标题行 df = pd.read_csv(filepath + file[1], skiprows=1, header=None) dist_matrix = np.zeros((num_node, num_node)) + float('inf') for _, row in df.iterrows(): start = int(id_dict[row[0]]) end = int(id_dict[row[1]]) dist_matrix[start][end] = float(row[2]) dist_matrix[end][start] = float(row[2]) np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix) else: # 使用 pandas 读取 CSV 文件,跳过标题行 df = pd.read_csv(filepath + file[1], skiprows=1, header=None) dist_matrix = np.zeros((num_node, num_node)) + float('inf') for _, row in df.iterrows(): start = int(row[0]) end = int(row[1]) dist_matrix[start][end] = float(row[2]) dist_matrix[end][start] = float(row[2]) np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix) ``` 完成对read_data()函数的修改后,get_normalized_adj()函数倒是不复杂,传入一个narray,返回一个tensor,直接复制即可。 ```py def get_normalized_adj(A): """ Returns a tensor, the degree normalized adjacency matrix. """ alpha = 0.8 D = np.array(np.sum(A, axis=1)).reshape((-1,)) D[D <= 10e-5] = 10e-5 # Prevent infs diag = np.reciprocal(np.sqrt(D)) A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A), diag.reshape((1, -1))) A_reg = alpha / 2 * (np.eye(A.shape[0]) + A_wave) return torch.from_numpy(A_reg.astype(np.float32)) ``` 最后,源代码是在read_data()外面调用get_normalized_adj()来返回标准化的tensor矩阵的,这样写太麻烦。我们直接在read_data的return上完成`get_normalized_adj(matrix).to(device)`完成这个操作,保持model代码的简洁。 ```py data, mean, std, dtw_matrix, sp_matrix = read_data(args) train_loader, valid_loader, test_loader = generate_dataset(data, args) A_sp_wave = get_normalized_adj(sp_matrix).to(device) A_se_wave = get_normalized_adj(dtw_matrix).to(device) ``` 修改为 ```py return get_normalized_adj(dtw_matrix).to(args['device']), get_normalized_adj(sp_matrix).to(args['device']) ``` `args['device']`不用写在config中,由运行命令指定。 最终,`model/STGODE/adj.py`文件下的`get_A_hat()`函数如下: ```python def get_A_hat(args): """read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw Args: sigma1: float, default=0.1, sigma for the semantic matrix sigma2: float, default=10, sigma for the spatial matrix thres1: float, default=0.6, the threshold for the semantic matrix thres2: float, default=0.5, the threshold for the spatial matrix Returns: data: tensor, T * N * 1 dtw_matrix: array, semantic adjacency matrix sp_matrix: array, spatial adjacency matrix """ filepath = './data/' num_node = args['num_nodes'] file = files[num_node] filename = file[0][:6] data = np.load(filepath + file[0])['data'] num_node = data.shape[1] mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1) std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1) data = (data - mean_value) / std_value # 计算dtw_distance, 如果存在缓存则直接读取缓存 if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy'): data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))], axis=0) data_mean = data_mean.squeeze().T dtw_distance = np.zeros((num_node, num_node)) for i in tqdm(range(num_node)): for j in range(i, num_node): dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0] for i in range(num_node): for j in range(i): dtw_distance[i][j] = dtw_distance[j][i] np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy', dtw_distance) dist_matrix = np.load(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy') mean = np.mean(dist_matrix) std = np.std(dist_matrix) dist_matrix = (dist_matrix - mean) / std sigma = args['sigma1'] dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2) dtw_matrix = np.zeros_like(dist_matrix) dtw_matrix[dist_matrix > args['thres1']] = 1 # 计算spatial_distance, 如果存在缓存则直接读取缓存 if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy'): if num_node == 358: with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f: id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表 # 使用 pandas 读取 CSV 文件,跳过标题行 df = pd.read_csv(filepath + file[1], skiprows=1, header=None) dist_matrix = np.zeros((num_node, num_node)) + float('inf') for _, row in df.iterrows(): start = int(id_dict[row[0]]) end = int(id_dict[row[1]]) dist_matrix[start][end] = float(row[2]) dist_matrix[end][start] = float(row[2]) np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix) else: # 使用 pandas 读取 CSV 文件,跳过标题行 df = pd.read_csv(filepath + file[1], skiprows=1, header=None) dist_matrix = np.zeros((num_node, num_node)) + float('inf') for _, row in df.iterrows(): start = int(row[0]) end = int(row[1]) dist_matrix[start][end] = float(row[2]) dist_matrix[end][start] = float(row[2]) np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix) # normalization std = np.std(dist_matrix[dist_matrix != float('inf')]) mean = np.mean(dist_matrix[dist_matrix != float('inf')]) dist_matrix = (dist_matrix - mean) / std sigma = args['sigma2'] sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2) sp_matrix[sp_matrix < args['thres2']] = 0 return get_normalized_adj(dtw_matrix).to(args['device']), get_normalized_adj(sp_matrix).to(args['device']) ``` 在model的init函数中调用,返回两个tensor矩阵。 ```py A_sp_hat, A_se_hat = get_A_hat(args) ``` 里层的`nn.Module`,一般不必修改任何内容。因为初始化里层的`nn.Module`所需要的相关参数已经在最外层的`nn.Module`中有定义。除非里层`nn.Module`需要额外的图。 # 在model_selector中添加自己的模型 到了这一步,需要先在`model/model_selector.py`中添加自己的模型。 先import最外层的模型的`nn.Module`,再仿照前面模型的样式添加上自己的 ```python from model.DDGCRN.DDGCRN import DDGCRN from model.TWDGCN.TWDGCN import TWDGCN # ... from model.STGODE.STGODE import ODEGCN # 新添加 def model_selector(model): match model['type']: case 'DDGCRN': return DDGCRN(model) case 'TWDGCN': return TWDGCN(model) # ... case 'STGODE': return ODEGCN(model) # 新添加 ``` 其中,case就是文首提到的`{model_name}` # 调整输入输出的shape 在`model/model_selector.py`中添加自己的模型之后,就可以进行debug了。首先,你需要在pycharm中编写配置。写参数。 ```bash python run.py --model STGODE --dataset PEMSD4 --mode train --device cuda:0 ``` 之后在自己的forward上打断点。若程序运行时报错,说明init配置的还是有问题,模型未能正常初始化,请重新检查。 若初始化完成后,就可以开始修改forward函数了。 TrafficWheel的外层会对模型传入一个shape为**`[batch_size, time_step, num_nodes, dim]`**的张量,一般情况下,batch_size为64,time_step为12,num_nodes取决于数据集,dim是输入维度为3。其中,第1维为流量数据,第2维为日嵌入,第3维为周嵌入。 需要观察原版模型的forward函数里传入什么shape的tensor。一般情况下,不需要日嵌入,周嵌入。 这里,STGODE里面传入的shape是 `(batch_size, num_nodes, num_timesteps, num_features)`,与我们传入的shape不一致,所以我们需要修改。 以PEMSD4为例,首先,截取 `x = x[..., 0:1]`,会使得x的shape由(64,12,307,3)变为(64,12,307,1),即取流量维度。 再通过`x = x.permute(0, 2, 1, 3)` 将`[batch_size, time_step, num_nodes, dim]`转换为 `(batch_size, num_nodes, num_timesteps, num_features)`。输入维度dim和特征features其实是一个意思。 ```py def forward(self, x): """ Args: x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F) Returns: prediction for future of shape (batch_size, num_nodes, num_timesteps_output) """ x = x[..., 0:1].permute(0, 2, 1, 3) ``` 最后修改输出,在forward最后的return这个地方打断点,如果输入无误,可以直接debug,幸运的话就能一次跑到断点处。不幸运的话可能会遇到各种问题(通常是原代码仓库的问题)。这时候需要debug各个模块的forward中是否有张量不匹配的地方。问题一般就出在这。修改这些bug,直到程序跑到return断点处。 ```python def forward(self, x): """ Args: x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F) Returns: prediction for future of shape (batch_size, num_nodes, num_timesteps_output) """ ... return self.pred(x) ``` 这里在return时还套了个pred函数,为了便于观察,我们可以直接在外层打断点(通常是`trainer/Trainer.py`的第55行,观察output变量)。然而,原版的代码已经用注释告诉我们输出的shape,所以我们只需要看注释就可以了。 我们看到输出的shape是(batch_size, num_nodes, num_timesteps_output),而**TrafficWheel接受的输出shape是`[batch_size, time_step, num_nodes, dim]`**,所以我们需要进行转换。 在PEMSD4下,原版shape是(64,307,12),首先通过`permute(0,2,1)`调整为(64,12,307),再通过`unsqueeze(dim=-1)`调整为(64,12,307,1)。以符合输出的shape。 ```py return self.pred(x).permute(0,2,1).unsqueeze(dim=-1) ``` 调整shape常用的技巧:注意,下述方法需要用一个x接受。不能没有返回值。 例如 `x=x.permute(0,2,1,3)`是正确的,不能是`x.permute(0,2,1,3)` | 操作 | 原张量shape | 操作后张量shape | | ----------------- | ------------- | --------------------- | | permute(0,2,1,3) | (64,12,307,3) | (64,307,12,3) | | x[..., 0:1] | (64,12,307,3) | (64,12,307,1) | | x[..., 0] | (64,12,307,3) | (64,12,307) 取第0维 | | squeeze(dim=-1) | (64,12,307,1) | (64,12,307) | | unsqueeze(dim=-1) | (64,12,307) | (64,12,307,1) | 更多的可以查pytorch的官方文档,比较常用的就这几个 # 开始训练模型 在一切就绪后(forward的input, output的shape对得上),就可以训练模型了。可以使用jupyter notebook批量训练,也可以在pycharm中配置任务。取决于个人喜好。训练的代码为: ```bash python run.py --model STGODE --dataset PEMSD4 --mode train --device cuda:0 ```