mirror of https://github.com/czzhangheng/STDEN.git
38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import yaml
|
|
|
|
from lib.utils import load_graph_data
|
|
from model.stden_supervisor import STDENSupervisor
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
def main(args):
|
|
with open(args.config_filename) as f:
|
|
supervisor_config = yaml.load(f)
|
|
|
|
graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename')
|
|
adj_mx = load_graph_data(graph_pkl_filename)
|
|
|
|
supervisor = STDENSupervisor(adj_mx=adj_mx, **supervisor_config)
|
|
|
|
supervisor.train()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--config_filename', default=None, type=str,
|
|
help='Configuration filename for restoring the model.')
|
|
parser.add_argument('--use_cpu_only', default=False, type=bool, help='Set to true to only use cpu.')
|
|
parser.add_argument('-r', '--random_seed', type=int, default=2021, help="Random seed for reproduction.")
|
|
args = parser.parse_args()
|
|
|
|
torch.manual_seed(args.random_seed)
|
|
np.random.seed(args.random_seed)
|
|
|
|
main(args)
|