TrafficWheel/transfer_guide.md

659 lines
26 KiB
Markdown
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 模型迁移教程
这里以[STGODE](https://github.com/square-coder/STGODE/tree/main)为例。
1. [创建模型](#创建模型)
2. [修改 import](#修改import)
3. [修改最外层 Module 与创建配置文件](#修改最外层Module与创建配置文件)
4. [图的传入](#图的传入)
5. [在 model_selector 中添加自己的模型](#在model_selector中添加自己的模型)
6. [调整输入输出的shape](#调整输入输出的shape)
7. [开始训练模型](#开始训练模型)
# 创建模型
确定模型的名称 `{model_name}`。推荐全大写,记住这个名称。
在model文件夹下创建模型文件夹`{model_name}` 命名。这里存放模型文件。
STGODE的模型文件一共有两个`model.py, odegcn.py`,把他拷贝到`TrafficWheel/model/STGODE`下
# 修改import
原仓库中model.py文件import了odegcn.py文件下的ODEG模块。
```python
from odegcn import ODEG
```
迁移到TrafficWheel中我们需要将odegcn指向我们的文件夹。
```python
from model.STGODE.odegcn import ODEG
```
# 修改最外层Module与创建配置文件
我们只关注最外层的`nn.Module`类这里STGODE最外层的`nn.Module`类是`class ODEGCN(nn.Module):`
```python
class ODEGCN(nn.Module):
""" the overall network framework """
def __init__(self, num_nodes, num_features, num_timesteps_input,
num_timesteps_output, A_sp_hat, A_se_hat):
"""
Args:
num_nodes : number of nodes in the graph
num_features : number of features at each node in each time step
num_timesteps_input : number of past time steps fed into the network
num_timesteps_output : desired number of future time steps output by the network
A_sp_hat : nomarlized adjacency spatial matrix
A_se_hat : nomarlized adjacency semantic matrix
"""
super(ODEGCN, self).__init__()
# spatial graph
self.sp_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_sp_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
])
# semantic graph
self.se_blocks = nn.ModuleList([nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_se_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
])
self.pred = nn.Sequential(
nn.Linear(num_timesteps_input * 64, num_timesteps_output * 32),
nn.ReLU(),
nn.Linear(num_timesteps_output * 32, num_timesteps_output)
)
```
我们先关注init函数即模型定义。这里我们要看init传入了什么东西这里传入了一堆
```python
def __init__(self, num_nodes, num_features, num_timesteps_input,
num_timesteps_output, A_sp_hat, A_se_hat):
```
好在作者给了注释。我们的TrafficWheel在定义模型时只传入一个字典args所有的参数都在args中提取。
我们把init的传参修改为args字典并使用args['key']从字典中访问参数,像这样
```python
def __init__(self, args):
super(ODEGCN, self).__init__()
num_nodes = args['num_nodes']
num_features = args['num_features']
num_timesteps_input = args['history']
num_timesteps_output = args['horizon']
A_sp_hat, A_se_hat = get_A_hat(args) # 这个之后讲
```
这里的args读取了`config/STGODE/PEMSDX.yaml`的配置文件。写配置文件最简单的方法是先在config目录中创建文件夹与模型名一致为STGODE。再从config中已有模型的yaml文件中复制参数。例如这里从STGNCDE中复制了一份PEMSD4.yaml的配置文件放在`config/STGODE`下。文件名需要与数据集名一致可以是PEMSD{3,4,7,8}中的任一种。
STGNCDE在model部分的配置如下
```yaml
model:
type: type1
g_type: agc
input_dim: 2
output_dim: 1
embed_dim: 10
hid_dim: 128
hid_hid_dim: 128
num_layers: 2
cheb_k: 2
solver: rk4
```
你需要保留`input_dim`和`output_dim`这两个属性,`num_nodes`这个属性会从data中自动拷贝一份到model不用写。
删掉其余参数然后将刚刚init中访问的key的value写到yaml文件里如下
```
model:
input_dim: 1
output_dim: 1
history: 12
horizon: 12
num_features: 1
```
后续如果在子模块中有新的参数就在config后续补充并把args传给子模块在子模块的init中使用args访问参数。
yaml支持int值bool值`{True, False}`字符串str列表`[[64,64,64],[64,64,64],[64,64,64]]`在访问args时会自动转换。
# 图的传入
在修改init时有一个比较棘手的地方原版代码中传入了两个矩阵是tensor类而yaml不支持tensor也不可能把矩阵写入tensor
我们需要自己构造函数来生成矩阵,在`model/STGODE`中创建一个新文件,我命名为`adj.py`。在这里构造一个`get_A_hat`的方法传参是args。返回所需要的两个矩阵。
在`STGODE.py`里面import这个方法并调用。
```py
from model.STGODE.adj import get_A_hat
```
```python
def __init__(self, args):
...
A_sp_hat, A_se_hat = get_A_hat(args)
```
这里我们需要看原版代码中矩阵是怎么生成并传入init的。
原版代码的`run_stode.py`里面的77~87行
```python
data, mean, std, dtw_matrix, sp_matrix = read_data(args)
train_loader, valid_loader, test_loader = generate_dataset(data, args)
A_sp_wave = get_normalized_adj(sp_matrix).to(device)
A_se_wave = get_normalized_adj(dtw_matrix).to(device)
net = ODEGCN(num_nodes=data.shape[1],
num_features=data.shape[2],
num_timesteps_input=args.his_length,
num_timesteps_output=args.pred_length,
A_sp_hat=A_sp_wave,
A_se_hat=A_se_wave)
```
这里先用`read_data(args)`读取数据,返回两个矩阵`dtw_matrix, sp_matrix`,再调用`get_normalized_adj(matrix)`标准化两个矩阵。
我们发现`read_data`和`get_normalized_adj`这两个函数都在`utils.py`里面所以在utils里面找这两个函数
```python
from utils import generate_dataset, read_data, get_normalized_adj
```
我们将这两个函数复制到TrafficWheel中的adj.py中
```python
files = {
'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'],
'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'],
'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'],
'pems08': ['PEMS08/pems08.npz', 'PEMS08/distance.csv'],
'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'],
'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'],
'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv']
}
def read_data(args):
"""read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw
Args:
sigma1: float, default=0.1, sigma for the semantic matrix
sigma2: float, default=10, sigma for the spatial matrix
thres1: float, default=0.6, the threshold for the semantic matrix
thres2: float, default=0.5, the threshold for the spatial matrix
Returns:
data: tensor, T * N * 1
dtw_matrix: array, semantic adjacency matrix
sp_matrix: array, spatial adjacency matrix
"""
filename = args.filename
file = files[filename]
filepath = "./data/"
if args.remote:
filepath = '/home/lantu.lqq/ftemp/data/'
data = np.load(filepath + file[0])['data']
num_node = data.shape[1]
mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
data = (data - mean_value) / std_value
mean_value = mean_value.reshape(-1)[0]
std_value = std_value.reshape(-1)[0]
if not os.path.exists(f'data/{filename}_dtw_distance.npy'):
data_mean = np.mean([data[:, :, 0][24*12*i: 24*12*(i+1)] for i in range(data.shape[0]//(24*12))], axis=0)
data_mean = data_mean.squeeze().T
dtw_distance = np.zeros((num_node, num_node))
for i in tqdm(range(num_node)):
for j in range(i, num_node):
dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0]
for i in range(num_node):
for j in range(i):
dtw_distance[i][j] = dtw_distance[j][i]
np.save(f'data/{filename}_dtw_distance.npy', dtw_distance)
dist_matrix = np.load(f'data/{filename}_dtw_distance.npy')
mean = np.mean(dist_matrix)
std = np.std(dist_matrix)
dist_matrix = (dist_matrix - mean) / std
sigma = args.sigma1
dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2)
dtw_matrix = np.zeros_like(dist_matrix)
dtw_matrix[dist_matrix > args.thres1] = 1
if not os.path.exists(f'data/{filename}_spatial_distance.npy'):
with open(filepath + file[1], 'r') as fp:
dist_matrix = np.zeros((num_node, num_node)) + np.float('inf')
file = csv.reader(fp)
for line in file:
break
for line in file:
start = int(line[0])
end = int(line[1])
dist_matrix[start][end] = float(line[2])
dist_matrix[end][start] = float(line[2])
np.save(f'data/{filename}_spatial_distance.npy', dist_matrix)
dist_matrix = np.load(f'data/{filename}_spatial_distance.npy')
# normalization
std = np.std(dist_matrix[dist_matrix != np.float('inf')])
mean = np.mean(dist_matrix[dist_matrix != np.float('inf')])
dist_matrix = (dist_matrix - mean) / std
sigma = args.sigma2
sp_matrix = np.exp(- dist_matrix**2 / sigma**2)
sp_matrix[sp_matrix < args.thres2] = 0
# np.save(f'data/{filename}_sp_c_matrix.npy', sp_matrix)
# sp_matrix = np.load(f'data/{filename}_sp_c_matrix.npy')
print(f'average degree of spatial graph is {np.sum(sp_matrix > 0)/2/num_node}')
print(f'average degree of semantic graph is {np.sum(dtw_matrix > 0)/2/num_node}')
return torch.from_numpy(data.astype(np.float32)), mean_value, std_value, dtw_matrix, sp_matrix
```
由于文件太长,我们需要删掉不必要的部分,最后我们只需要返回两个矩阵`dtw_matrix, sp_matrix`因此把无关的datamean_value std_value相关函数去掉。
```python
def read_data(args):
"""read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw
Args:
sigma1: float, default=0.1, sigma for the semantic matrix
sigma2: float, default=10, sigma for the spatial matrix
thres1: float, default=0.6, the threshold for the semantic matrix
thres2: float, default=0.5, the threshold for the spatial matrix
Returns:
data: tensor, T * N * 1
dtw_matrix: array, semantic adjacency matrix
sp_matrix: array, spatial adjacency matrix
"""
files = {
'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'],
'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'],
'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'],
'pems08': ['PEMSD8/pems08.npz', 'PEMSD8/distance.csv'],
'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'],
'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'],
'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv']
}
filename = args.filename
file = files[filename]
filepath = "./data/"
data = np.load(filepath + file[0])['data']
num_node = data.shape[1]
mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
data = (data - mean_value) / std_value
if not os.path.exists(f'data/{filename}_dtw_distance.npy'):
data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))],
axis=0)
data_mean = data_mean.squeeze().T
dtw_distance = np.zeros((num_node, num_node))
for i in tqdm(range(num_node)):
for j in range(i, num_node):
dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0]
for i in range(num_node):
for j in range(i):
dtw_distance[i][j] = dtw_distance[j][i]
np.save(f'data/{filename}_dtw_distance.npy', dtw_distance)
dist_matrix = np.load(f'data/{filename}_dtw_distance.npy')
mean = np.mean(dist_matrix)
std = np.std(dist_matrix)
dist_matrix = (dist_matrix - mean) / std
sigma = args.sigma1
dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2)
dtw_matrix = np.zeros_like(dist_matrix)
dtw_matrix[dist_matrix > args.thres1] = 1
if not os.path.exists(f'data/{filename}_spatial_distance.npy'):
with open(filepath + file[1], 'r') as fp:
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
file = csv.reader(fp)
for line in file:
break
for line in file:
start = int(line[0])
end = int(line[1])
dist_matrix[start][end] = float(line[2])
dist_matrix[end][start] = float(line[2])
np.save(f'data/{filename}_spatial_distance.npy', dist_matrix)
dist_matrix = np.load(f'data/{filename}_spatial_distance.npy')
# normalization
std = np.std(dist_matrix[dist_matrix != float('inf')])
mean = np.mean(dist_matrix[dist_matrix != float('inf')])
dist_matrix = (dist_matrix - mean) / std
sigma = args.sigma2
sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2)
sp_matrix[sp_matrix < args.thres2] = 0
return dtw_matrix, sp_matrix
```
原版函数也传入了args并通过args.sigma这种成员变量的方式访问参数。而我们是用字典传参所以要将类似`args.sigma`改成`args['sigma']`并在config中配置相应参数.
```python
sigma = args['sigma1']
dtw_matrix[dist_matrix > args['thres1']] = 1
sigma = args['sigma2']
sp_matrix[sp_matrix < args['thres2']] = 0
```
这里需要注意的是由于PEMSD3的距离文件中from和to是传感器编号而不是从0开始的下标所以要通过data中的`PEMSD3.txt`文件做个映射将传感器编号映射到0到357的下标。
```csv
from,to,distance
317842,318711,0.872
318721,315955,1.322
...
```
```py
# 计算spatial_distance, 如果存在缓存则直接读取缓存
if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy'):
if num_node == 358:
with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f:
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表
# 使用 pandas 读取 CSV 文件,跳过标题行
df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
for _, row in df.iterrows():
start = int(id_dict[row[0]])
end = int(id_dict[row[1]])
dist_matrix[start][end] = float(row[2])
dist_matrix[end][start] = float(row[2])
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
else:
# 使用 pandas 读取 CSV 文件,跳过标题行
df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
for _, row in df.iterrows():
start = int(row[0])
end = int(row[1])
dist_matrix[start][end] = float(row[2])
dist_matrix[end][start] = float(row[2])
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
```
完成对read_data()函数的修改后get_normalized_adj()函数倒是不复杂传入一个narray返回一个tensor直接复制即可。
```py
def get_normalized_adj(A):
"""
Returns a tensor, the degree normalized adjacency matrix.
"""
alpha = 0.8
D = np.array(np.sum(A, axis=1)).reshape((-1,))
D[D <= 10e-5] = 10e-5 # Prevent infs
diag = np.reciprocal(np.sqrt(D))
A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A),
diag.reshape((1, -1)))
A_reg = alpha / 2 * (np.eye(A.shape[0]) + A_wave)
return torch.from_numpy(A_reg.astype(np.float32))
```
最后源代码是在read_data()外面调用get_normalized_adj()来返回标准化的tensor矩阵的这样写太麻烦。我们直接在read_data的return上完成`get_normalized_adj(matrix).to(device)`完成这个操作保持model代码的简洁。
```py
data, mean, std, dtw_matrix, sp_matrix = read_data(args)
train_loader, valid_loader, test_loader = generate_dataset(data, args)
A_sp_wave = get_normalized_adj(sp_matrix).to(device)
A_se_wave = get_normalized_adj(dtw_matrix).to(device)
```
修改为
```py
return get_normalized_adj(dtw_matrix).to(args['device']), get_normalized_adj(sp_matrix).to(args['device'])
```
`args['device']`不用写在config中由运行命令指定。
最终,`model/STGODE/adj.py`文件下的`get_A_hat()`函数如下:
```python
def get_A_hat(args):
"""read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw
Args:
sigma1: float, default=0.1, sigma for the semantic matrix
sigma2: float, default=10, sigma for the spatial matrix
thres1: float, default=0.6, the threshold for the semantic matrix
thres2: float, default=0.5, the threshold for the spatial matrix
Returns:
data: tensor, T * N * 1
dtw_matrix: array, semantic adjacency matrix
sp_matrix: array, spatial adjacency matrix
"""
filepath = './data/'
num_node = args['num_nodes']
file = files[num_node]
filename = file[0][:6]
data = np.load(filepath + file[0])['data']
num_node = data.shape[1]
mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
data = (data - mean_value) / std_value
# 计算dtw_distance, 如果存在缓存则直接读取缓存
if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy'):
data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))],
axis=0)
data_mean = data_mean.squeeze().T
dtw_distance = np.zeros((num_node, num_node))
for i in tqdm(range(num_node)):
for j in range(i, num_node):
dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0]
for i in range(num_node):
for j in range(i):
dtw_distance[i][j] = dtw_distance[j][i]
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy', dtw_distance)
dist_matrix = np.load(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy')
mean = np.mean(dist_matrix)
std = np.std(dist_matrix)
dist_matrix = (dist_matrix - mean) / std
sigma = args['sigma1']
dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2)
dtw_matrix = np.zeros_like(dist_matrix)
dtw_matrix[dist_matrix > args['thres1']] = 1
# 计算spatial_distance, 如果存在缓存则直接读取缓存
if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy'):
if num_node == 358:
with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f:
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表
# 使用 pandas 读取 CSV 文件,跳过标题行
df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
for _, row in df.iterrows():
start = int(id_dict[row[0]])
end = int(id_dict[row[1]])
dist_matrix[start][end] = float(row[2])
dist_matrix[end][start] = float(row[2])
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
else:
# 使用 pandas 读取 CSV 文件,跳过标题行
df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
for _, row in df.iterrows():
start = int(row[0])
end = int(row[1])
dist_matrix[start][end] = float(row[2])
dist_matrix[end][start] = float(row[2])
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
# normalization
std = np.std(dist_matrix[dist_matrix != float('inf')])
mean = np.mean(dist_matrix[dist_matrix != float('inf')])
dist_matrix = (dist_matrix - mean) / std
sigma = args['sigma2']
sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2)
sp_matrix[sp_matrix < args['thres2']] = 0
return get_normalized_adj(dtw_matrix).to(args['device']), get_normalized_adj(sp_matrix).to(args['device'])
```
在model的init函数中调用返回两个tensor矩阵。
```py
A_sp_hat, A_se_hat = get_A_hat(args)
```
里层的`nn.Module`,一般不必修改任何内容。因为初始化里层的`nn.Module`所需要的相关参数已经在最外层的`nn.Module`中有定义。除非里层`nn.Module`需要额外的图。
# 在model_selector中添加自己的模型
到了这一步,需要先在`model/model_selector.py`中添加自己的模型。
先import最外层的模型的`nn.Module`,再仿照前面模型的样式添加上自己的
```python
from model.DDGCRN.DDGCRN import DDGCRN
from model.TWDGCN.TWDGCN import TWDGCN
# ...
from model.STGODE.STGODE import ODEGCN # 新添加
def model_selector(model):
match model['type']:
case 'DDGCRN': return DDGCRN(model)
case 'TWDGCN': return TWDGCN(model)
# ...
case 'STGODE': return ODEGCN(model) # 新添加
```
其中case就是文首提到的`{model_name}`
# 调整输入输出的shape
在`model/model_selector.py`中添加自己的模型之后就可以进行debug了。首先你需要在pycharm中编写配置。写参数。
```bash
python run.py --model STGODE --dataset PEMSD4 --mode train --device cuda:0
```
之后在自己的forward上打断点。若程序运行时报错说明init配置的还是有问题模型未能正常初始化请重新检查。
若初始化完成后就可以开始修改forward函数了。
TrafficWheel的外层会对模型传入一个shape为**`[batch_size, time_step, num_nodes, dim]`**的张量一般情况下batch_size为64time_step为12num_nodes取决于数据集dim是输入维度为3。其中第1维为流量数据第2维为日嵌入第3维为周嵌入。
需要观察原版模型的forward函数里传入什么shape的tensor。一般情况下不需要日嵌入周嵌入。
这里STGODE里面传入的shape是 `(batch_size, num_nodes, num_timesteps, num_features)`与我们传入的shape不一致所以我们需要修改。
以PEMSD4为例首先截取 `x = x[..., 0:1]`会使得x的shape由(64,12,307,3)变为(64,12,307,1),即取流量维度。
再通过`x = x.permute(0, 2, 1, 3)` 将`[batch_size, time_step, num_nodes, dim]`转换为 `(batch_size, num_nodes, num_timesteps, num_features)`。输入维度dim和特征features其实是一个意思。
```py
def forward(self, x):
"""
Args:
x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F)
Returns:
prediction for future of shape (batch_size, num_nodes, num_timesteps_output)
"""
x = x[..., 0:1].permute(0, 2, 1, 3)
```
最后修改输出在forward最后的return这个地方打断点如果输入无误可以直接debug幸运的话就能一次跑到断点处。不幸运的话可能会遇到各种问题通常是原代码仓库的问题。这时候需要debug各个模块的forward中是否有张量不匹配的地方。问题一般就出在这。修改这些bug直到程序跑到return断点处。
```python
def forward(self, x):
"""
Args:
x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F)
Returns:
prediction for future of shape (batch_size, num_nodes, num_timesteps_output)
"""
...
return self.pred(x)
```
这里在return时还套了个pred函数为了便于观察我们可以直接在外层打断点通常是`trainer/Trainer.py`的第55行观察output变量。然而原版的代码已经用注释告诉我们输出的shape所以我们只需要看注释就可以了。
我们看到输出的shape是(batch_size, num_nodes, num_timesteps_output),而**TrafficWheel接受的输出shape是`[batch_size, time_step, num_nodes, dim]`**,所以我们需要进行转换。
在PEMSD4下原版shape是(64,307,12),首先通过`permute(0,2,1)`调整为(64,12,307),再通过`unsqueeze(dim=-1)`调整为(64,12,307,1)。以符合输出的shape。
```py
return self.pred(x).permute(0,2,1).unsqueeze(dim=-1)
```
调整shape常用的技巧注意下述方法需要用一个x接受。不能没有返回值。
例如 `x=x.permute(0,2,1,3)`是正确的,不能是`x.permute(0,2,1,3)`
| 操作 | 原张量shape | 操作后张量shape |
| ----------------- | ------------- | --------------------- |
| permute(0,2,1,3) | (64,12,307,3) | (64,307,12,3) |
| x[..., 0:1] | (64,12,307,3) | (64,12,307,1) |
| x[..., 0] | (64,12,307,3) | (64,12,307) 取第0维 |
| squeeze(dim=-1) | (64,12,307,1) | (64,12,307) |
| unsqueeze(dim=-1) | (64,12,307) | (64,12,307,1) |
更多的可以查pytorch的官方文档比较常用的就这几个
# 开始训练模型
在一切就绪后forward的input, output的shape对得上就可以训练模型了。可以使用jupyter notebook批量训练也可以在pycharm中配置任务。取决于个人喜好。训练的代码为
```bash
python run.py --model STGODE --dataset PEMSD4 --mode train --device cuda:0
```