FS-TFP/federatedscope/tabular/model/quadratic.py

12 lines
359 B
Python

import torch
class QuadraticModel(torch.nn.Module):
def __init__(self, in_channels, class_num):
super(QuadraticModel, self).__init__()
x = torch.ones((in_channels, 1))
self.x = torch.nn.parameter.Parameter(x.uniform_(-10.0, 10.0).float())
def forward(self, A):
return torch.sum(self.x * torch.matmul(A, self.x), -1)