303 lines
13 KiB
Python
303 lines
13 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from torch.nn import CrossEntropyLoss
|
|
from federatedscope.nlp.hetero_tasks.dataset.utils import setup_tokenizer
|
|
from federatedscope.nlp.loss.label_smooth_loss import \
|
|
LabelSmoothingLoss
|
|
|
|
|
|
class ModelOutput(object):
|
|
def __init__(self,
|
|
loss=None,
|
|
regular_loss=None,
|
|
contrastive_loss=None,
|
|
logits=None,
|
|
hidden_states=None,
|
|
example_indices=None):
|
|
self.loss = loss
|
|
self.regular_loss = regular_loss
|
|
self.contrastive_loss = contrastive_loss
|
|
self.logits = logits
|
|
self.hidden_states = hidden_states
|
|
self.example_indices = example_indices
|
|
|
|
|
|
class ContrastiveHead(nn.Module):
|
|
def __init__(self, input_dim, inner_dim, out_dim, dropout_prob):
|
|
super().__init__()
|
|
|
|
self.dense = nn.Linear(input_dim, inner_dim)
|
|
self.dropout = nn.Dropout(p=dropout_prob)
|
|
self.out_prj = nn.Linear(inner_dim, out_dim)
|
|
|
|
def forward(self, x):
|
|
x = self.dense(self.dropout(x))
|
|
x = torch.tanh(x)
|
|
x = self.out_prj(self.dropout(x))
|
|
return x
|
|
|
|
|
|
class ATCModel(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
from transformers.models.encoder_decoder import EncoderDecoderModel
|
|
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
|
|
|
self.model = EncoderDecoderModel.from_encoder_decoder_pretrained(
|
|
config.model_type, config.model_type)
|
|
self.lm_head = BertLMPredictionHead(self.model.encoder.config)
|
|
|
|
self.client_id = None
|
|
self.task = config.task
|
|
self.pt_cfg = self.model.encoder.config
|
|
self.vocab_size = self.pt_cfg.vocab_size
|
|
self.hidden_size = self.pt_cfg.hidden_size
|
|
self.dropout_prob = self.pt_cfg.hidden_dropout_prob
|
|
self.dropout = nn.Dropout(self.dropout_prob)
|
|
|
|
setup_tokenizer(config.model_type) # update global token ids
|
|
from federatedscope.nlp.hetero_tasks.dataset.utils import \
|
|
BOS_TOKEN_ID, EOS_TOKEN_ID, PAD_TOKEN_ID
|
|
self.label_smoothing = config.label_smoothing
|
|
self.padding_idx = PAD_TOKEN_ID
|
|
self.classifier = nn.Linear(self.hidden_size, config.num_labels)
|
|
|
|
self.use_contrastive_loss = config.use_contrastive_loss
|
|
if self.use_contrastive_loss:
|
|
self.contrast_topk = config.contrast_topk
|
|
self.contrast_temp = config.contrast_temp
|
|
self.contrast_head = ContrastiveHead(
|
|
input_dim=self.hidden_size,
|
|
inner_dim=self.hidden_size,
|
|
out_dim=self.hidden_size,
|
|
dropout_prob=self.dropout_prob)
|
|
|
|
# for eval generation
|
|
self.model.config.decoder_start_token_id = BOS_TOKEN_ID
|
|
self.model.config.eos_token_id = EOS_TOKEN_ID
|
|
self.model.config.pad_token_id = PAD_TOKEN_ID
|
|
self.model.config.vocab_size = self.pt_cfg.vocab_size
|
|
self.model.config.max_length = config.max_length
|
|
self.model.config.min_length = config.min_length
|
|
self.model.config.no_repeat_ngram_size = config.no_repeat_ngram_size
|
|
self.model.config.length_penalty = config.length_penalty
|
|
self.model.config.num_beams = config.num_beams
|
|
|
|
def update_client_id(self, client_id):
|
|
self.client_id = client_id
|
|
|
|
def generate(self, **kwargs):
|
|
return self.model.generate(**kwargs)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
start_positions=None,
|
|
end_positions=None,
|
|
labels=None,
|
|
pretrain_task=None,
|
|
contrast_monitor=None,
|
|
in_contrast_prepare=None,
|
|
example_indices=None,
|
|
):
|
|
if in_contrast_prepare: # return dec_hidden_states & dec_out
|
|
self.eval()
|
|
with torch.no_grad():
|
|
example_indices = [
|
|
k for k in example_indices
|
|
if k.item() in contrast_monitor.synth_tokens
|
|
]
|
|
if len(example_indices) == 0:
|
|
return ModelOutput(example_indices=example_indices)
|
|
|
|
example_indices = torch.stack(example_indices)
|
|
synth_input_ids = torch.stack([
|
|
contrast_monitor.synth_tokens[k.item()]
|
|
for k in example_indices
|
|
]).to(self.model.device)
|
|
|
|
enc_hidden = torch.stack([
|
|
contrast_monitor.enc_hidden[k.item()]
|
|
for k in example_indices
|
|
]).to(self.model.device)
|
|
outputs = self.model.decoder.bert(
|
|
input_ids=synth_input_ids,
|
|
encoder_hidden_states=enc_hidden,
|
|
)
|
|
logits = self.model.decoder.cls(outputs.last_hidden_state)
|
|
dec_hidden = self.contrast_head(
|
|
outputs.last_hidden_state).mean(1)
|
|
|
|
return ModelOutput(logits=logits.argmax(-1),
|
|
hidden_states=dec_hidden,
|
|
example_indices=example_indices)
|
|
|
|
enc_outputs = self.model.encoder(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
regular_loss, contrastive_loss = None, None
|
|
if self.task == 'pretrain':
|
|
if pretrain_task == 'mlm':
|
|
logits = self.lm_head(enc_outputs.last_hidden_state)
|
|
loss_fct = CrossEntropyLoss()
|
|
masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size),
|
|
labels.view(-1))
|
|
loss = masked_lm_loss
|
|
|
|
elif pretrain_task == 'denoise':
|
|
dec_outputs = self.model.decoder.bert(
|
|
input_ids=labels,
|
|
encoder_hidden_states=enc_outputs.last_hidden_state,
|
|
encoder_attention_mask=attention_mask,
|
|
)
|
|
logits = self.model.decoder.cls(
|
|
dec_outputs.last_hidden_state)[:, :-1, :]
|
|
loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx)
|
|
denoise_loss = loss_fct(
|
|
logits.contiguous().view(-1, self.vocab_size),
|
|
labels[:, 1:].contiguous().view(-1))
|
|
loss = denoise_loss
|
|
|
|
else:
|
|
raise KeyError(
|
|
'Unsupported pretrain task: \'{}\''.format(pretrain_task))
|
|
|
|
else:
|
|
# regular loss
|
|
if self.task in {'imdb', 'agnews'}:
|
|
pooled_output = self.dropout(enc_outputs.pooler_output)
|
|
logits = self.classifier(pooled_output)
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, logits.size(-1)),
|
|
labels.view(-1))
|
|
|
|
elif self.task in {'squad', 'newsqa'}:
|
|
logits = self.classifier(enc_outputs.last_hidden_state)
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
start_logits = start_logits.squeeze(-1).contiguous()
|
|
end_logits = end_logits.squeeze(-1).contiguous()
|
|
logits = (start_logits, end_logits)
|
|
|
|
# sometimes the start/end positions are outside our model
|
|
# inputs, we ignore these terms
|
|
ignored_index = start_logits.size(1)
|
|
start_positions = start_positions.clamp(0, ignored_index)
|
|
end_positions = end_positions.clamp(0, ignored_index)
|
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
|
start_loss = loss_fct(start_logits, start_positions)
|
|
end_loss = loss_fct(end_logits, end_positions)
|
|
loss = (start_loss + end_loss) / 2
|
|
|
|
elif self.task in {'cnndm', 'msqg'}:
|
|
dec_outputs = self.model.decoder.bert(
|
|
input_ids=labels,
|
|
encoder_hidden_states=enc_outputs.last_hidden_state,
|
|
encoder_attention_mask=attention_mask,
|
|
)
|
|
dec_hidden_states = dec_outputs.last_hidden_state
|
|
logits = self.model.decoder.cls(dec_hidden_states)[:, :-1, :]
|
|
|
|
num_tokens = labels[:, 1:].ne(self.padding_idx).sum().item()
|
|
label_smoothing = self.label_smoothing if self.training \
|
|
else 0.0
|
|
if label_smoothing > 0:
|
|
loss_fct = LabelSmoothingLoss(
|
|
label_smoothing,
|
|
self.vocab_size,
|
|
ignore_index=self.padding_idx,
|
|
).to(logits.device)
|
|
loss = loss_fct(
|
|
F.log_softmax(logits.contiguous().view(
|
|
-1, self.vocab_size),
|
|
dim=-1),
|
|
labels[:, 1:].contiguous().view(-1)) / num_tokens
|
|
else:
|
|
loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx)
|
|
loss = loss_fct(
|
|
logits.contiguous().view(-1, self.vocab_size),
|
|
labels[:, 1:].contiguous().view(-1))
|
|
|
|
else:
|
|
raise KeyError('Unsupported task: \'{}\''.format(self.task))
|
|
|
|
# contrastive loss
|
|
if self.use_contrastive_loss and self.training:
|
|
regular_loss = loss.clone()
|
|
example_indices = [
|
|
k for k in example_indices
|
|
if k.item() in contrast_monitor.synth_tokens
|
|
]
|
|
all_group_ids = contrast_monitor.all_group_ids[self.client_id]
|
|
topk_group_ids = \
|
|
contrast_monitor.topk_group_ids[self.client_id]
|
|
if len(example_indices) > 0 and len(topk_group_ids) > 1:
|
|
example_indices = torch.stack(example_indices)
|
|
synth_input_ids = torch.stack([
|
|
contrast_monitor.synth_tokens[k.item()]
|
|
for k in example_indices
|
|
]).to(self.model.device)
|
|
|
|
contrast_enc_hidden = torch.stack([
|
|
contrast_monitor.enc_hidden[k.item()]
|
|
for k in example_indices
|
|
]).to(self.model.device)
|
|
contrast_outputs = self.model.decoder.bert(
|
|
input_ids=synth_input_ids,
|
|
encoder_hidden_states=contrast_enc_hidden,
|
|
)
|
|
cur_dec_hidden = self.contrast_head(
|
|
contrast_outputs.last_hidden_state).mean(1)
|
|
|
|
pos_client_ids = [
|
|
x for x in topk_group_ids[1:self.contrast_topk + 1]
|
|
]
|
|
all_dec_hiddens = contrast_monitor.dec_hidden
|
|
sim_hiddens = [[
|
|
all_dec_hiddens[cid][k.item()] for k in example_indices
|
|
] for cid in pos_client_ids]
|
|
sim_hiddens = torch.stack([
|
|
torch.stack(hid) for hid in sim_hiddens
|
|
]).mean(0).to(self.model.device)
|
|
sim_matrix = F.cosine_similarity(cur_dec_hidden,
|
|
sim_hiddens,
|
|
dim=-1)
|
|
nominator = torch.exp(sim_matrix / self.contrast_temp)
|
|
denominator = nominator
|
|
|
|
neg_client_ids = [
|
|
x for x in all_group_ids[::-1][:self.contrast_topk]
|
|
if x not in topk_group_ids
|
|
]
|
|
if len(neg_client_ids) > 0:
|
|
dissim_hiddens = [[
|
|
all_dec_hiddens[cid][k.item()]
|
|
for k in example_indices
|
|
] for cid in neg_client_ids]
|
|
dissim_hiddens = torch.stack([
|
|
torch.stack(hid) for hid in dissim_hiddens
|
|
]).to(self.model.device)
|
|
dissim_matrix = F.cosine_similarity(
|
|
cur_dec_hidden.unsqueeze(0),
|
|
dissim_hiddens,
|
|
dim=-1)
|
|
denominator = denominator + (torch.exp(
|
|
dissim_matrix / self.contrast_temp)).sum(0)
|
|
|
|
contrastive_loss = -torch.log(
|
|
nominator / denominator).mean()
|
|
loss += contrastive_loss
|
|
|
|
return ModelOutput(loss=loss,
|
|
regular_loss=regular_loss,
|
|
contrastive_loss=contrastive_loss,
|
|
logits=logits)
|