import torch class LogisticRegression(torch.nn.Module): def __init__(self, in_channels, class_num, use_bias=True): super(LogisticRegression, self).__init__() self.fc = torch.nn.Linear(in_channels, class_num, bias=use_bias) def forward(self, x): return self.fc(x)