FS-TFP/tests/test_yaml.py

75 lines
2.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import logging
import unittest
from federatedscope.core.configs.config import global_cfg
logger = logging.getLogger(__name__)
class YAMLTest(unittest.TestCase):
def setUp(self):
self.exclude_all = ['benchmark', 'scripts', 'federatedscope/autotune']
self.exclude_file = [
'.pre-commit-config.yaml', 'meta.yaml',
'federatedscope/gfl/baseline/isolated_gin_minibatch_on_cikmcup_per_client.yaml',
'federatedscope/gfl/baseline/fedavg_gin_minibatch_on_cikmcup_per_client.yaml',
'federatedscope/gfl/baseline/mini_graph_dc/fedavg_per_client.yaml'
]
self.exclude_str = ['config.yaml', 'config_client']
self.root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
self.exclude_all = [
os.path.join(self.root, f) for f in self.exclude_all
]
self.exclude_file = [
os.path.join(self.root, f) for f in self.exclude_file
]
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_yaml(self):
init_cfg = global_cfg.clone()
sign, cont = False, False
error_file = []
for dirpath, _, filenames in os.walk(self.root):
for prefix in self.exclude_all:
if dirpath.startswith(prefix):
cont = True
break
if cont:
cont = False
continue
filenames = [f for f in filenames if f.endswith('.yaml')]
for f in filenames:
yaml_file = os.path.join(dirpath, f)
# Ignore `ss` search space and yaml file in `exclude_file`
if yaml_file in self.exclude_file or 'ss' in yaml_file:
continue
exclude = False
for s in self.exclude_str:
if s in yaml_file:
exclude = True
break
if exclude:
continue
try:
init_cfg.merge_from_file(yaml_file)
except KeyError as error:
error_file.append(yaml_file.removeprefix(self.root))
logger.error(
f"KeyError: {error} in file: {yaml_file.removeprefix(self.root)}"
)
sign = True
except ValueError as error:
error_file.append(yaml_file.removeprefix(self.root))
logger.error(
f"ValueError: {error} in file: {yaml_file.removeprefix(self.root)}"
)
sign = True
init_cfg = global_cfg.clone()
self.assertIs(sign, False, f'Yaml check failed in {error_file}.')
if __name__ == '__main__':
unittest.main()