From 9181413d0929f12671f16a46f2a8cb74180b4f0c Mon Sep 17 00:00:00 2001 From: HengZhang Date: Mon, 3 Mar 2025 15:22:05 +0800 Subject: [PATCH] Update transfer_guide.md --- transfer_guide.md | 658 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 658 insertions(+) diff --git a/transfer_guide.md b/transfer_guide.md index e69de29..be7b44c 100644 --- a/transfer_guide.md +++ b/transfer_guide.md @@ -0,0 +1,658 @@ +# 模型迁移教程 + +这里以[STGODE](https://github.com/square-coder/STGODE/tree/main)为例。 + + +1. [创建模型](#创建模型) +2. [修改 import](#修改import) +3. [修改最外层 Module 与创建配置文件](#修改最外层Module与创建配置文件) +4. [图的传入](#图的传入) +5. [在 model_selector 中添加自己的模型](#在model_selector中添加自己的模型) +6. [使用 debug 调试模型,修改 input x, output y 的 shape](#使用debug调试模型,修改input x, output y的shape) +7. [开始训练模型](#开始训练模型) + + + +# 创建模型 + +确定模型的名称 `{model_name}`。推荐全大写,记住这个名称。 + +在model文件夹下创建模型文件夹,以 `{model_name}` 命名。这里存放模型文件。 + +STGODE的模型文件一共有两个,`model.py, odegcn.py`,把他拷贝到`TrafficWheel/model/STGODE`下 + + + +# 修改import + +原仓库中,model.py文件import了odegcn.py文件下的ODEG模块。 + +```python +from odegcn import ODEG +``` + +迁移到TrafficWheel中,我们需要将odegcn指向我们的文件夹。 + +```python +from model.STGODE.odegcn import ODEG +``` + + + +# 修改最外层Module与创建配置文件 + +我们只关注最外层的`nn.Module`类,这里STGODE最外层的`nn.Module`类是`class ODEGCN(nn.Module):` + +```python +class ODEGCN(nn.Module): + """ the overall network framework """ + def __init__(self, num_nodes, num_features, num_timesteps_input, + num_timesteps_output, A_sp_hat, A_se_hat): + """ + Args: + num_nodes : number of nodes in the graph + num_features : number of features at each node in each time step + num_timesteps_input : number of past time steps fed into the network + num_timesteps_output : desired number of future time steps output by the network + A_sp_hat : nomarlized adjacency spatial matrix + A_se_hat : nomarlized adjacency semantic matrix + """ + + super(ODEGCN, self).__init__() + # spatial graph + self.sp_blocks = nn.ModuleList( + [nn.Sequential( + STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], + num_nodes=num_nodes, A_hat=A_sp_hat), + STGCNBlock(in_channels=64, out_channels=[64, 32, 64], + num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3) + ]) + # semantic graph + self.se_blocks = nn.ModuleList([nn.Sequential( + STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], + num_nodes=num_nodes, A_hat=A_se_hat), + STGCNBlock(in_channels=64, out_channels=[64, 32, 64], + num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3) + ]) + + self.pred = nn.Sequential( + nn.Linear(num_timesteps_input * 64, num_timesteps_output * 32), + nn.ReLU(), + nn.Linear(num_timesteps_output * 32, num_timesteps_output) + ) +``` + +我们先关注init函数,即模型定义。这里,我们要看init传入了什么东西,这里传入了一堆: + +```python +def __init__(self, num_nodes, num_features, num_timesteps_input, + num_timesteps_output, A_sp_hat, A_se_hat): +``` + +好在作者给了注释。我们的TrafficWheel在定义模型时只传入一个字典args,所有的参数都在args中提取。 + +我们把init的传参修改为args字典,并使用args['key']从字典中访问参数,像这样 + +```python + def __init__(self, args): + super(ODEGCN, self).__init__() + num_nodes = args['num_nodes'] + num_features = args['num_features'] + num_timesteps_input = args['history'] + num_timesteps_output = args['horizon'] + A_sp_hat, A_se_hat = get_A_hat(args) # 这个之后讲 +``` + +这里的args读取了`config/STGODE/PEMSDX.yaml`的配置文件。写配置文件最简单的方法是,先在config目录中创建文件夹,与模型名一致为STGODE。再从config中已有模型的yaml文件中复制参数。例如,这里从STGNCDE中复制了一份PEMSD4.yaml的配置文件,放在`config/STGODE`下。文件名需要与数据集名一致,可以是PEMSD{3,4,7,8}中的任一种。 + +STGNCDE在model部分的配置如下: + +```yaml +model: + type: type1 + g_type: agc + input_dim: 2 + output_dim: 1 + embed_dim: 10 + hid_dim: 128 + hid_hid_dim: 128 + num_layers: 2 + cheb_k: 2 + solver: rk4 +``` + +你需要保留`input_dim`和`output_dim`这两个属性,`num_nodes`这个属性会从data中自动拷贝一份到model,不用写。 + +删掉其余参数,然后将刚刚init中访问的key的value写到yaml文件里,如下: + +``` +model: + input_dim: 1 + output_dim: 1 + history: 12 + horizon: 12 + num_features: 1 +``` + +后续如果在子模块中有新的参数,就在config后续补充,并把args传给子模块,在子模块的init中使用args访问参数。 + +yaml支持int值,bool值`{True, False}`,字符串str,列表`[[64,64,64],[64,64,64],[64,64,64]]`,在访问args时会自动转换。 + + + +# 图的传入 + +在修改init时,有一个比较棘手的地方,原版代码中传入了两个矩阵,是tensor类,而yaml不支持tensor(也不可能把矩阵写入tensor)。 + +我们需要自己构造函数来生成矩阵,在`model/STGODE`中创建一个新文件,我命名为`adj.py`。在这里构造一个`get_A_hat`的方法,传参是args。返回所需要的两个矩阵。 + +在`STGODE.py`里面import这个方法,并调用。 + +```py +from model.STGODE.adj import get_A_hat +``` + +```python + def __init__(self, args): + ...略 + A_sp_hat, A_se_hat = get_A_hat(args) +``` + +这里我们需要看原版代码中矩阵是怎么生成并传入init的。 + +原版代码的`run_stode.py`里面的77~87行 + +```python +data, mean, std, dtw_matrix, sp_matrix = read_data(args) +train_loader, valid_loader, test_loader = generate_dataset(data, args) +A_sp_wave = get_normalized_adj(sp_matrix).to(device) +A_se_wave = get_normalized_adj(dtw_matrix).to(device) + +net = ODEGCN(num_nodes=data.shape[1], + num_features=data.shape[2], + num_timesteps_input=args.his_length, + num_timesteps_output=args.pred_length, + A_sp_hat=A_sp_wave, + A_se_hat=A_se_wave) +``` + +这里先用`read_data(args)`读取数据,返回两个矩阵`dtw_matrix, sp_matrix`,再调用`get_normalized_adj(matrix)`标准化两个矩阵。 + +我们发现`read_data`和`get_normalized_adj`这两个函数都在`utils.py`里面,所以在utils里面找这两个函数 + +```python +from utils import generate_dataset, read_data, get_normalized_adj +``` + + + +我们将这两个函数复制到TrafficWheel中的adj.py中 + +```python +files = { + 'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'], + 'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'], + 'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'], + 'pems08': ['PEMS08/pems08.npz', 'PEMS08/distance.csv'], + 'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'], + 'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'], + 'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv'] +} + +def read_data(args): + """read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw + + Args: + sigma1: float, default=0.1, sigma for the semantic matrix + sigma2: float, default=10, sigma for the spatial matrix + thres1: float, default=0.6, the threshold for the semantic matrix + thres2: float, default=0.5, the threshold for the spatial matrix + + Returns: + data: tensor, T * N * 1 + dtw_matrix: array, semantic adjacency matrix + sp_matrix: array, spatial adjacency matrix + """ + filename = args.filename + file = files[filename] + filepath = "./data/" + if args.remote: + filepath = '/home/lantu.lqq/ftemp/data/' + data = np.load(filepath + file[0])['data'] + + num_node = data.shape[1] + mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1) + std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1) + data = (data - mean_value) / std_value + mean_value = mean_value.reshape(-1)[0] + std_value = std_value.reshape(-1)[0] + + if not os.path.exists(f'data/{filename}_dtw_distance.npy'): + data_mean = np.mean([data[:, :, 0][24*12*i: 24*12*(i+1)] for i in range(data.shape[0]//(24*12))], axis=0) + data_mean = data_mean.squeeze().T + dtw_distance = np.zeros((num_node, num_node)) + for i in tqdm(range(num_node)): + for j in range(i, num_node): + dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0] + for i in range(num_node): + for j in range(i): + dtw_distance[i][j] = dtw_distance[j][i] + np.save(f'data/{filename}_dtw_distance.npy', dtw_distance) + + dist_matrix = np.load(f'data/{filename}_dtw_distance.npy') + + mean = np.mean(dist_matrix) + std = np.std(dist_matrix) + dist_matrix = (dist_matrix - mean) / std + sigma = args.sigma1 + dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2) + dtw_matrix = np.zeros_like(dist_matrix) + dtw_matrix[dist_matrix > args.thres1] = 1 + + if not os.path.exists(f'data/{filename}_spatial_distance.npy'): + with open(filepath + file[1], 'r') as fp: + dist_matrix = np.zeros((num_node, num_node)) + np.float('inf') + file = csv.reader(fp) + for line in file: + break + for line in file: + start = int(line[0]) + end = int(line[1]) + dist_matrix[start][end] = float(line[2]) + dist_matrix[end][start] = float(line[2]) + np.save(f'data/{filename}_spatial_distance.npy', dist_matrix) + + dist_matrix = np.load(f'data/{filename}_spatial_distance.npy') + # normalization + std = np.std(dist_matrix[dist_matrix != np.float('inf')]) + mean = np.mean(dist_matrix[dist_matrix != np.float('inf')]) + dist_matrix = (dist_matrix - mean) / std + sigma = args.sigma2 + sp_matrix = np.exp(- dist_matrix**2 / sigma**2) + sp_matrix[sp_matrix < args.thres2] = 0 + # np.save(f'data/{filename}_sp_c_matrix.npy', sp_matrix) + # sp_matrix = np.load(f'data/{filename}_sp_c_matrix.npy') + + print(f'average degree of spatial graph is {np.sum(sp_matrix > 0)/2/num_node}') + print(f'average degree of semantic graph is {np.sum(dtw_matrix > 0)/2/num_node}') + return torch.from_numpy(data.astype(np.float32)), mean_value, std_value, dtw_matrix, sp_matrix +``` + +由于文件太长,我们需要删掉不必要的部分,最后我们只需要返回两个矩阵`dtw_matrix, sp_matrix`,因此把无关的data,mean_value, std_value相关函数去掉。 + +```python +def read_data(args): + """read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw + + Args: + sigma1: float, default=0.1, sigma for the semantic matrix + sigma2: float, default=10, sigma for the spatial matrix + thres1: float, default=0.6, the threshold for the semantic matrix + thres2: float, default=0.5, the threshold for the spatial matrix + + Returns: + data: tensor, T * N * 1 + dtw_matrix: array, semantic adjacency matrix + sp_matrix: array, spatial adjacency matrix + """ + files = { + 'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'], + 'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'], + 'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'], + 'pems08': ['PEMS08/pems08.npz', 'PEMS08/distance.csv'], + 'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'], + 'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'], + 'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv'] + } + + filename = args.filename + file = files[filename] + filepath = "./data/" + data = np.load(filepath + file[0])['data'] + + num_node = data.shape[1] + mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1) + std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1) + data = (data - mean_value) / std_value + + if not os.path.exists(f'data/{filename}_dtw_distance.npy'): + data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))], + axis=0) + data_mean = data_mean.squeeze().T + dtw_distance = np.zeros((num_node, num_node)) + for i in tqdm(range(num_node)): + for j in range(i, num_node): + dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0] + for i in range(num_node): + for j in range(i): + dtw_distance[i][j] = dtw_distance[j][i] + np.save(f'data/{filename}_dtw_distance.npy', dtw_distance) + + dist_matrix = np.load(f'data/{filename}_dtw_distance.npy') + + mean = np.mean(dist_matrix) + std = np.std(dist_matrix) + dist_matrix = (dist_matrix - mean) / std + sigma = args.sigma1 + dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2) + dtw_matrix = np.zeros_like(dist_matrix) + dtw_matrix[dist_matrix > args.thres1] = 1 + + if not os.path.exists(f'data/{filename}_spatial_distance.npy'): + with open(filepath + file[1], 'r') as fp: + dist_matrix = np.zeros((num_node, num_node)) + float('inf') + file = csv.reader(fp) + for line in file: + break + for line in file: + start = int(line[0]) + end = int(line[1]) + dist_matrix[start][end] = float(line[2]) + dist_matrix[end][start] = float(line[2]) + np.save(f'data/{filename}_spatial_distance.npy', dist_matrix) + + dist_matrix = np.load(f'data/{filename}_spatial_distance.npy') + # normalization + std = np.std(dist_matrix[dist_matrix != float('inf')]) + mean = np.mean(dist_matrix[dist_matrix != float('inf')]) + dist_matrix = (dist_matrix - mean) / std + sigma = args.sigma2 + sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2) + sp_matrix[sp_matrix < args.thres2] = 0 + + return dtw_matrix, sp_matrix +``` + +原版函数也传入了args,并通过args.sigma这种成员变量的方式访问参数。而我们是用字典传参,所以要将类似`args.sigma`改成`args['sigma']`,并在config中配置相应参数. + +```python +sigma = args['sigma1'] +dtw_matrix[dist_matrix > args['thres1']] = 1 +sigma = args['sigma2'] +sp_matrix[sp_matrix < args['thres2']] = 0 +``` + +这里需要注意的是,由于PEMSD3的距离文件中,from和to是传感器编号(而不是从0开始的下标),所以要通过data中的`PEMSD3.txt`文件做个映射,将传感器编号映射到0到357的下标。 + +```csv +from,to,distance +317842,318711,0.872 +318721,315955,1.322 +... +``` + +```py + # 计算spatial_distance, 如果存在缓存则直接读取缓存 + if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy'): + if num_node == 358: + with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f: + id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表 + # 使用 pandas 读取 CSV 文件,跳过标题行 + df = pd.read_csv(filepath + file[1], skiprows=1, header=None) + dist_matrix = np.zeros((num_node, num_node)) + float('inf') + for _, row in df.iterrows(): + start = int(id_dict[row[0]]) + end = int(id_dict[row[1]]) + dist_matrix[start][end] = float(row[2]) + dist_matrix[end][start] = float(row[2]) + np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix) + else: + # 使用 pandas 读取 CSV 文件,跳过标题行 + df = pd.read_csv(filepath + file[1], skiprows=1, header=None) + dist_matrix = np.zeros((num_node, num_node)) + float('inf') + for _, row in df.iterrows(): + start = int(row[0]) + end = int(row[1]) + dist_matrix[start][end] = float(row[2]) + dist_matrix[end][start] = float(row[2]) + np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix) +``` + + + +完成对read_data()函数的修改后,get_normalized_adj()函数倒是不复杂,传入一个narray,返回一个tensor,直接复制即可。 + +```py +def get_normalized_adj(A): + """ + Returns a tensor, the degree normalized adjacency matrix. + """ + alpha = 0.8 + D = np.array(np.sum(A, axis=1)).reshape((-1,)) + D[D <= 10e-5] = 10e-5 # Prevent infs + diag = np.reciprocal(np.sqrt(D)) + A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A), + diag.reshape((1, -1))) + A_reg = alpha / 2 * (np.eye(A.shape[0]) + A_wave) + return torch.from_numpy(A_reg.astype(np.float32)) +``` + + + +最后,源代码是在read_data()外面调用get_normalized_adj()来返回标准化的tensor矩阵的,这样写太麻烦。我们直接在read_data的return上完成`get_normalized_adj(matrix).to(device)`完成这个操作,保持model代码的简洁。 + +```py +data, mean, std, dtw_matrix, sp_matrix = read_data(args) +train_loader, valid_loader, test_loader = generate_dataset(data, args) +A_sp_wave = get_normalized_adj(sp_matrix).to(device) +A_se_wave = get_normalized_adj(dtw_matrix).to(device) +``` + +修改为 + +```py +return get_normalized_adj(dtw_matrix).to(args['device']), get_normalized_adj(sp_matrix).to(args['device']) +``` + +`args['device']`不用写在config中,由运行命令指定。 + +最终,`model/STGODE/adj.py`文件下的`get_A_hat()`函数如下: + +```python +def get_A_hat(args): + """read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw + + Args: + sigma1: float, default=0.1, sigma for the semantic matrix + sigma2: float, default=10, sigma for the spatial matrix + thres1: float, default=0.6, the threshold for the semantic matrix + thres2: float, default=0.5, the threshold for the spatial matrix + + Returns: + data: tensor, T * N * 1 + dtw_matrix: array, semantic adjacency matrix + sp_matrix: array, spatial adjacency matrix + """ + filepath = './data/' + num_node = args['num_nodes'] + file = files[num_node] + filename = file[0][:6] + + data = np.load(filepath + file[0])['data'] + num_node = data.shape[1] + mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1) + std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1) + data = (data - mean_value) / std_value + + # 计算dtw_distance, 如果存在缓存则直接读取缓存 + if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy'): + data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))], + axis=0) + data_mean = data_mean.squeeze().T + dtw_distance = np.zeros((num_node, num_node)) + for i in tqdm(range(num_node)): + for j in range(i, num_node): + dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0] + for i in range(num_node): + for j in range(i): + dtw_distance[i][j] = dtw_distance[j][i] + np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy', dtw_distance) + + dist_matrix = np.load(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy') + + mean = np.mean(dist_matrix) + std = np.std(dist_matrix) + dist_matrix = (dist_matrix - mean) / std + sigma = args['sigma1'] + dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2) + dtw_matrix = np.zeros_like(dist_matrix) + dtw_matrix[dist_matrix > args['thres1']] = 1 + + # 计算spatial_distance, 如果存在缓存则直接读取缓存 + if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy'): + if num_node == 358: + with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f: + id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表 + # 使用 pandas 读取 CSV 文件,跳过标题行 + df = pd.read_csv(filepath + file[1], skiprows=1, header=None) + dist_matrix = np.zeros((num_node, num_node)) + float('inf') + for _, row in df.iterrows(): + start = int(id_dict[row[0]]) + end = int(id_dict[row[1]]) + dist_matrix[start][end] = float(row[2]) + dist_matrix[end][start] = float(row[2]) + np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix) + else: + # 使用 pandas 读取 CSV 文件,跳过标题行 + df = pd.read_csv(filepath + file[1], skiprows=1, header=None) + dist_matrix = np.zeros((num_node, num_node)) + float('inf') + for _, row in df.iterrows(): + start = int(row[0]) + end = int(row[1]) + dist_matrix[start][end] = float(row[2]) + dist_matrix[end][start] = float(row[2]) + np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix) + # normalization + std = np.std(dist_matrix[dist_matrix != float('inf')]) + mean = np.mean(dist_matrix[dist_matrix != float('inf')]) + dist_matrix = (dist_matrix - mean) / std + sigma = args['sigma2'] + sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2) + sp_matrix[sp_matrix < args['thres2']] = 0 + + return get_normalized_adj(dtw_matrix).to(args['device']), get_normalized_adj(sp_matrix).to(args['device']) +``` + +在model的init函数中调用,返回两个tensor矩阵。 + +```py +A_sp_hat, A_se_hat = get_A_hat(args) +``` + +里层的`nn.Module`,一般不必修改任何内容。因为初始化里层的`nn.Module`所需要的相关参数已经在最外层的`nn.Module`中有定义。除非里层`nn.Module`需要额外的图。 + + + +# 在model_selector中添加自己的模型 + +到了这一步,需要先在`model/model_selector.py`中添加自己的模型。 + +先import最外层的模型的`nn.Module`,再仿照前面模型的样式添加上自己的 + +```python +from model.DDGCRN.DDGCRN import DDGCRN +from model.TWDGCN.TWDGCN import TWDGCN +# ... +from model.STGODE.STGODE import ODEGCN # 新添加 + + + +def model_selector(model): + match model['type']: + case 'DDGCRN': return DDGCRN(model) + case 'TWDGCN': return TWDGCN(model) + # ... + case 'STGODE': return ODEGCN(model) # 新添加 + +``` + +其中,case就是文首提到的`{model_name}` + + + +# 使用debug调试模型,修改input x, output y的shape + +在`model/model_selector.py`中添加自己的模型之后,就可以进行debug了。首先,你需要在pycharm中编写配置。写参数。 + +```bash +python run.py --model STGODE --dataset PEMSD4 --mode train --device cuda:0 +``` + +之后在自己的forward上打断点。若程序运行时报错,说明init配置的还是有问题,模型未能正常初始化,请重新检查。 + +若初始化完成后,就可以开始修改forward函数了。 + + + +TrafficWheel的外层会对模型传入一个shape为**`[batch_size, time_step, num_nodes, dim]`**的张量,一般情况下,batch_size为64,time_step为12,num_nodes取决于数据集,dim是输入维度为3。其中,第1维为流量数据,第2维为日嵌入,第3维为周嵌入。 + +需要观察原版模型的forward函数里传入什么shape的tensor。一般情况下,不需要日嵌入,周嵌入。 + +这里,STGODE里面传入的shape是 `(batch_size, num_nodes, num_timesteps, num_features)`,与我们传入的shape不一致,所以我们需要修改。 + +以PEMSD4为例,首先,截取 `x = x[..., 0:1]`,会使得x的shape由(64,12,307,3)变为(64,12,307,1),即取流量维度。 + +再通过`x = x.permute(0, 2, 1, 3)` 将`[batch_size, time_step, num_nodes, dim]`转换为 `(batch_size, num_nodes, num_timesteps, num_features)`。输入维度dim和特征features其实是一个意思。 + +```py + def forward(self, x): + """ + Args: + x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F) + Returns: + prediction for future of shape (batch_size, num_nodes, num_timesteps_output) + """ + x = x[..., 0:1].permute(0, 2, 1, 3) +``` + +最后修改输出,在forward最后的return这个地方打断点,如果输入无误,可以直接debug,幸运的话就能一次跑到断点处。不幸运的话可能会遇到各种问题(通常是原代码仓库的问题)。这时候需要debug各个模块的forward中是否有张量不匹配的地方。问题一般就出在这。修改这些bug,直到程序跑到return断点处。 + +```python + def forward(self, x): + """ + Args: + x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F) + Returns: + prediction for future of shape (batch_size, num_nodes, num_timesteps_output) + """ + ... + + return self.pred(x) +``` + +这里在return时还套了个pred函数,为了便于观察,我们可以直接在外层打断点(通常是`trainer/Trainer.py`的第55行,观察output变量)。然而,原版的代码已经用注释告诉我们输出的shape,所以我们只需要看注释就可以了。 + +我们看到输出的shape是(batch_size, num_nodes, num_timesteps_output),而**TrafficWheel接受的输出shape是`[batch_size, time_step, num_nodes, dim]`**,所以我们需要进行转换。 + +在PEMSD4下,原版shape是(64,307,12),首先通过`permute(0,2,1)`调整为(64,12,307),再通过`unsqueeze(dim=-1)`调整为(64,12,307,1)。以符合输出的shape。 + +```py +return self.pred(x).permute(0,2,1).unsqueeze(dim=-1) +``` + + + +调整shape常用的技巧:注意,下述方法需要用一个x接受。不能没有返回值。 + +例如 `x=x.permute(0,2,1,3)`是正确的,不能是`x.permute(0,2,1,3)` + +| 操作 | 原张量shape | 操作后张量shape | +| ----------------- | ------------- | --------------------- | +| permute(0,2,1,3) | (64,12,307,3) | (64,307,12,3) | +| x[..., 0:1] | (64,12,307,3) | (64,12,307,1) | +| x[..., 0] | (64,12,307,3) | (64,12,307) 取第0维 | +| squeeze(dim=-1) | (64,12,307,1) | (64,12,307) | +| unsqueeze(dim=-1) | (64,12,307) | (64,12,307,1) | + +更多的可以查pytorch的官方文档,比较常用的就这几个 + + + +# 开始训练模型 + +在一切就绪后(forward的input, output的shape对得上),就可以训练模型了。可以使用jupyter notebook批量训练,也可以在pycharm中配置任务。取决于个人喜好。训练的代码为: + +```bash +python run.py --model STGODE --dataset PEMSD4 --mode train --device cuda:0 +``` +