优化GWN, ddgcrn的代码,使其更简洁易读

This commit is contained in:
czzhangheng 2025-03-27 20:08:17 +08:00
parent d016dd5980
commit 1b76cc6ce2
1 changed files with 1 additions and 0 deletions

View File

@ -116,3 +116,4 @@ class DGCN(nn.Module):
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul( return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul(
torch.matmul(D_inv, graph + I), D_inv) torch.matmul(D_inv, graph + I), D_inv)