TrafficWheel/model/D2STGNN/dynamic_graph_conv/utils/distance.py

60 lines
2.7 KiB
Python

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistanceFunction(nn.Module):
def __init__(self, **model_args):
super().__init__()
# attributes
self.hidden_dim = model_args['num_hidden']
self.node_dim = model_args['node_hidden']
self.time_slot_emb_dim = self.hidden_dim
self.input_seq_len = model_args['seq_len']
# Time Series Feature Extraction
self.dropout = nn.Dropout(model_args['dropout'])
self.fc_ts_emb1 = nn.Linear(self.input_seq_len, self.hidden_dim * 2)
self.fc_ts_emb2 = nn.Linear(self.hidden_dim * 2, self.hidden_dim)
self.ts_feat_dim= self.hidden_dim
# Time Slot Embedding Extraction
self.time_slot_embedding = nn.Linear(model_args['time_emb_dim'], self.time_slot_emb_dim)
# Distance Score
self.all_feat_dim = self.ts_feat_dim + self.node_dim + model_args['time_emb_dim']*2
self.WQ = nn.Linear(self.all_feat_dim, self.hidden_dim, bias=False)
self.WK = nn.Linear(self.all_feat_dim, self.hidden_dim, bias=False)
self.bn = nn.BatchNorm1d(self.hidden_dim*2)
def reset_parameters(self):
# 初始化所有线性层的参数
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, X, E_d, E_u, T_D, D_W):
# last pooling
T_D = T_D[:, -1, :, :]
D_W = D_W[:, -1, :, :]
# dynamic information
X = X[:, :, :, 0].transpose(1, 2).contiguous() # X->[batch_size, seq_len, num_nodes]->[batch_size, num_nodes, seq_len]
[batch_size, num_nodes, seq_len] = X.shape
X = X.view(batch_size * num_nodes, seq_len)
dy_feat = self.fc_ts_emb2(self.dropout(self.bn(F.relu(self.fc_ts_emb1(X))))) # [batchsize, num_nodes, hidden_dim]
dy_feat = dy_feat.view(batch_size, num_nodes, -1)
# node embedding
emb1 = E_d.unsqueeze(0).expand(batch_size, -1, -1)
emb2 = E_u.unsqueeze(0).expand(batch_size, -1, -1)
# distance calculation
X1 = torch.cat([dy_feat, T_D, D_W, emb1], dim=-1) # hidden state for calculating distance
X2 = torch.cat([dy_feat, T_D, D_W, emb2], dim=-1) # hidden state for calculating distance
X = [X1, X2]
adjacent_list = []
for _ in X:
Q = self.WQ(_)
K = self.WK(_)
QKT = torch.bmm(Q, K.transpose(-1, -2)) / math.sqrt(self.hidden_dim)
W = torch.softmax(QKT, dim=-1)
adjacent_list.append(W)
return adjacent_list