161 lines
5.3 KiB
Python
161 lines
5.3 KiB
Python
import os
|
|
import shlex
|
|
import logging
|
|
import paramiko
|
|
|
|
from collections.abc import MutableMapping
|
|
|
|
from federatedscope.core.cmd_args import parse_args
|
|
from federatedscope.core.configs.config import global_cfg
|
|
from federatedscope.core.auxiliaries.data_builder import get_data
|
|
from federatedscope.organizer.cfg_client import env_name, root_path, fs_version
|
|
|
|
logger = logging.getLogger('federatedscope')
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
class SSHManager(object):
|
|
def __init__(self, ip, user, psw, ssh_port=22):
|
|
self.ip, self.user, self.psw = ip, user, psw
|
|
self.ssh_port = ssh_port
|
|
self.ssh, self.trans = None, None
|
|
self.setup_fs()
|
|
self.tasks = set()
|
|
|
|
def _connect(self):
|
|
self.trans = paramiko.Transport((self.ip, self.ssh_port))
|
|
self.trans.connect(username=self.user, password=self.psw)
|
|
self.ssh = paramiko.SSHClient()
|
|
self.ssh._transport = self.trans
|
|
self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
|
|
def _disconnect(self):
|
|
self.trans.close()
|
|
|
|
def _exec_cmd(self, command):
|
|
if self.trans is None or self.ssh is None:
|
|
self._connect()
|
|
command = f'source ~/.bashrc; cd ~; {command}'
|
|
_, stdout, stderr = self.ssh.exec_command(command)
|
|
stdout = stdout.read().decode('ascii').strip("\n")
|
|
stderr = stderr.read().decode('ascii').strip("\n")
|
|
return stdout, stderr
|
|
|
|
def _exec_python_cmd(self, command):
|
|
if self.trans is None or self.ssh is None:
|
|
self._connect()
|
|
command = f'source ~/.bashrc; conda activate {env_name}; ' \
|
|
f'cd ~/{root_path}/FederatedScope; {command}'
|
|
_, stdout, stderr = self.ssh.exec_command(command)
|
|
stdout = stdout.read().decode('ascii').strip("\n")
|
|
stderr = stderr.read().decode('ascii').strip("\n")
|
|
return stdout, stderr
|
|
|
|
def _check_conda(self):
|
|
"""
|
|
Check and install conda env.
|
|
"""
|
|
# Check conda
|
|
conda, _ = self._exec_cmd('which conda')
|
|
if conda is None:
|
|
logger.exception('No conda, please install conda first.')
|
|
# TODO: Install conda here
|
|
return False
|
|
|
|
# Check env & FS
|
|
output, err = self._exec_cmd(f'conda activate {env_name}; '
|
|
f'python -c "import federatedscope; '
|
|
f'print(federatedscope.__version__)"')
|
|
if err:
|
|
logger.error(err)
|
|
# TODO: Install FS env here
|
|
return False
|
|
elif output != fs_version:
|
|
logger.warning(f'The installed FS version is {output}, however'
|
|
f' {fs_version} is required.')
|
|
logger.info(f'Conda environment found, named {env_name}.')
|
|
return True
|
|
|
|
def _check_source(self):
|
|
"""
|
|
Check and download FS repo.
|
|
"""
|
|
fs_path = os.path.join(root_path, 'FederatedScope', '.git')
|
|
output, _ = self._exec_cmd(f'[ -d {fs_path} ] && echo "Found" || '
|
|
f'echo "Not found"')
|
|
if output == 'Not found':
|
|
# TODO: git clone here
|
|
logger.exception(f'Repo not find in {fs_path}.')
|
|
return False
|
|
logger.info(f'FS repo Found in {root_path}.')
|
|
return True
|
|
|
|
def _check_task_status(self):
|
|
"""
|
|
Check task status.
|
|
"""
|
|
pass
|
|
|
|
def setup_fs(self):
|
|
logger.info("Checking environment, please wait...")
|
|
if not self._check_conda():
|
|
raise Exception('The environment is not configured properly.')
|
|
if not self._check_source():
|
|
raise Exception('The FS repo is not configured properly.')
|
|
|
|
def launch_task(self, command):
|
|
self._check_task_status()
|
|
stdout, _ = self._exec_python_cmd(f'nohup python '
|
|
f'federatedscope/main.py {command} '
|
|
f'> /dev/null 2>&1 & echo $!')
|
|
self.tasks.add(stdout)
|
|
return stdout
|
|
|
|
|
|
def anonymize(info, mask):
|
|
for key, value in info.items():
|
|
if key == mask:
|
|
info[key] = "******"
|
|
elif isinstance(value, dict):
|
|
anonymize(info[key], mask)
|
|
return info
|
|
|
|
|
|
def args2yaml(args):
|
|
init_cfg = global_cfg.clone()
|
|
args = parse_args(shlex.split(args))
|
|
init_cfg.merge_from_file(args.cfg_file)
|
|
init_cfg.merge_from_list(args.opts)
|
|
_, modified_cfg = get_data(config=init_cfg.clone())
|
|
init_cfg.merge_from_other_cfg(modified_cfg)
|
|
init_cfg.freeze(inform=False, save=False)
|
|
init_cfg.defrost()
|
|
return init_cfg
|
|
|
|
|
|
def flatten_dict(d, parent_key='', sep='.'):
|
|
items = []
|
|
for key, value in d.items():
|
|
new_key = parent_key + sep + key if parent_key else key
|
|
if isinstance(value, MutableMapping):
|
|
items.extend(flatten_dict(value, new_key, sep=sep).items())
|
|
else:
|
|
items.append((new_key, value))
|
|
return dict(items)
|
|
|
|
|
|
def config2cmdargs(config):
|
|
'''
|
|
Arguments:
|
|
config (dict): key is cfg node name, value is the specified value.
|
|
Returns:
|
|
results (list): cmd args
|
|
'''
|
|
|
|
results = []
|
|
for key, value in config.items():
|
|
if value:
|
|
results.append(key)
|
|
results.append(value)
|
|
return results
|