TrafficWheel/transfer_guide.md

26 KiB
Executable File
Raw Blame History

模型迁移教程

这里以STGODE为例。

  1. 创建模型
  2. 修改 import
  3. 修改最外层 Module 与创建配置文件
  4. 图的传入
  5. 在 model_selector 中添加自己的模型
  6. 调整输入输出的shape
  7. 开始训练模型

创建模型

确定模型的名称 {model_name}。推荐全大写,记住这个名称。

在model文件夹下创建模型文件夹{model_name} 命名。这里存放模型文件。

STGODE的模型文件一共有两个model.py, odegcn.py,把他拷贝到TrafficWheel/model/STGODE

修改import

原仓库中model.py文件import了odegcn.py文件下的ODEG模块。

from odegcn import ODEG

迁移到TrafficWheel中我们需要将odegcn指向我们的文件夹。

from model.STGODE.odegcn import ODEG

修改最外层Module与创建配置文件

我们只关注最外层的nn.Module这里STGODE最外层的nn.Module类是class ODEGCN(nn.Module):

class ODEGCN(nn.Module):
    """ the overall network framework """
    def __init__(self, num_nodes, num_features, num_timesteps_input,
                 num_timesteps_output, A_sp_hat, A_se_hat):
        """ 
        Args:
            num_nodes : number of nodes in the graph
            num_features : number of features at each node in each time step
            num_timesteps_input : number of past time steps fed into the network
            num_timesteps_output : desired number of future time steps output by the network
            A_sp_hat : nomarlized adjacency spatial matrix
            A_se_hat : nomarlized adjacency semantic matrix
        """        

        super(ODEGCN, self).__init__()
        # spatial graph
        self.sp_blocks = nn.ModuleList(
            [nn.Sequential(
                STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
                num_nodes=num_nodes, A_hat=A_sp_hat),
                STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
                num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
            ])
        # semantic graph
        self.se_blocks = nn.ModuleList([nn.Sequential(
                STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
                num_nodes=num_nodes, A_hat=A_se_hat),
                STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
                num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
            ]) 

        self.pred = nn.Sequential(
            nn.Linear(num_timesteps_input * 64, num_timesteps_output * 32), 
            nn.ReLU(),
            nn.Linear(num_timesteps_output * 32, num_timesteps_output)
        )

我们先关注init函数即模型定义。这里我们要看init传入了什么东西这里传入了一堆

def __init__(self, num_nodes, num_features, num_timesteps_input,
                 num_timesteps_output, A_sp_hat, A_se_hat):

好在作者给了注释。我们的TrafficWheel在定义模型时只传入一个字典args所有的参数都在args中提取。

我们把init的传参修改为args字典并使用args['key']从字典中访问参数,像这样

    def __init__(self, args):
        super(ODEGCN, self).__init__()
        num_nodes = args['num_nodes']
        num_features = args['num_features']
        num_timesteps_input = args['history']
        num_timesteps_output = args['horizon']
        A_sp_hat, A_se_hat = get_A_hat(args) # 这个之后讲

这里的args读取了config/STGODE/PEMSDX.yaml的配置文件。写配置文件最简单的方法是先在config目录中创建文件夹与模型名一致为STGODE。再从config中已有模型的yaml文件中复制参数。例如这里从STGNCDE中复制了一份PEMSD4.yaml的配置文件放在config/STGODE下。文件名需要与数据集名一致可以是PEMSD{3,4,7,8}中的任一种。

STGNCDE在model部分的配置如下

model:
  type: type1
  g_type: agc
  input_dim: 2
  output_dim: 1
  embed_dim: 10
  hid_dim: 128
  hid_hid_dim: 128
  num_layers: 2
  cheb_k: 2
  solver: rk4

你需要保留input_dimoutput_dim这两个属性,num_nodes这个属性会从data中自动拷贝一份到model不用写。

删掉其余参数然后将刚刚init中访问的key的value写到yaml文件里如下

model:
  input_dim: 1
  output_dim: 1
  history: 12
  horizon: 12
  num_features: 1

后续如果在子模块中有新的参数就在config后续补充并把args传给子模块在子模块的init中使用args访问参数。

yaml支持int值bool值{True, False}字符串str列表[[64,64,64],[64,64,64],[64,64,64]]在访问args时会自动转换。

图的传入

在修改init时有一个比较棘手的地方原版代码中传入了两个矩阵是tensor类而yaml不支持tensor也不可能把矩阵写入tensor

我们需要自己构造函数来生成矩阵,在model/STGODE中创建一个新文件,我命名为adj.py。在这里构造一个get_A_hat的方法传参是args。返回所需要的两个矩阵。

STGODE.py里面import这个方法并调用。

from model.STGODE.adj import get_A_hat
    def __init__(self, args):
       	...
        A_sp_hat, A_se_hat = get_A_hat(args)

这里我们需要看原版代码中矩阵是怎么生成并传入init的。

原版代码的run_stode.py里面的77~87行

data, mean, std, dtw_matrix, sp_matrix = read_data(args)
train_loader, valid_loader, test_loader = generate_dataset(data, args)
A_sp_wave = get_normalized_adj(sp_matrix).to(device)
A_se_wave = get_normalized_adj(dtw_matrix).to(device)

net = ODEGCN(num_nodes=data.shape[1], 
             num_features=data.shape[2], 
             num_timesteps_input=args.his_length, 
             num_timesteps_output=args.pred_length, 
             A_sp_hat=A_sp_wave, 
             A_se_hat=A_se_wave)

这里先用read_data(args)读取数据,返回两个矩阵dtw_matrix, sp_matrix,再调用get_normalized_adj(matrix)标准化两个矩阵。

我们发现read_dataget_normalized_adj这两个函数都在utils.py里面所以在utils里面找这两个函数

from utils import generate_dataset, read_data, get_normalized_adj

我们将这两个函数复制到TrafficWheel中的adj.py中

files = {
    'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'],
    'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'],
    'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'],
    'pems08': ['PEMS08/pems08.npz', 'PEMS08/distance.csv'],
    'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'],
    'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'],
    'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv']
}

def read_data(args):
    """read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw

    Args:
        sigma1: float, default=0.1, sigma for the semantic matrix
        sigma2: float, default=10, sigma for the spatial matrix
        thres1: float, default=0.6, the threshold for the semantic matrix
        thres2: float, default=0.5, the threshold for the spatial matrix

    Returns:
        data: tensor, T * N * 1
        dtw_matrix: array, semantic adjacency matrix
        sp_matrix: array, spatial adjacency matrix
    """
    filename = args.filename
    file = files[filename]
    filepath = "./data/"
    if args.remote:
        filepath = '/home/lantu.lqq/ftemp/data/'
    data = np.load(filepath + file[0])['data']
    
    num_node = data.shape[1]
    mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
    std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
    data = (data - mean_value) / std_value
    mean_value = mean_value.reshape(-1)[0]
    std_value = std_value.reshape(-1)[0]

    if not os.path.exists(f'data/{filename}_dtw_distance.npy'):
        data_mean = np.mean([data[:, :, 0][24*12*i: 24*12*(i+1)] for i in range(data.shape[0]//(24*12))], axis=0)
        data_mean = data_mean.squeeze().T 
        dtw_distance = np.zeros((num_node, num_node))
        for i in tqdm(range(num_node)):
            for j in range(i, num_node):
                dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0]
        for i in range(num_node):
            for j in range(i):
                dtw_distance[i][j] = dtw_distance[j][i]
        np.save(f'data/{filename}_dtw_distance.npy', dtw_distance)

    dist_matrix = np.load(f'data/{filename}_dtw_distance.npy')

    mean = np.mean(dist_matrix)
    std = np.std(dist_matrix)
    dist_matrix = (dist_matrix - mean) / std
    sigma = args.sigma1
    dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2)
    dtw_matrix = np.zeros_like(dist_matrix)
    dtw_matrix[dist_matrix > args.thres1] = 1

    if not os.path.exists(f'data/{filename}_spatial_distance.npy'):
        with open(filepath + file[1], 'r') as fp:
            dist_matrix = np.zeros((num_node, num_node)) + np.float('inf')
            file = csv.reader(fp)
            for line in file:
                break
            for line in file:
                start = int(line[0])
                end = int(line[1])
                dist_matrix[start][end] = float(line[2])
                dist_matrix[end][start] = float(line[2])
            np.save(f'data/{filename}_spatial_distance.npy', dist_matrix)

    dist_matrix = np.load(f'data/{filename}_spatial_distance.npy')
    # normalization
    std = np.std(dist_matrix[dist_matrix != np.float('inf')])
    mean = np.mean(dist_matrix[dist_matrix != np.float('inf')])
    dist_matrix = (dist_matrix - mean) / std
    sigma = args.sigma2
    sp_matrix = np.exp(- dist_matrix**2 / sigma**2)
    sp_matrix[sp_matrix < args.thres2] = 0 
    # np.save(f'data/{filename}_sp_c_matrix.npy', sp_matrix)
    # sp_matrix = np.load(f'data/{filename}_sp_c_matrix.npy')

    print(f'average degree of spatial graph is {np.sum(sp_matrix > 0)/2/num_node}')
    print(f'average degree of semantic graph is {np.sum(dtw_matrix > 0)/2/num_node}')
    return torch.from_numpy(data.astype(np.float32)), mean_value, std_value, dtw_matrix, sp_matrix

由于文件太长,我们需要删掉不必要的部分,最后我们只需要返回两个矩阵dtw_matrix, sp_matrix因此把无关的datamean_value std_value相关函数去掉。

def read_data(args):
    """read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw

    Args:
        sigma1: float, default=0.1, sigma for the semantic matrix
        sigma2: float, default=10, sigma for the spatial matrix
        thres1: float, default=0.6, the threshold for the semantic matrix
        thres2: float, default=0.5, the threshold for the spatial matrix

    Returns:
        data: tensor, T * N * 1
        dtw_matrix: array, semantic adjacency matrix
        sp_matrix: array, spatial adjacency matrix
    """
    files = {
        'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'],
        'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'],
        'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'],
        'pems08': ['PEMSD8/pems08.npz', 'PEMSD8/distance.csv'],
        'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'],
        'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'],
        'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv']
    }

    filename = args.filename
    file = files[filename]
    filepath = "./data/"
    data = np.load(filepath + file[0])['data']

    num_node = data.shape[1]
    mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
    std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
    data = (data - mean_value) / std_value

    if not os.path.exists(f'data/{filename}_dtw_distance.npy'):
        data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))],
                            axis=0)
        data_mean = data_mean.squeeze().T
        dtw_distance = np.zeros((num_node, num_node))
        for i in tqdm(range(num_node)):
            for j in range(i, num_node):
                dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0]
        for i in range(num_node):
            for j in range(i):
                dtw_distance[i][j] = dtw_distance[j][i]
        np.save(f'data/{filename}_dtw_distance.npy', dtw_distance)

    dist_matrix = np.load(f'data/{filename}_dtw_distance.npy')

    mean = np.mean(dist_matrix)
    std = np.std(dist_matrix)
    dist_matrix = (dist_matrix - mean) / std
    sigma = args.sigma1
    dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2)
    dtw_matrix = np.zeros_like(dist_matrix)
    dtw_matrix[dist_matrix > args.thres1] = 1

    if not os.path.exists(f'data/{filename}_spatial_distance.npy'):
        with open(filepath + file[1], 'r') as fp:
            dist_matrix = np.zeros((num_node, num_node)) + float('inf')
            file = csv.reader(fp)
            for line in file:
                break
            for line in file:
                start = int(line[0])
                end = int(line[1])
                dist_matrix[start][end] = float(line[2])
                dist_matrix[end][start] = float(line[2])
            np.save(f'data/{filename}_spatial_distance.npy', dist_matrix)

    dist_matrix = np.load(f'data/{filename}_spatial_distance.npy')
    # normalization
    std = np.std(dist_matrix[dist_matrix != float('inf')])
    mean = np.mean(dist_matrix[dist_matrix != float('inf')])
    dist_matrix = (dist_matrix - mean) / std
    sigma = args.sigma2
    sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2)
    sp_matrix[sp_matrix < args.thres2] = 0

    return dtw_matrix, sp_matrix

原版函数也传入了args并通过args.sigma这种成员变量的方式访问参数。而我们是用字典传参所以要将类似args.sigma改成args['sigma']并在config中配置相应参数.

sigma = args['sigma1']
dtw_matrix[dist_matrix > args['thres1']] = 1
sigma = args['sigma2']
sp_matrix[sp_matrix < args['thres2']] = 0

这里需要注意的是由于PEMSD3的距离文件中from和to是传感器编号而不是从0开始的下标所以要通过data中的PEMSD3.txt文件做个映射将传感器编号映射到0到357的下标。

from,to,distance
317842,318711,0.872
318721,315955,1.322
...
    # 计算spatial_distance, 如果存在缓存则直接读取缓存
    if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy'):
        if num_node == 358:
            with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f:
                id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}  # 建立映射列表
                # 使用 pandas 读取 CSV 文件,跳过标题行
                df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
                dist_matrix = np.zeros((num_node, num_node)) + float('inf')
                for _, row in df.iterrows():
                    start = int(id_dict[row[0]])
                    end = int(id_dict[row[1]])
                    dist_matrix[start][end] = float(row[2])
                    dist_matrix[end][start] = float(row[2])
                np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
        else:
            # 使用 pandas 读取 CSV 文件,跳过标题行
            df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
            dist_matrix = np.zeros((num_node, num_node)) + float('inf')
            for _, row in df.iterrows():
                start = int(row[0])
                end = int(row[1])
                dist_matrix[start][end] = float(row[2])
                dist_matrix[end][start] = float(row[2])
            np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)

完成对read_data()函数的修改后get_normalized_adj()函数倒是不复杂传入一个narray返回一个tensor直接复制即可。

def get_normalized_adj(A):
    """
    Returns a tensor, the degree normalized adjacency matrix.
    """
    alpha = 0.8
    D = np.array(np.sum(A, axis=1)).reshape((-1,))
    D[D <= 10e-5] = 10e-5  # Prevent infs
    diag = np.reciprocal(np.sqrt(D))
    A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A),
                         diag.reshape((1, -1)))
    A_reg = alpha / 2 * (np.eye(A.shape[0]) + A_wave)
    return torch.from_numpy(A_reg.astype(np.float32))

最后源代码是在read_data()外面调用get_normalized_adj()来返回标准化的tensor矩阵的这样写太麻烦。我们直接在read_data的return上完成get_normalized_adj(matrix).to(device)完成这个操作保持model代码的简洁。

data, mean, std, dtw_matrix, sp_matrix = read_data(args)
train_loader, valid_loader, test_loader = generate_dataset(data, args)
A_sp_wave = get_normalized_adj(sp_matrix).to(device)
A_se_wave = get_normalized_adj(dtw_matrix).to(device)

修改为

return get_normalized_adj(dtw_matrix).to(args['device']), get_normalized_adj(sp_matrix).to(args['device'])

args['device']不用写在config中由运行命令指定。

最终,model/STGODE/adj.py文件下的get_A_hat()函数如下:

def get_A_hat(args):
    """read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw

    Args:
        sigma1: float, default=0.1, sigma for the semantic matrix
        sigma2: float, default=10, sigma for the spatial matrix
        thres1: float, default=0.6, the threshold for the semantic matrix
        thres2: float, default=0.5, the threshold for the spatial matrix

    Returns:
        data: tensor, T * N * 1
        dtw_matrix: array, semantic adjacency matrix
        sp_matrix: array, spatial adjacency matrix
    """
    filepath = './data/'
    num_node = args['num_nodes']
    file = files[num_node]
    filename = file[0][:6]

    data = np.load(filepath + file[0])['data']
    num_node = data.shape[1]
    mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
    std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
    data = (data - mean_value) / std_value

    # 计算dtw_distance, 如果存在缓存则直接读取缓存
    if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy'):
        data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))],
                            axis=0)
        data_mean = data_mean.squeeze().T
        dtw_distance = np.zeros((num_node, num_node))
        for i in tqdm(range(num_node)):
            for j in range(i, num_node):
                dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0]
        for i in range(num_node):
            for j in range(i):
                dtw_distance[i][j] = dtw_distance[j][i]
        np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy', dtw_distance)

    dist_matrix = np.load(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy')

    mean = np.mean(dist_matrix)
    std = np.std(dist_matrix)
    dist_matrix = (dist_matrix - mean) / std
    sigma = args['sigma1']
    dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2)
    dtw_matrix = np.zeros_like(dist_matrix)
    dtw_matrix[dist_matrix > args['thres1']] = 1

    # 计算spatial_distance, 如果存在缓存则直接读取缓存
    if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy'):
        if num_node == 358:
            with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f:
                id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}  # 建立映射列表
                # 使用 pandas 读取 CSV 文件,跳过标题行
                df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
                dist_matrix = np.zeros((num_node, num_node)) + float('inf')
                for _, row in df.iterrows():
                    start = int(id_dict[row[0]])
                    end = int(id_dict[row[1]])
                    dist_matrix[start][end] = float(row[2])
                    dist_matrix[end][start] = float(row[2])
                np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
        else:
            # 使用 pandas 读取 CSV 文件,跳过标题行
            df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
            dist_matrix = np.zeros((num_node, num_node)) + float('inf')
            for _, row in df.iterrows():
                start = int(row[0])
                end = int(row[1])
                dist_matrix[start][end] = float(row[2])
                dist_matrix[end][start] = float(row[2])
            np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
    # normalization
    std = np.std(dist_matrix[dist_matrix != float('inf')])
    mean = np.mean(dist_matrix[dist_matrix != float('inf')])
    dist_matrix = (dist_matrix - mean) / std
    sigma = args['sigma2']
    sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2)
    sp_matrix[sp_matrix < args['thres2']] = 0

    return get_normalized_adj(dtw_matrix).to(args['device']), get_normalized_adj(sp_matrix).to(args['device'])

在model的init函数中调用返回两个tensor矩阵。

A_sp_hat, A_se_hat = get_A_hat(args)

里层的nn.Module,一般不必修改任何内容。因为初始化里层的nn.Module所需要的相关参数已经在最外层的nn.Module中有定义。除非里层nn.Module需要额外的图。

在model_selector中添加自己的模型

到了这一步,需要先在model/model_selector.py中添加自己的模型。

先import最外层的模型的nn.Module,再仿照前面模型的样式添加上自己的

from model.DDGCRN.DDGCRN import DDGCRN
from model.TWDGCN.TWDGCN import TWDGCN
# ...
from model.STGODE.STGODE import ODEGCN # 新添加



def model_selector(model):
    match model['type']:
        case 'DDGCRN': return DDGCRN(model)
        case 'TWDGCN': return TWDGCN(model)
		# ...
        case 'STGODE': return ODEGCN(model)  # 新添加

其中case就是文首提到的{model_name}

调整输入输出的shape

model/model_selector.py中添加自己的模型之后就可以进行debug了。首先你需要在pycharm中编写配置。写参数。

python run.py --model STGODE --dataset PEMSD4  --mode train  --device cuda:0

之后在自己的forward上打断点。若程序运行时报错说明init配置的还是有问题模型未能正常初始化请重新检查。

若初始化完成后就可以开始修改forward函数了。

TrafficWheel的外层会对模型传入一个shape为**[batch_size, time_step, num_nodes, dim]**的张量一般情况下batch_size为64time_step为12num_nodes取决于数据集dim是输入维度为3。其中第1维为流量数据第2维为日嵌入第3维为周嵌入。

需要观察原版模型的forward函数里传入什么shape的tensor。一般情况下不需要日嵌入周嵌入。

这里STGODE里面传入的shape是 (batch_size, num_nodes, num_timesteps, num_features)与我们传入的shape不一致所以我们需要修改。

以PEMSD4为例首先截取 x = x[..., 0:1]会使得x的shape由(64,12,307,3)变为(64,12,307,1),即取流量维度。

再通过x = x.permute(0, 2, 1, 3)[batch_size, time_step, num_nodes, dim]转换为 (batch_size, num_nodes, num_timesteps, num_features)。输入维度dim和特征features其实是一个意思。

    def forward(self, x):
        """
        Args:
            x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F)
        Returns:
            prediction for future of shape (batch_size, num_nodes, num_timesteps_output)
        """
        x = x[..., 0:1].permute(0, 2, 1, 3)

最后修改输出在forward最后的return这个地方打断点如果输入无误可以直接debug幸运的话就能一次跑到断点处。不幸运的话可能会遇到各种问题通常是原代码仓库的问题。这时候需要debug各个模块的forward中是否有张量不匹配的地方。问题一般就出在这。修改这些bug直到程序跑到return断点处。

    def forward(self, x):
        """
        Args:
            x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F)
        Returns:
            prediction for future of shape (batch_size, num_nodes, num_timesteps_output)
        """
        ...

        return self.pred(x)

这里在return时还套了个pred函数为了便于观察我们可以直接在外层打断点通常是trainer/Trainer.py的第55行观察output变量。然而原版的代码已经用注释告诉我们输出的shape所以我们只需要看注释就可以了。

我们看到输出的shape是(batch_size, num_nodes, num_timesteps_output),而TrafficWheel接受的输出shape是[batch_size, time_step, num_nodes, dim],所以我们需要进行转换。

在PEMSD4下原版shape是(64,307,12),首先通过permute(0,2,1)调整为(64,12,307),再通过unsqueeze(dim=-1)调整为(64,12,307,1)。以符合输出的shape。

return self.pred(x).permute(0,2,1).unsqueeze(dim=-1)

调整shape常用的技巧注意下述方法需要用一个x接受。不能没有返回值。

例如 x=x.permute(0,2,1,3)是正确的,不能是x.permute(0,2,1,3)

操作 原张量shape 操作后张量shape
permute(0,2,1,3) (64,12,307,3) (64,307,12,3)
x[..., 0:1] (64,12,307,3) (64,12,307,1)
x[..., 0] (64,12,307,3) (64,12,307) 取第0维
squeeze(dim=-1) (64,12,307,1) (64,12,307)
unsqueeze(dim=-1) (64,12,307) (64,12,307,1)

更多的可以查pytorch的官方文档比较常用的就这几个

开始训练模型

在一切就绪后forward的input, output的shape对得上就可以训练模型了。可以使用jupyter notebook批量训练也可以在pycharm中配置任务。取决于个人喜好。训练的代码为

python run.py --model STGODE --dataset PEMSD4  --mode train  --device cuda:0