TrafficWheel/model/STMLP/STMLP.py

438 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from utils.get_adj import get_adj
import numbers
# --- 基础算子 ---
class NConv(nn.Module):
def forward(self, x, adj):
return torch.einsum("ncwl,vw->ncvl", (x, adj)).contiguous()
class DyNconv(nn.Module):
def forward(self, x, adj):
return torch.einsum("ncvl,nvwl->ncwl", (x, adj)).contiguous()
class Linear(nn.Module):
def __init__(self, c_in, c_out, bias=True):
super().__init__()
self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1, bias=bias)
def forward(self, x):
return self.mlp(x)
class Prop(nn.Module):
def __init__(self, c_in, c_out, gdep, dropout, alpha):
super().__init__()
self.nconv = NConv()
self.mlp = Linear(c_in, c_out)
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
def forward(self, x, adj):
adj = adj + torch.eye(adj.size(0), device=x.device)
d = adj.sum(1)
a = adj / d.view(-1, 1)
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a)
return self.mlp(h)
class MixProp(nn.Module):
def __init__(self, c_in, c_out, gdep, dropout, alpha):
super().__init__()
self.nconv = NConv()
self.mlp = Linear((gdep + 1) * c_in, c_out)
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
def forward(self, x, adj):
adj = adj + torch.eye(adj.size(0), device=x.device)
d = adj.sum(1)
a = adj / d.view(-1, 1)
out = [x]
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a)
out.append(h)
return self.mlp(torch.cat(out, dim=1))
class DyMixprop(nn.Module):
def __init__(self, c_in, c_out, gdep, dropout, alpha):
super().__init__()
self.nconv = DyNconv()
self.mlp1 = Linear((gdep + 1) * c_in, c_out)
self.mlp2 = Linear((gdep + 1) * c_in, c_out)
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
self.lin1, self.lin2 = Linear(c_in, c_in), Linear(c_in, c_in)
def forward(self, x):
x1 = torch.tanh(self.lin1(x))
x2 = torch.tanh(self.lin2(x))
adj = self.nconv(x1.transpose(2, 1), x2)
adj0 = torch.softmax(adj, dim=2)
adj1 = torch.softmax(adj.transpose(2, 1), dim=2)
# 两条分支
out1, out2 = [x], [x]
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj0)
out1.append(h)
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1)
out2.append(h)
return self.mlp1(torch.cat(out1, dim=1)) + self.mlp2(torch.cat(out2, dim=1))
class DilatedInception(nn.Module):
def __init__(self, cin, cout, dilation_factor=2):
super().__init__()
self.kernels = [2, 3, 6, 7]
cout_each = int(cout / len(self.kernels))
self.convs = nn.ModuleList(
[
nn.Conv2d(
cin, cout_each, kernel_size=(1, k), dilation=(1, dilation_factor)
)
for k in self.kernels
]
)
def forward(self, x):
outs = [conv(x)[..., -self.convs[-1](x).size(3) :] for conv in self.convs]
return torch.cat(outs, dim=1)
class GraphConstructor(nn.Module):
def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
super().__init__()
self.nnodes, self.k, self.dim, self.alpha, self.device = (
nnodes,
k,
dim,
alpha,
device,
)
self.static_feat = static_feat
if static_feat is not None:
xd = static_feat.shape[1]
self.lin1, self.lin2 = nn.Linear(xd, dim), nn.Linear(xd, dim)
else:
self.emb1 = nn.Embedding(nnodes, dim)
self.emb2 = nn.Embedding(nnodes, dim)
self.lin1, self.lin2 = nn.Linear(dim, dim), nn.Linear(dim, dim)
def forward(self, idx):
if self.static_feat is None:
vec1, vec2 = self.emb1(idx), self.emb2(idx)
else:
vec1 = vec2 = self.static_feat[idx, :]
vec1 = torch.tanh(self.alpha * self.lin1(vec1))
vec2 = torch.tanh(self.alpha * self.lin2(vec2))
a = torch.mm(vec1, vec2.transpose(1, 0)) - torch.mm(vec2, vec1.transpose(1, 0))
adj = F.relu(torch.tanh(self.alpha * a))
mask = torch.zeros(idx.size(0), idx.size(0), device=self.device)
s1, t1 = adj.topk(self.k, 1)
mask.scatter_(1, t1, s1.new_ones(s1.size()))
return adj * mask
class LayerNorm(nn.Module):
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape, self.eps, self.elementwise_affine = (
tuple(normalized_shape),
eps,
elementwise_affine,
)
if elementwise_affine:
self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
self.bias = nn.Parameter(torch.Tensor(*normalized_shape))
init.ones_(self.weight)
init.zeros_(self.bias)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(self, x, idx):
if self.elementwise_affine:
return F.layer_norm(
x,
tuple(x.shape[1:]),
self.weight[:, idx, :],
self.bias[:, idx, :],
self.eps,
)
else:
return F.layer_norm(x, tuple(x.shape[1:]), self.weight, self.bias, self.eps)
def extra_repr(self):
return f"{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}"
# --- 合并后的模型类,支持 teacher 与 stmlp 两种分支 ---
class STMLP(nn.Module):
def __init__(self, args):
super().__init__()
# 参数从字典中读取
self.adj_mx = get_adj(args)
self.num_nodes = args["num_nodes"]
self.feature_dim = args["input_dim"]
self.input_window = args["input_window"]
self.output_window = args["output_window"]
self.output_dim = args["output_dim"]
self.device = args["device"]
self.gcn_true = args["gcn_true"]
self.buildA_true = args["buildA_true"]
self.gcn_depth = args["gcn_depth"]
self.dropout = args["dropout"]
self.subgraph_size = args["subgraph_size"]
self.node_dim = args["node_dim"]
self.dilation_exponential = args["dilation_exponential"]
self.conv_channels = args["conv_channels"]
self.residual_channels = args["residual_channels"]
self.skip_channels = args["skip_channels"]
self.end_channels = args["end_channels"]
self.layers = args["layers"]
self.propalpha = args["propalpha"]
self.tanhalpha = args["tanhalpha"]
self.layer_norm_affline = args["layer_norm_affline"]
self.model_type = args["model_type"] # 'teacher' 或 'stmlp'
self.idx = torch.arange(self.num_nodes).to(self.device)
self.predefined_A = (
None
if self.adj_mx is None
else (torch.tensor(self.adj_mx) - torch.eye(self.num_nodes)).to(self.device)
)
self.static_feat = None
# transformer保留原有结构
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=12, nhead=4, batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
self.encoder_layer, num_layers=3
)
# 构建各层
self.start_conv = nn.Conv2d(
self.feature_dim, self.residual_channels, kernel_size=1
)
self.gc = GraphConstructor(
self.num_nodes,
self.subgraph_size,
self.node_dim,
self.device,
alpha=self.tanhalpha,
static_feat=self.static_feat,
)
# 计算 receptive_field
kernel_size = 7
if self.dilation_exponential > 1:
self.receptive_field = int(
self.output_dim
+ (kernel_size - 1)
* (self.dilation_exponential**self.layers - 1)
/ (self.dilation_exponential - 1)
)
else:
self.receptive_field = self.layers * (kernel_size - 1) + self.output_dim
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.norm = nn.ModuleList()
self.stu_mlp = nn.ModuleList(
[
nn.Sequential(nn.Linear(c, c), nn.Linear(c, c), nn.Linear(c, c))
for c in [13, 7, 1]
]
)
if self.gcn_true:
self.gconv1 = nn.ModuleList()
self.gconv2 = nn.ModuleList()
new_dilation = 1
for i in range(1):
rf_size_i = (
int(
1
+ i
* (kernel_size - 1)
* (self.dilation_exponential**self.layers - 1)
/ (self.dilation_exponential - 1)
)
if self.dilation_exponential > 1
else i * self.layers * (kernel_size - 1) + 1
)
for j in range(1, self.layers + 1):
rf_size_j = (
int(
rf_size_i
+ (kernel_size - 1)
* (self.dilation_exponential**j - 1)
/ (self.dilation_exponential - 1)
)
if self.dilation_exponential > 1
else rf_size_i + j * (kernel_size - 1)
)
self.filter_convs.append(
DilatedInception(
self.residual_channels,
self.conv_channels,
dilation_factor=new_dilation,
)
)
self.gate_convs.append(
DilatedInception(
self.residual_channels,
self.conv_channels,
dilation_factor=new_dilation,
)
)
self.residual_convs.append(
nn.Conv2d(self.conv_channels, self.residual_channels, kernel_size=1)
)
k_size = (
(1, self.input_window - rf_size_j + 1)
if self.input_window > self.receptive_field
else (1, self.receptive_field - rf_size_j + 1)
)
self.skip_convs.append(
nn.Conv2d(
self.conv_channels, self.skip_channels, kernel_size=k_size
)
)
if self.gcn_true:
self.gconv1.append(
MixProp(
self.conv_channels,
self.residual_channels,
self.gcn_depth,
self.dropout,
self.propalpha,
)
)
self.gconv2.append(
MixProp(
self.conv_channels,
self.residual_channels,
self.gcn_depth,
self.dropout,
self.propalpha,
)
)
norm_size = (
(
self.residual_channels,
self.num_nodes,
self.input_window - rf_size_j + 1,
)
if self.input_window > self.receptive_field
else (
self.residual_channels,
self.num_nodes,
self.receptive_field - rf_size_j + 1,
)
)
self.norm.append(
LayerNorm(norm_size, elementwise_affine=self.layer_norm_affline)
)
new_dilation *= self.dilation_exponential
self.end_conv_1 = nn.Conv2d(
self.skip_channels, self.end_channels, kernel_size=1, bias=True
)
self.end_conv_2 = nn.Conv2d(
self.end_channels, self.output_window, kernel_size=1, bias=True
)
k0 = (
(1, self.input_window)
if self.input_window > self.receptive_field
else (1, self.receptive_field)
)
self.skip0 = nn.Conv2d(
self.feature_dim, self.skip_channels, kernel_size=k0, bias=True
)
kE = (
(1, self.input_window - self.receptive_field + 1)
if self.input_window > self.receptive_field
else (1, 1)
)
self.skipE = nn.Conv2d(
self.residual_channels, self.skip_channels, kernel_size=kE, bias=True
)
# 最后输出分支,根据模型类型选择不同的头
if self.model_type == "teacher":
self.tt_linear1 = nn.Linear(self.residual_channels, self.input_window)
self.tt_linear2 = nn.Linear(1, 32)
self.ss_linear1 = nn.Linear(self.residual_channels, self.input_window)
self.ss_linear2 = nn.Linear(1, 32)
else: # stmlp
self.out_linear1 = nn.Linear(self.residual_channels, self.input_window)
self.out_linear2 = nn.Linear(1, 32)
def forward(self, source, idx=None):
source = source[..., 0:1]
sout, tout = [], []
inputs = source.transpose(1, 3)
assert inputs.size(3) == self.input_window, "input sequence length mismatch"
if self.input_window < self.receptive_field:
inputs = F.pad(inputs, (self.receptive_field - self.input_window, 0, 0, 0))
if self.gcn_true:
adp = (
self.gc(self.idx if idx is None else idx)
if self.buildA_true
else self.predefined_A
)
x = self.start_conv(inputs)
skip = self.skip0(F.dropout(inputs, self.dropout, training=self.training))
for i in range(self.layers):
residual = x
filters = torch.tanh(self.filter_convs[i](x))
gate = torch.sigmoid(self.gate_convs[i](x))
x = F.dropout(filters * gate, self.dropout, training=self.training)
tout.append(x)
s = self.skip_convs[i](x)
skip = s + skip
if self.gcn_true:
x = self.gconv1[i](x, adp) + self.gconv2[i](x, adp.transpose(1, 0))
else:
x = self.stu_mlp[i](x)
x = x + residual[:, :, :, -x.size(3) :]
x = self.norm[i](x, self.idx if idx is None else idx)
sout.append(x)
skip = self.skipE(x) + skip
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)
if self.model_type == "teacher":
ttout = self.tt_linear2(
self.tt_linear1(tout[-1].transpose(1, 3)).transpose(1, 3)
)
ssout = self.ss_linear2(
self.ss_linear1(sout[-1].transpose(1, 3)).transpose(1, 3)
)
return x, ttout, ssout
else:
x_ = self.out_linear2(
self.out_linear1(tout[-1].transpose(1, 3)).transpose(1, 3)
)
return x, x_, x