FS-TFP/federatedscope/organizer/utils.py

161 lines
5.3 KiB
Python

import os
import shlex
import logging
import paramiko
from collections.abc import MutableMapping
from federatedscope.core.cmd_args import parse_args
from federatedscope.core.configs.config import global_cfg
from federatedscope.core.auxiliaries.data_builder import get_data
from federatedscope.organizer.cfg_client import env_name, root_path, fs_version
logger = logging.getLogger('federatedscope')
logger.setLevel(logging.INFO)
class SSHManager(object):
def __init__(self, ip, user, psw, ssh_port=22):
self.ip, self.user, self.psw = ip, user, psw
self.ssh_port = ssh_port
self.ssh, self.trans = None, None
self.setup_fs()
self.tasks = set()
def _connect(self):
self.trans = paramiko.Transport((self.ip, self.ssh_port))
self.trans.connect(username=self.user, password=self.psw)
self.ssh = paramiko.SSHClient()
self.ssh._transport = self.trans
self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
def _disconnect(self):
self.trans.close()
def _exec_cmd(self, command):
if self.trans is None or self.ssh is None:
self._connect()
command = f'source ~/.bashrc; cd ~; {command}'
_, stdout, stderr = self.ssh.exec_command(command)
stdout = stdout.read().decode('ascii').strip("\n")
stderr = stderr.read().decode('ascii').strip("\n")
return stdout, stderr
def _exec_python_cmd(self, command):
if self.trans is None or self.ssh is None:
self._connect()
command = f'source ~/.bashrc; conda activate {env_name}; ' \
f'cd ~/{root_path}/FederatedScope; {command}'
_, stdout, stderr = self.ssh.exec_command(command)
stdout = stdout.read().decode('ascii').strip("\n")
stderr = stderr.read().decode('ascii').strip("\n")
return stdout, stderr
def _check_conda(self):
"""
Check and install conda env.
"""
# Check conda
conda, _ = self._exec_cmd('which conda')
if conda is None:
logger.exception('No conda, please install conda first.')
# TODO: Install conda here
return False
# Check env & FS
output, err = self._exec_cmd(f'conda activate {env_name}; '
f'python -c "import federatedscope; '
f'print(federatedscope.__version__)"')
if err:
logger.error(err)
# TODO: Install FS env here
return False
elif output != fs_version:
logger.warning(f'The installed FS version is {output}, however'
f' {fs_version} is required.')
logger.info(f'Conda environment found, named {env_name}.')
return True
def _check_source(self):
"""
Check and download FS repo.
"""
fs_path = os.path.join(root_path, 'FederatedScope', '.git')
output, _ = self._exec_cmd(f'[ -d {fs_path} ] && echo "Found" || '
f'echo "Not found"')
if output == 'Not found':
# TODO: git clone here
logger.exception(f'Repo not find in {fs_path}.')
return False
logger.info(f'FS repo Found in {root_path}.')
return True
def _check_task_status(self):
"""
Check task status.
"""
pass
def setup_fs(self):
logger.info("Checking environment, please wait...")
if not self._check_conda():
raise Exception('The environment is not configured properly.')
if not self._check_source():
raise Exception('The FS repo is not configured properly.')
def launch_task(self, command):
self._check_task_status()
stdout, _ = self._exec_python_cmd(f'nohup python '
f'federatedscope/main.py {command} '
f'> /dev/null 2>&1 & echo $!')
self.tasks.add(stdout)
return stdout
def anonymize(info, mask):
for key, value in info.items():
if key == mask:
info[key] = "******"
elif isinstance(value, dict):
anonymize(info[key], mask)
return info
def args2yaml(args):
init_cfg = global_cfg.clone()
args = parse_args(shlex.split(args))
init_cfg.merge_from_file(args.cfg_file)
init_cfg.merge_from_list(args.opts)
_, modified_cfg = get_data(config=init_cfg.clone())
init_cfg.merge_from_other_cfg(modified_cfg)
init_cfg.freeze(inform=False, save=False)
init_cfg.defrost()
return init_cfg
def flatten_dict(d, parent_key='', sep='.'):
items = []
for key, value in d.items():
new_key = parent_key + sep + key if parent_key else key
if isinstance(value, MutableMapping):
items.extend(flatten_dict(value, new_key, sep=sep).items())
else:
items.append((new_key, value))
return dict(items)
def config2cmdargs(config):
'''
Arguments:
config (dict): key is cfg node name, value is the specified value.
Returns:
results (list): cmd args
'''
results = []
for key, value in config.items():
if value:
results.append(key)
results.append(value)
return results