68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
import logging
|
|
import numpy as np
|
|
|
|
from rdkit import Chem
|
|
from rdkit import RDLogger
|
|
from rdkit.Chem.Scaffolds import MurckoScaffold
|
|
|
|
from federatedscope.core.splitters import BaseSplitter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
RDLogger.DisableLog('rdApp.*')
|
|
|
|
|
|
def generate_scaffold(smiles, include_chirality=False):
|
|
"""return scaffold string of target molecule"""
|
|
mol = Chem.MolFromSmiles(smiles)
|
|
scaffold = MurckoScaffold\
|
|
.MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
|
|
return scaffold
|
|
|
|
|
|
def gen_scaffold_split(dataset, client_num=5):
|
|
r"""
|
|
return dict{ID:[idxs]}
|
|
"""
|
|
logger.info('Scaffold split might take minutes, please wait...')
|
|
scaffolds = {}
|
|
for idx, data in enumerate(dataset):
|
|
smiles = data.smiles
|
|
_ = Chem.MolFromSmiles(smiles)
|
|
scaffold = generate_scaffold(smiles)
|
|
if scaffold not in scaffolds:
|
|
scaffolds[scaffold] = [idx]
|
|
else:
|
|
scaffolds[scaffold].append(idx)
|
|
# Sort from largest to smallest scaffold sets
|
|
scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
|
|
scaffold_list = [
|
|
list(scaffold_set)
|
|
for (scaffold,
|
|
scaffold_set) in sorted(scaffolds.items(),
|
|
key=lambda x: (len(x[1]), x[1][0]),
|
|
reverse=True)
|
|
]
|
|
scaffold_idxs = sum(scaffold_list, [])
|
|
# Split data to list
|
|
splits = np.array_split(scaffold_idxs, client_num)
|
|
return [splits[ID] for ID in range(client_num)]
|
|
|
|
|
|
class ScaffoldSplitter(BaseSplitter):
|
|
"""
|
|
Split molecular via scaffold. This splitter will sort all moleculars, and \
|
|
split them into several parts.
|
|
|
|
Arguments:
|
|
client_num (int): Split data into client_num of pieces.
|
|
"""
|
|
def __init__(self, client_num):
|
|
super(ScaffoldSplitter, self).__init__(client_num)
|
|
|
|
def __call__(self, dataset, **kwargs):
|
|
dataset = [ds for ds in dataset]
|
|
idx_slice = gen_scaffold_split(dataset)
|
|
data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice]
|
|
return data_list
|