新增消耗分析模式,只需在原有的mode中调整为benchmark即可
This commit is contained in:
parent
8c839642e1
commit
d016dd5980
|
|
@ -7,14 +7,13 @@ class DGCRM(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.node_num, self.input_dim, self.num_layers = node_num, dim_in, num_layers
|
self.node_num, self.input_dim, self.num_layers = node_num, dim_in, num_layers
|
||||||
self.cells = nn.ModuleList(
|
self.cells = nn.ModuleList(
|
||||||
[DDGCRNCell(node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim) for i in
|
[DDGCRNCell(node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim) for i in range(num_layers)]
|
||||||
range(num_layers)])
|
)
|
||||||
|
|
||||||
def forward(self, x, init_state, node_embeddings):
|
def forward(self, x, init_state, node_embeddings):
|
||||||
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
state, inner = init_state[i], []
|
state, inner = init_state[i].to(x.device), []
|
||||||
state = state.to(x.device)
|
|
||||||
for t in range(x.shape[1]):
|
for t in range(x.shape[1]):
|
||||||
state = self.cells[i](x[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
state = self.cells[i](x[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
||||||
inner.append(state)
|
inner.append(state)
|
||||||
|
|
@ -69,8 +68,8 @@ class DDGCRNCell(nn.Module):
|
||||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.node_num, self.hidden_dim = node_num, dim_out
|
self.node_num, self.hidden_dim = node_num, dim_out
|
||||||
self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim)
|
self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim, node_num)
|
||||||
self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim)
|
self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim, node_num)
|
||||||
|
|
||||||
def forward(self, x, state, node_embeddings):
|
def forward(self, x, state, node_embeddings):
|
||||||
inp = torch.cat((x, state), -1)
|
inp = torch.cat((x, state), -1)
|
||||||
|
|
@ -84,7 +83,7 @@ class DDGCRNCell(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DGCN(nn.Module):
|
class DGCN(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
|
def __init__(self, dim_in, dim_out, cheb_k, embed_dim, num_nodes):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cheb_k, self.embed_dim = cheb_k, embed_dim
|
self.cheb_k, self.embed_dim = cheb_k, embed_dim
|
||||||
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
|
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
|
||||||
|
|
@ -98,14 +97,16 @@ class DGCN(nn.Module):
|
||||||
('sigmoid2', nn.Sigmoid()),
|
('sigmoid2', nn.Sigmoid()),
|
||||||
('fc3', nn.Linear(2, embed_dim))
|
('fc3', nn.Linear(2, embed_dim))
|
||||||
]))
|
]))
|
||||||
|
# 预注册恒定不变的单位矩阵
|
||||||
|
self.register_buffer('eye', torch.eye(num_nodes))
|
||||||
|
|
||||||
def forward(self, x, node_embeddings):
|
def forward(self, x, node_embeddings):
|
||||||
node_num = node_embeddings[0].shape[1]
|
supp1 = self.eye.to(node_embeddings[0].device)
|
||||||
supp1 = torch.eye(node_num).to(node_embeddings[0].device)
|
|
||||||
filt = self.fc(x)
|
filt = self.fc(x)
|
||||||
nodevec = torch.tanh(node_embeddings[0] * filt)
|
nodevec = torch.tanh(node_embeddings[0] * filt)
|
||||||
supp2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supp1)
|
supp2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supp1)
|
||||||
x_g = torch.stack([torch.einsum("nm,bmc->bnc", supp1, x), torch.einsum("bnm,bmc->bnc", supp2, x)], dim=1)
|
x_g = torch.stack([torch.einsum("nm,bmc->bnc", supp1, x),
|
||||||
|
torch.einsum("bnm,bmc->bnc", supp2, x)], dim=1)
|
||||||
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
|
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
|
||||||
bias = torch.matmul(node_embeddings[1], self.bias_pool)
|
bias = torch.matmul(node_embeddings[1], self.bias_pool)
|
||||||
return torch.einsum('bnki,nkio->bno', x_g.permute(0, 2, 1, 3), weights) + bias
|
return torch.einsum('bnki,nkio->bno', x_g.permute(0, 2, 1, 3), weights) + bias
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue