from model.STGNRDE.GRDE import NeuralGCDE from model.STGNRDE.vector_fields import FinalTanh_f, VectorField_g def make_model(args): vector_field_f = FinalTanh_f( input_channels=args['input_dim'], hidden_channels=args['hid_dim'], hidden_hidden_channels=args['hid_hid_dim'], num_hidden_layers=args['num_layers'] ) vector_field_g = VectorField_g( input_channels=args['input_dim'], hidden_channels=args['hid_dim'], hidden_hidden_channels=args['hid_hid_dim'], num_hidden_layers=args['num_layers'], num_nodes=args['num_nodes'], cheb_k=args['cheb_k'], embed_dim=args['embed_dim'], g_type=args['g_type'] ) model = NeuralGCDE( args, func_f=vector_field_f, func_g=vector_field_g, input_channels=args['input_dim'], hidden_channels=args['hid_dim'], output_channels=args['output_dim'], initial=True, device=args['device'], atol=1e-9, rtol=1e-7, solver=args['solver'] ) return model