323 lines
14 KiB
Python
323 lines
14 KiB
Python
import os
|
|
import logging
|
|
import random
|
|
import csv
|
|
import json
|
|
import gzip
|
|
import zipfile
|
|
import shutil
|
|
import copy
|
|
import torch
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from federatedscope.core.data.utils import download_url
|
|
from federatedscope.core.gpu_manager import GPUManager
|
|
from federatedscope.nlp.hetero_tasks.model.model import ATCModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HeteroNLPDataLoader(object):
|
|
"""
|
|
Load hetero NLP task datasets (including multiple datasets), split them
|
|
into train/val/test, and partition into several clients.
|
|
"""
|
|
def __init__(self, data_dir, data_name, num_of_clients, split=[0.9, 0.1]):
|
|
self.data_dir = data_dir
|
|
self.data_name = data_name
|
|
self.num_of_clients = num_of_clients
|
|
self.split = split # split for train:val
|
|
self.train_data = []
|
|
self.val_data = []
|
|
self.test_data = []
|
|
|
|
def get_data(self):
|
|
for each_data, each_client_num in zip(self.data_name,
|
|
self.num_of_clients):
|
|
train_and_val_data = self._load(each_data, 'train',
|
|
each_client_num)
|
|
each_train_data = [
|
|
data[:int(self.split[0] * len(data))]
|
|
for data in train_and_val_data
|
|
]
|
|
each_val_data = [
|
|
data[-int(self.split[1] * len(data)):]
|
|
for data in train_and_val_data
|
|
]
|
|
each_test_data = self._load(each_data, 'test', each_client_num)
|
|
self.train_data.extend(each_train_data)
|
|
self.val_data.extend(each_val_data)
|
|
self.test_data.extend(each_test_data)
|
|
|
|
return {
|
|
'train': self.train_data,
|
|
'val': self.val_data,
|
|
'test': self.test_data
|
|
}
|
|
|
|
def _load(self, dataset, split, num_of_client):
|
|
data_dir = os.path.join(self.data_dir, dataset)
|
|
if not os.path.exists(data_dir):
|
|
logger.info(f'Start tp download the dataset {dataset} ...')
|
|
self._download_and_extract(dataset)
|
|
|
|
# read data
|
|
data = []
|
|
if dataset == 'imdb':
|
|
pos_files = os.listdir(os.path.join(data_dir, split, 'pos'))
|
|
neg_files = os.listdir(os.path.join(data_dir, split, 'neg'))
|
|
for file in pos_files:
|
|
path = os.path.join(data_dir, split, 'pos', file)
|
|
with open(path) as f:
|
|
line = f.readline()
|
|
data.append({'text': line, 'label': 1})
|
|
for file in neg_files:
|
|
path = os.path.join(data_dir, split, 'neg', file)
|
|
with open(path) as f:
|
|
line = f.readline()
|
|
data.append({'text': line, 'label': 0})
|
|
random.shuffle(data)
|
|
|
|
elif dataset == 'agnews':
|
|
with open(os.path.join(data_dir, split + '.csv'),
|
|
encoding="utf-8") as csv_file:
|
|
csv_reader = csv.reader(csv_file,
|
|
quotechar='"',
|
|
delimiter=",",
|
|
quoting=csv.QUOTE_ALL,
|
|
skipinitialspace=True)
|
|
for i, row in enumerate(csv_reader):
|
|
label, title, description = row
|
|
label = int(label) - 1
|
|
text = ' [SEP] '.join((title, description))
|
|
data.append({'text': text, 'label': label})
|
|
|
|
elif dataset == 'squad':
|
|
with open(os.path.join(data_dir, split + '.json'),
|
|
'r',
|
|
encoding='utf-8') as reader:
|
|
raw_data = json.load(reader)['data']
|
|
for line in raw_data:
|
|
for para in line['paragraphs']:
|
|
context = para['context']
|
|
for qa in para['qas']:
|
|
data.append({'context': context, 'qa': qa})
|
|
|
|
elif dataset == 'newsqa':
|
|
with gzip.GzipFile(os.path.join(data_dir, split + '.jsonl.gz'),
|
|
'r') as reader:
|
|
content = reader.read().decode('utf-8').strip().split('\n')[1:]
|
|
raw_data = [json.loads(line) for line in content]
|
|
for line in raw_data:
|
|
context = line['context']
|
|
for qa in line['qas']:
|
|
data.append({'context': context, 'qa': qa})
|
|
|
|
elif dataset in {'cnndm', 'msqg'}:
|
|
src_file = os.path.join(data_dir, split + '.src')
|
|
tgt_file = os.path.join(data_dir, split + '.tgt')
|
|
with open(src_file) as f:
|
|
src_data = [
|
|
line.strip().replace('<S_SEP>', '[SEP]') for line in f
|
|
]
|
|
with open(tgt_file) as f:
|
|
tgt_data = [
|
|
line.strip().replace('<S_SEP>', '[SEP]') for line in f
|
|
]
|
|
for src, tgt in zip(src_data, tgt_data):
|
|
data.append({'src': src, 'tgt': tgt})
|
|
|
|
# split data
|
|
logger.info(f'Spliting dataset {dataset} ({split})')
|
|
splited_data = []
|
|
n = len(data) // num_of_client
|
|
data_idx = 0
|
|
for i in range(num_of_client):
|
|
num_split = n if i < num_of_client - 1 else \
|
|
len(data) - n * (num_of_client - 1)
|
|
cur_data = data[data_idx:data_idx + num_split]
|
|
data_idx += num_split
|
|
splited_data.append(cur_data)
|
|
logger.info(f'Dataset {dataset} ({split}) is splited into '
|
|
f'{[len(x) for x in splited_data]}')
|
|
|
|
return splited_data
|
|
|
|
def _download_and_extract(self, dataset):
|
|
url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com'
|
|
os.makedirs(self.data_dir, exist_ok=True)
|
|
download_url(f'{url}/{dataset}.zip', self.data_dir)
|
|
|
|
raw_dir = os.path.join(self.data_dir, dataset + '_raw')
|
|
extract_dir = os.path.join(self.data_dir, dataset)
|
|
with zipfile.ZipFile(os.path.join(self.data_dir, f'{dataset}.zip'),
|
|
'r') as zip_ref:
|
|
zip_ref.extractall(raw_dir)
|
|
shutil.move(os.path.join(raw_dir, dataset), self.data_dir)
|
|
if os.path.exists(os.path.join(extract_dir, '.DS_Store')):
|
|
os.remove(os.path.join(extract_dir, '.DS_Store'))
|
|
os.remove(os.path.join(self.data_dir, f'{dataset}.zip'))
|
|
shutil.rmtree(raw_dir)
|
|
|
|
|
|
class SynthDataProcessor(object):
|
|
def __init__(self, config, datasets):
|
|
self.cfg = config
|
|
self.device = GPUManager(
|
|
gpu_available=self.cfg.use_gpu,
|
|
specified_device=self.cfg.device).auto_choice()
|
|
self.pretrain_dir = config.federate.atc_load_from
|
|
self.cache_dir = 'cache_debug' if \
|
|
config.data.is_debug else config.data.cache_dir
|
|
self.save_dir = os.path.join(self.cache_dir, 'synthetic')
|
|
self.batch_size = config.data.hetero_synth_batch_size
|
|
self.datasets = datasets
|
|
self.num_clients = len(datasets)
|
|
self.synth_prim_weight = config.data.hetero_synth_prim_weight
|
|
self.synth_feat_dim = config.data.hetero_synth_feat_dim
|
|
self.models = {}
|
|
|
|
def save_data(self):
|
|
if os.path.exists(self.save_dir):
|
|
return
|
|
|
|
max_sz, max_len = 1e8, 0
|
|
for client_id in range(1, self.num_clients + 1):
|
|
dataset = self.datasets[client_id -
|
|
1]['train_contrast']['dataloader'].dataset
|
|
max_sz = min(max_sz, len(dataset))
|
|
max_len = max(max_len, len(dataset[0]['token_ids']))
|
|
enc_hiddens = np.memmap(filename=os.path.join(self.cfg.outdir,
|
|
'tmp_feat.memmap'),
|
|
shape=(self.num_clients, max_sz, max_len,
|
|
self.synth_feat_dim),
|
|
mode='w+',
|
|
dtype=np.float32)
|
|
self._get_models()
|
|
|
|
logger.info('Generating synthetic encoder hidden states')
|
|
for client_id in tqdm(range(1, self.num_clients + 1)):
|
|
dataloader = self.datasets[client_id -
|
|
1]['train_contrast']['dataloader']
|
|
model = self.models[client_id]
|
|
model.eval()
|
|
model.to(self.device)
|
|
enc_hid = []
|
|
for batch_i, data_batch in tqdm(enumerate(dataloader),
|
|
total=len(dataloader)):
|
|
token_ids = data_batch['token_ids']
|
|
token_type_ids = data_batch['token_type_ids']
|
|
attention_mask = data_batch['attention_mask']
|
|
enc_out = model.model.encoder(
|
|
input_ids=token_ids.to(self.device),
|
|
attention_mask=attention_mask.to(self.device),
|
|
token_type_ids=token_type_ids.to(self.device),
|
|
)
|
|
enc_hid.append(enc_out.last_hidden_state.detach().cpu())
|
|
|
|
enc_hid = torch.cat(enc_hid)
|
|
if enc_hid.size(1) < max_len:
|
|
enc_hid = torch.cat([
|
|
enc_hid,
|
|
torch.zeros(enc_hid.size(0), max_len - enc_hid.size(1),
|
|
self.synth_feat_dim)
|
|
],
|
|
dim=1)
|
|
enc_hiddens[client_id - 1] = enc_hid[:max_sz]
|
|
model.to('cpu')
|
|
|
|
all_hids = torch.from_numpy(enc_hiddens)
|
|
prim_indices = [
|
|
random.randint(0,
|
|
len(all_hids) - 1) for _ in range(len(all_hids[0]))
|
|
] # avoid over-smooth results when setting
|
|
# equal merging weights to all clients
|
|
all_weights = torch.ones(len(all_hids), len(all_hids[0]))
|
|
all_weights *= (1 - self.synth_prim_weight) / (len(all_hids) - 1)
|
|
for i, pi in enumerate(prim_indices):
|
|
all_weights[pi, i] = self.synth_prim_weight
|
|
avg_hids = (all_hids * all_weights[:, :, None, None]).sum(0)
|
|
|
|
logger.info('Generating synthetic input tokens')
|
|
lm_head = self._get_avg_lm_head().to(self.device)
|
|
with torch.no_grad():
|
|
pred_toks = torch.cat([
|
|
lm_head(avg_hids[i:i + self.batch_size].to(
|
|
self.device)).detach().cpu().argmax(dim=-1)
|
|
for i in tqdm(range(0, avg_hids.size(0), self.batch_size))
|
|
])
|
|
|
|
if self.cache_dir:
|
|
logger.info('Saving synthetic data to \'{}\''.format(
|
|
self.save_dir))
|
|
os.makedirs(self.save_dir, exist_ok=True)
|
|
saved_feats = np.memmap(
|
|
filename=os.path.join(
|
|
self.save_dir,
|
|
'feature_{}.memmap'.format(self.synth_prim_weight)),
|
|
shape=avg_hids.size(),
|
|
mode='w+',
|
|
dtype=np.float32,
|
|
)
|
|
saved_toks = np.memmap(
|
|
filename=os.path.join(
|
|
self.save_dir,
|
|
'token_{}.memmap'.format(self.synth_prim_weight)),
|
|
shape=pred_toks.size(),
|
|
mode='w+',
|
|
dtype=np.int64,
|
|
)
|
|
for i in range(len(avg_hids)):
|
|
saved_feats[i] = avg_hids[i]
|
|
saved_toks[i] = pred_toks[i]
|
|
shapes = {'feature': avg_hids.size(), 'token': pred_toks.size()}
|
|
with open(os.path.join(self.save_dir, 'shapes.json'), 'w') as f:
|
|
json.dump(shapes, f)
|
|
|
|
if os.path.exists(os.path.join(self.cfg.outdir, 'tmp_feat.memmap')):
|
|
os.remove(os.path.join(self.cfg.outdir, 'tmp_feat.memmap'))
|
|
|
|
def _get_models(self):
|
|
for client_id in range(1, self.num_clients + 1):
|
|
self.models[client_id] = self._load_model(ATCModel(self.cfg.model),
|
|
client_id)
|
|
|
|
def _get_avg_lm_head(self):
|
|
all_params = copy.deepcopy([
|
|
self.models[k].lm_head.state_dict()
|
|
for k in range(1, self.num_clients + 1)
|
|
])
|
|
avg_param = all_params[0]
|
|
for k in avg_param:
|
|
for i in range(len(all_params)):
|
|
local_param = all_params[i][k].float()
|
|
if i == 0:
|
|
avg_param[k] = local_param / len(all_params)
|
|
else:
|
|
avg_param[k] += local_param / len(all_params)
|
|
avg_lm_head = copy.deepcopy(self.models[1].lm_head)
|
|
avg_lm_head.load_state_dict(avg_param)
|
|
return avg_lm_head
|
|
|
|
def _load_model(self, model, client_id):
|
|
global_dir = os.path.join(self.pretrain_dir, 'global')
|
|
client_dir = os.path.join(self.pretrain_dir, 'client')
|
|
global_ckpt_path = os.path.join(global_dir,
|
|
'global_model_{}.pt'.format(client_id))
|
|
client_ckpt_path = os.path.join(client_dir,
|
|
'client_model_{}.pt'.format(client_id))
|
|
if os.path.exists(global_ckpt_path):
|
|
model_ckpt = model.state_dict()
|
|
logger.info('Loading model from \'{}\''.format(global_ckpt_path))
|
|
global_ckpt = torch.load(global_ckpt_path,
|
|
map_location='cpu')['model']
|
|
model_ckpt.update(global_ckpt)
|
|
if os.path.exists(client_ckpt_path):
|
|
logger.info(
|
|
'Updating model from \'{}\''.format(client_ckpt_path))
|
|
client_ckpt = torch.load(client_ckpt_path,
|
|
map_location='cpu')['model']
|
|
model_ckpt.update(client_ckpt)
|
|
model.load_state_dict(model_ckpt)
|
|
return model
|