12 lines
359 B
Python
12 lines
359 B
Python
import torch
|
|
|
|
|
|
class QuadraticModel(torch.nn.Module):
|
|
def __init__(self, in_channels, class_num):
|
|
super(QuadraticModel, self).__init__()
|
|
x = torch.ones((in_channels, 1))
|
|
self.x = torch.nn.parameter.Parameter(x.uniform_(-10.0, 10.0).float())
|
|
|
|
def forward(self, A):
|
|
return torch.sum(self.x * torch.matmul(A, self.x), -1)
|