parent
6ea133716f
commit
1ee3b39de2
|
|
@ -0,0 +1,73 @@
|
|||
use_gpu: True
|
||||
seed: 10
|
||||
device: 1
|
||||
early_stop:
|
||||
patience: 60
|
||||
federate:
|
||||
mode: standalone
|
||||
total_round_num: 70
|
||||
client_num: 10
|
||||
data:
|
||||
root: data/trafficflow/PeMS03
|
||||
type: trafficflow
|
||||
splitter: trafficflowprediction
|
||||
num_nodes: 358
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
dataloader:
|
||||
type: trafficflow
|
||||
batch_size: 64
|
||||
drop_last: True
|
||||
model:
|
||||
type: FedDGCN
|
||||
task: TrafficFlowPrediction
|
||||
dropout: 0.1
|
||||
horizon: 12
|
||||
num_nodes: 0
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
embed_dim: 10
|
||||
rnn_units: 64
|
||||
num_layers: 1
|
||||
cheb_order: 2
|
||||
use_day: True
|
||||
use_week: True
|
||||
train:
|
||||
batch_or_epoch: 'epoch'
|
||||
local_update_steps: 1
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.003
|
||||
weight_decay: 0.0
|
||||
batch_size: 64
|
||||
epochs: 300
|
||||
lr_init: 0.003
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
early_stop: False
|
||||
early_stop_patience: 15
|
||||
grad_norm: True
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
criterion:
|
||||
type: RMSE
|
||||
trainer:
|
||||
type: trafficflowtrainer
|
||||
log_dir: ./
|
||||
grad:
|
||||
grad_clip: 5.0
|
||||
eval:
|
||||
metrics: ['avg_loss']
|
||||
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
use_gpu: True
|
||||
seed: 10
|
||||
device: 0
|
||||
early_stop:
|
||||
patience: 60
|
||||
federate:
|
||||
mode: standalone
|
||||
total_round_num: 70
|
||||
client_num: 10
|
||||
#personalization:
|
||||
# local_param: ['D_i_W_emb', "T_i_D_emb", "encoder1", "encoder2", "node_embeddings1", "node_embeddings2"]
|
||||
data:
|
||||
root: data/trafficflow/PeMS04
|
||||
type: trafficflow
|
||||
splitter: trafficflowprediction
|
||||
num_nodes: 307
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
dataloader:
|
||||
type: trafficflow
|
||||
batch_size: 64
|
||||
drop_last: True
|
||||
model:
|
||||
type: FedDGCN
|
||||
task: TrafficFlowPrediction
|
||||
dropout: 0.1
|
||||
horizon: 12
|
||||
num_nodes: 0
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
embed_dim: 10
|
||||
rnn_units: 64
|
||||
num_layers: 1
|
||||
cheb_order: 2
|
||||
use_day: True
|
||||
use_week: True
|
||||
train:
|
||||
batch_or_epoch: 'epoch'
|
||||
local_update_steps: 1
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.003
|
||||
weight_decay: 0.0
|
||||
batch_size: 64
|
||||
epochs: 300
|
||||
lr_init: 0.003
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
early_stop: False
|
||||
early_stop_patience: 15
|
||||
grad_norm: True
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
criterion:
|
||||
type: RMSE
|
||||
trainer:
|
||||
type: trafficflowtrainer
|
||||
log_dir: ./
|
||||
grad:
|
||||
grad_clip: 5.0
|
||||
eval:
|
||||
metrics: ['avg_loss']
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
use_gpu: True
|
||||
seed: 10
|
||||
device: 0
|
||||
early_stop:
|
||||
patience: 60
|
||||
federate:
|
||||
mode: standalone
|
||||
total_round_num: 70
|
||||
client_num: 10
|
||||
data:
|
||||
root: data/trafficflow/PeMS07
|
||||
type: trafficflow
|
||||
splitter: trafficflowprediction
|
||||
num_nodes: 883
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
dataloader:
|
||||
type: trafficflow
|
||||
batch_size: 16
|
||||
drop_last: True
|
||||
model:
|
||||
type: FedDGCN
|
||||
task: TrafficFlowPrediction
|
||||
dropout: 0.1
|
||||
horizon: 12
|
||||
num_nodes: 0
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
embed_dim: 10
|
||||
rnn_units: 64
|
||||
num_layers: 1
|
||||
cheb_order: 2
|
||||
use_day: True
|
||||
use_week: True
|
||||
train:
|
||||
batch_or_epoch: 'epoch'
|
||||
local_update_steps: 1
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.0015
|
||||
weight_decay: 0.0
|
||||
batch_size: 16
|
||||
epochs: 250
|
||||
lr_init: 0.0015
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
early_stop: False
|
||||
early_stop_patience: 15
|
||||
grad_norm: True
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
criterion:
|
||||
type: RMSE
|
||||
trainer:
|
||||
type: trafficflowtrainer
|
||||
log_dir: ./
|
||||
grad:
|
||||
grad_clip: 5.0
|
||||
eval:
|
||||
metrics: ['avg_loss']
|
||||
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
use_gpu: True
|
||||
seed: 10
|
||||
device: 0
|
||||
early_stop:
|
||||
patience: 60
|
||||
federate:
|
||||
mode: standalone
|
||||
total_round_num: 70
|
||||
client_num: 10
|
||||
data:
|
||||
root: data/trafficflow/PeMS07
|
||||
type: trafficflow
|
||||
splitter: trafficflowprediction
|
||||
num_nodes: 883
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
dataloader:
|
||||
type: trafficflow
|
||||
batch_size: 16
|
||||
drop_last: True
|
||||
model:
|
||||
type: FedDGCN
|
||||
task: TrafficFlowPrediction
|
||||
dropout: 0.1
|
||||
horizon: 12
|
||||
num_nodes: 0
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
embed_dim: 10
|
||||
rnn_units: 64
|
||||
num_layers: 1
|
||||
cheb_order: 2
|
||||
use_day: True
|
||||
use_week: True
|
||||
train:
|
||||
batch_or_epoch: 'epoch'
|
||||
local_update_steps: 1
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.0015
|
||||
weight_decay: 0.0
|
||||
batch_size: 16
|
||||
epochs: 250
|
||||
lr_init: 0.0015
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
early_stop: False
|
||||
early_stop_patience: 15
|
||||
grad_norm: True
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
criterion:
|
||||
type: L1Loss
|
||||
trainer:
|
||||
type: trafficflowtrainer
|
||||
log_dir: ./
|
||||
grad:
|
||||
grad_clip: 5.0
|
||||
eval:
|
||||
metrics: ['avg_loss']
|
||||
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
use_gpu: True
|
||||
seed: 10
|
||||
device: 0
|
||||
early_stop:
|
||||
patience: 60
|
||||
federate:
|
||||
mode: standalone
|
||||
total_round_num: 70
|
||||
client_num: 10
|
||||
data:
|
||||
root: data/trafficflow/PeMS08
|
||||
type: trafficflow
|
||||
splitter: trafficflowprediction
|
||||
num_nodes: 170
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
dataloader:
|
||||
type: trafficflow
|
||||
batch_size: 64
|
||||
drop_last: True
|
||||
model:
|
||||
type: FedDGCN
|
||||
task: TrafficFlowPrediction
|
||||
dropout: 0.1
|
||||
horizon: 12
|
||||
num_nodes: 0
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
embed_dim: 10
|
||||
rnn_units: 64
|
||||
num_layers: 1
|
||||
cheb_order: 2
|
||||
use_day: True
|
||||
use_week: True
|
||||
train:
|
||||
batch_or_epoch: 'epoch'
|
||||
local_update_steps: 1
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.01
|
||||
weight_decay: 0.0
|
||||
batch_size: 64
|
||||
epochs: 300
|
||||
lr_init: 0.003
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
early_stop: False
|
||||
early_stop_patience: 15
|
||||
grad_norm: True
|
||||
real_value: True
|
||||
criterion:
|
||||
type: RMSE
|
||||
trainer:
|
||||
type: trafficflowtrainer
|
||||
log_dir: ./
|
||||
grad:
|
||||
grad_clip: 5.0
|
||||
eval:
|
||||
metrics: ['avg_loss']
|
||||
|
||||
|
|
@ -0,0 +1,196 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from federatedscope.trafficflow.dataset.add_window import add_window_horizon
|
||||
from federatedscope.trafficflow.dataset.normalization import (
|
||||
NScaler, MinMax01Scaler, MinMax11Scaler, StandardScaler, ColumnMinMaxScaler)
|
||||
from federatedscope.trafficflow.dataset.traffic_dataset import load_st_dataset
|
||||
def normalize_dataset(data, normalizer, column_wise=False):
|
||||
if normalizer == 'max01':
|
||||
if column_wise:
|
||||
minimum = data.min(axis=0, keepdims=True)
|
||||
maximum = data.max(axis=0, keepdims=True)
|
||||
else:
|
||||
minimum = data.min()
|
||||
maximum = data.max()
|
||||
scaler = MinMax01Scaler(minimum, maximum)
|
||||
data = scaler.transform(data)
|
||||
print('Normalize the dataset by MinMax01 Normalization')
|
||||
elif normalizer == 'max11':
|
||||
if column_wise:
|
||||
minimum = data.min(axis=0, keepdims=True)
|
||||
maximum = data.max(axis=0, keepdims=True)
|
||||
else:
|
||||
minimum = data.min()
|
||||
maximum = data.max()
|
||||
scaler = MinMax11Scaler(minimum, maximum)
|
||||
data = scaler.transform(data)
|
||||
print('Normalize the dataset by MinMax11 Normalization')
|
||||
elif normalizer == 'std':
|
||||
if column_wise:
|
||||
mean = data.mean(axis=0, keepdims=True)
|
||||
std = data.std(axis=0, keepdims=True)
|
||||
else:
|
||||
mean = data.mean()
|
||||
std = data.std()
|
||||
scaler = StandardScaler(mean, std)
|
||||
# data = scaler.transform(data)
|
||||
print('Normalize the dataset by Standard Normalization')
|
||||
elif normalizer == 'None':
|
||||
scaler = NScaler()
|
||||
data = scaler.transform(data)
|
||||
print('Does not normalize the dataset')
|
||||
elif normalizer == 'cmax':
|
||||
#column min max, to be depressed
|
||||
#note: axis must be the spatial dimension, please check !
|
||||
scaler = ColumnMinMaxScaler(data.min(axis=0), data.max(axis=0))
|
||||
data = scaler.transform(data)
|
||||
print('Normalize the dataset by Column Min-Max Normalization')
|
||||
else:
|
||||
raise ValueError
|
||||
return scaler
|
||||
|
||||
|
||||
def split_data_by_days(data, val_days, test_days, interval=30):
|
||||
"""
|
||||
:param data: [B, *]
|
||||
:param val_days:
|
||||
:param test_days:
|
||||
:param interval: interval (15, 30, 60) minutes
|
||||
:return:
|
||||
"""
|
||||
t = int((24 * 60) / interval)
|
||||
x = -t * int(test_days)
|
||||
test_data = data[-t * int(test_days):]
|
||||
val_data = data[-t * int(test_days + val_days): -t * int(test_days)]
|
||||
train_data = data[:-t * int(test_days + val_days)]
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def split_data_by_ratio(data, val_ratio, test_ratio):
|
||||
data_len = data.shape[0]
|
||||
test_data = data[-int(data_len * test_ratio):]
|
||||
val_data = data[-int(data_len * (test_ratio + val_ratio)):-int(data_len * test_ratio)]
|
||||
train_data = data[:-int(data_len * (test_ratio + val_ratio))]
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def data_loader(X, Y, batch_size, shuffle=True, drop_last=True):
|
||||
cuda = True if torch.cuda.is_available() else False
|
||||
TensorFloat = torch.cuda.FloatTensor if cuda else torch.FloatTensor
|
||||
X, Y = TensorFloat(X), TensorFloat(Y)
|
||||
data = torch.utils.data.TensorDataset(X, Y)
|
||||
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size,
|
||||
shuffle=shuffle, drop_last=drop_last)
|
||||
return dataloader
|
||||
|
||||
|
||||
def load_traffic_data(config, client_cfgs):
|
||||
root = config.data.root
|
||||
dataName = 'PEMSD' + root[-1]
|
||||
raw_data = load_st_dataset(dataName)
|
||||
|
||||
|
||||
l, n, f = raw_data.shape
|
||||
|
||||
feature_list = [raw_data]
|
||||
|
||||
|
||||
# numerical time_in_day
|
||||
time_ind = [i % config.data.steps_per_day / config.data.steps_per_day for i in range(raw_data.shape[0])]
|
||||
time_ind = np.array(time_ind)
|
||||
time_in_day = np.tile(time_ind, [1, n, 1]).transpose((2, 1, 0))
|
||||
feature_list.append(time_in_day)
|
||||
|
||||
# numerical day_in_week
|
||||
day_in_week = [(i // config.data.steps_per_day) % config.data.days_per_week for i in range(raw_data.shape[0])]
|
||||
day_in_week = np.array(day_in_week)
|
||||
day_in_week = np.tile(day_in_week, [1, n, 1]).transpose((2, 1, 0))
|
||||
feature_list.append(day_in_week)
|
||||
|
||||
# data = np.concatenate(feature_list, axis=-1)
|
||||
single = False
|
||||
x, y = add_window_horizon(raw_data, config.data.lag, config.data.horizon, single)
|
||||
x_day, y_day = add_window_horizon(time_in_day, config.data.lag, config.data.horizon, single)
|
||||
x_week, y_week = add_window_horizon(day_in_week, config.data.lag, config.data.horizon, single)
|
||||
x, y = np.concatenate([x, x_day, x_week], axis=-1), np.concatenate([y, y_day, y_week], axis=-1)
|
||||
|
||||
# split dataset by days or by ratio
|
||||
if config.data.test_ratio > 1:
|
||||
x_train, x_val, x_test = split_data_by_days(x, config.data.val_ratio, config.data.test_ratio)
|
||||
y_train, y_val, y_test = split_data_by_days(y, config.data.val_ratio, config.data.test_ratio)
|
||||
else:
|
||||
x_train, x_val, x_test = split_data_by_ratio(x, config.data.val_ratio, config.data.test_ratio)
|
||||
y_train, y_val, y_test = split_data_by_ratio(y, config.data.val_ratio, config.data.test_ratio)
|
||||
|
||||
# normalize st data
|
||||
normalizer = 'std'
|
||||
scaler = normalize_dataset(x_train[..., :config.model.input_dim], normalizer, config.data.column_wise)
|
||||
config.data.scaler = [float(scaler.mean), float(scaler.std)]
|
||||
|
||||
x_train[..., :config.model.input_dim] = scaler.transform(x_train[..., :config.model.input_dim])
|
||||
x_val[..., :config.model.input_dim] = scaler.transform(x_val[..., :config.model.input_dim])
|
||||
x_test[..., :config.model.input_dim] = scaler.transform(x_test[..., :config.model.input_dim])
|
||||
# y_train[..., :config.model.output_dim] = scaler.transform(y_train[..., :config.model.output_dim])
|
||||
# y_val[..., :config.model.output_dim] = scaler.transform(y_val[..., :config.model.output_dim])
|
||||
# y_test[..., :config.model.output_dim] = scaler.transform(y_test[..., :config.model.output_dim])
|
||||
|
||||
# 客户端分割数据集
|
||||
node_num = config.data.num_nodes
|
||||
client_num = config.federate.client_num
|
||||
per_samples = node_num // client_num
|
||||
data_list, cur_index = dict(), 0
|
||||
input_dim, output_dim = config.model.input_dim, config.model.output_dim
|
||||
for i in range(client_num):
|
||||
if cur_index + per_samples <= node_num:
|
||||
# 正常截取
|
||||
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
|
||||
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
|
||||
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]
|
||||
|
||||
sub_y_train = y_train[:, :, cur_index:cur_index + per_samples, :output_dim]
|
||||
sub_y_val = y_val[:, :, cur_index:cur_index + per_samples, :output_dim]
|
||||
sub_y_test = y_test[:, :, cur_index:cur_index + per_samples, :output_dim]
|
||||
else:
|
||||
# 不足一个per_samples,补0列
|
||||
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
|
||||
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
|
||||
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]
|
||||
padding = np.zeros((x_train.shape[0], config.data.lag ,config.data.lag, per_samples - x_train.shape[1], config.model.output_dim))
|
||||
sub_array_train = np.concatenate((sub_array_train, padding), axis=2)
|
||||
sub_array_val = np.concatenate((sub_array_val, padding), axis=2)
|
||||
sub_array_test = np.concatenate((sub_array_test, padding), axis=2)
|
||||
|
||||
sub_y_train = y_train[:, :, cur_index:cur_index + per_samples, :]
|
||||
sub_y_val = y_val[:, :, cur_index:cur_index + per_samples, :]
|
||||
sub_y_test = y_test[:, :, cur_index:cur_index + per_samples, :]
|
||||
sub_y_train = np.concatenate((sub_y_train, padding), axis=2)
|
||||
sub_y_val = np.concatenate((sub_y_val, padding), axis=2)
|
||||
sub_y_test = np.concatenate((sub_y_test, padding), axis=2)
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
data_list[i + 1] = {
|
||||
'train': torch.utils.data.TensorDataset(
|
||||
torch.tensor(sub_array_train, dtype=torch.float, device=device),
|
||||
torch.tensor(sub_y_train, dtype=torch.float, device=device)
|
||||
),
|
||||
'val': torch.utils.data.TensorDataset(
|
||||
torch.tensor(sub_array_val, dtype=torch.float, device=device),
|
||||
torch.tensor(sub_y_val, dtype=torch.float, device=device)
|
||||
),
|
||||
'test': torch.utils.data.TensorDataset(
|
||||
torch.tensor(sub_array_test, dtype=torch.float, device=device),
|
||||
torch.tensor(sub_y_test, dtype=torch.float, device=device)
|
||||
)
|
||||
}
|
||||
cur_index += per_samples
|
||||
config.model.num_nodes = per_samples
|
||||
return data_list, config
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
a = 'data/trafficflow/PeMS04'
|
||||
name = 'PEMSD' + a[-1]
|
||||
raw_data = load_st_dataset(name)
|
||||
pass
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def add_window_horizon(data, window=3, horizon=1, single=False):
|
||||
"""
|
||||
:param data: shape [B, ...]
|
||||
:param window:
|
||||
:param horizon:
|
||||
:param single:
|
||||
:return: X is [B, W, ...], Y is [B, H, ...]
|
||||
"""
|
||||
length = len(data)
|
||||
end_index = length - horizon - window + 1
|
||||
x = [] # windows
|
||||
y = [] # horizon
|
||||
index = 0
|
||||
if single:
|
||||
while index < end_index:
|
||||
x.append(data[index:index + window])
|
||||
y.append(data[index + window + horizon - 1:index + window + horizon])
|
||||
index = index + 1
|
||||
else:
|
||||
while index < end_index:
|
||||
x.append(data[index:index + window])
|
||||
y.append(data[index + window:index + window + horizon])
|
||||
index = index + 1
|
||||
x = np.array(x)
|
||||
y = np.array(y)
|
||||
return x, y
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# from data.load_raw_data import Load_Sydney_Demand_Data
|
||||
# path = '../data/1h_data_new3.csv'
|
||||
# data = Load_Sydney_Demand_Data(path)
|
||||
# print(data.shape)
|
||||
# X, Y = Add_Window_Horizon(data, horizon=2)
|
||||
# print(X.shape, Y.shape)
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class NScaler(object):
|
||||
def transform(self, data):
|
||||
return data
|
||||
|
||||
def inverse_transform(self, data):
|
||||
return data
|
||||
|
||||
|
||||
class StandardScaler:
|
||||
"""
|
||||
Standard the input
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def transform(self, data):
|
||||
return (data - self.mean) / self.std
|
||||
|
||||
def inverse_transform(self, data):
|
||||
if type(data) is torch.Tensor and type(self.mean) is np.ndarray:
|
||||
self.std = torch.from_numpy(self.std).to(data.device).type(data.dtype)
|
||||
self.mean = torch.from_numpy(self.mean).to(data.device).type(data.dtype)
|
||||
return (data * self.std) + self.mean
|
||||
|
||||
|
||||
class MinMax01Scaler:
|
||||
"""
|
||||
Standard the input
|
||||
"""
|
||||
|
||||
def __init__(self, min_value, max_value):
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
def transform(self, data):
|
||||
return (data - self.min_value) / (self.max_value - self.min_value)
|
||||
|
||||
def inverse_transform(self, data):
|
||||
if type(data) is torch.Tensor and type(self.min_value) is np.ndarray:
|
||||
self.min_value = torch.from_numpy(self.min_value).to(data.device).type(data.dtype)
|
||||
self.max_value = torch.from_numpy(self.max_value).to(data.device).type(data.dtype)
|
||||
return data * (self.max_value - self.min_value) + self.min_value
|
||||
|
||||
|
||||
class MinMax11Scaler:
|
||||
"""
|
||||
Standard the input
|
||||
"""
|
||||
|
||||
def __init__(self, min_value, max_value):
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
def transform(self, data):
|
||||
return ((data - self.min_value) / (self.max_value - self.min_value)) * 2. - 1.
|
||||
|
||||
def inverse_transform(self, data):
|
||||
if type(data) is torch.Tensor and type(self.min_value) is np.ndarray:
|
||||
self.min_value = torch.from_numpy(self.min_value).to(data.device).type(data.dtype)
|
||||
self.max_value = torch.from_numpy(self.max_value).to(data.device).type(data.dtype)
|
||||
return ((data + 1.) / 2.) * (self.max_value - self.min_value) + self.min_value
|
||||
|
||||
|
||||
class ColumnMinMaxScaler:
|
||||
# Note: to use this scale, must init the min and max with column min and column max
|
||||
def __init__(self, min_value, max_value):
|
||||
self.min_value = min_value
|
||||
self.min_max = max_value - self.min_value
|
||||
self.min_max[self.min_max == 0] = 1
|
||||
|
||||
def transform(self, data):
|
||||
print(data.shape, self.min_max.shape)
|
||||
return (data - self.min_value) / self.min_max
|
||||
|
||||
def inverse_transform(self, data):
|
||||
if type(data) is torch.Tensor and type(self.min_value) is np.ndarray:
|
||||
self.min_max = torch.from_numpy(self.min_max).to(data.device).type(torch.float32)
|
||||
self.min_value = torch.from_numpy(self.min_value).to(data.device).type(torch.float32)
|
||||
# print(data.dtype, self.min_max.dtype, self.min.dtype)
|
||||
return data * self.min_max + self.min_value
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_data = np.array([[0, 0, 0, 1], [0, 1, 3, 2], [0, 2, 1, 3]])
|
||||
print(test_data)
|
||||
minimum = test_data.min(axis=1)
|
||||
print(minimum, minimum.shape, test_data.shape)
|
||||
maximum = test_data.max(axis=1)
|
||||
print(maximum)
|
||||
print(test_data - minimum)
|
||||
test_data = (test_data - minimum) / (maximum - minimum)
|
||||
print(test_data)
|
||||
print(0 == 0)
|
||||
print(0.00 == 0)
|
||||
print(0 == 0.00)
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
"""
|
||||
return ndarray (time_step, node_num, channel)
|
||||
"""
|
||||
|
||||
|
||||
def load_st_dataset(dataset):
|
||||
# output B, N, D
|
||||
if dataset == 'PEMSD3':
|
||||
data_path = os.path.join('./data/trafficflow/PeMS03/PEMS03.npz')
|
||||
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
|
||||
data = data.astype(np.float32)
|
||||
elif dataset == 'PEMSD4':
|
||||
data_path = os.path.join('./data/trafficflow/PeMS04/PEMS04.npz')
|
||||
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
|
||||
data = data.astype(np.float32)
|
||||
elif dataset == 'PEMSD7':
|
||||
data_path = os.path.join('./data/trafficflow/PeMS07/PeMS07.npz')
|
||||
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
|
||||
data = data.astype(np.float32)
|
||||
elif dataset == 'PEMSD8':
|
||||
data_path = os.path.join('./data/trafficflow/PeMS08/PeMS08.npz')
|
||||
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
|
||||
data = data.astype(np.float32)
|
||||
elif dataset == 'PEMSD7(L)':
|
||||
data_path = os.path.join('./data/trafficflow/PeMS07(L)/PEMS07L.npz')
|
||||
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
|
||||
elif dataset == 'PEMSD7(M)':
|
||||
data_path = os.path.join('./data/trafficflow/PeMS07(M)/V_228.csv')
|
||||
data = np.array(pd.read_csv(data_path, header=None)) # only the first dimension, traffic flow data
|
||||
elif dataset == 'METR-LA':
|
||||
data_path = os.path.join('./data/trafficflow/METR-LA/METR.h5')
|
||||
data = pd.read_hdf(data_path)
|
||||
elif dataset == 'BJ':
|
||||
data_path = os.path.join('./data/trafficflow/BJ/BJ500.csv')
|
||||
data = np.array(pd.read_csv(data_path, header=0, index_col=0))
|
||||
else:
|
||||
raise ValueError
|
||||
if len(data.shape) == 2:
|
||||
data = np.expand_dims(data, axis=-1)
|
||||
print('Load %s Dataset shaped: ' % dataset, data.shape, data.max(), data.min(), data.mean(), np.median(data))
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset = 'PEMSD8'
|
||||
data = load_st_dataset(dataset)
|
||||
print("Finished")
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
from federatedscope.register import register_criterion
|
||||
|
||||
"""
|
||||
Adding RMSE, MAPE for traffic flow prediction
|
||||
"""
|
||||
def TFP_criterion(type, device):
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
except ImportError:
|
||||
nn = None
|
||||
criterion = None
|
||||
|
||||
class RMSELoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(RMSELoss, self).__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
return torch.sqrt(self.mse(y_pred, y_true))
|
||||
|
||||
class MAPELoss(nn.Module):
|
||||
def __init__(self, epsilon=1e-10):
|
||||
super(MAPELoss, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
mask_value = 0.1
|
||||
if mask_value is not None:
|
||||
mask = torch.gt(y_true, mask_value)
|
||||
pred = torch.masked_select(y_pred, mask)
|
||||
true = torch.masked_select(y_true, mask)
|
||||
return torch.mean(torch.abs(torch.div((true - pred), (true + 0.001)))) * 100
|
||||
|
||||
|
||||
if type == 'RMSE':
|
||||
if nn is not None:
|
||||
criterion = RMSELoss().to(device)
|
||||
elif type == 'MAPE':
|
||||
if nn is not None:
|
||||
criterion = MAPELoss().to(device)
|
||||
else:
|
||||
criterion = None
|
||||
|
||||
return criterion
|
||||
|
||||
# Register the custom RMSE and MAPE criterion
|
||||
register_criterion('RMSE', TFP_criterion)
|
||||
register_criterion('MAPE', TFP_criterion)
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class DGCN(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
|
||||
super(DGCN, self).__init__()
|
||||
self.cheb_k = cheb_k
|
||||
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
|
||||
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
|
||||
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
|
||||
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
|
||||
# 初始化参数
|
||||
nn.init.xavier_uniform_(self.weights_pool)
|
||||
nn.init.xavier_uniform_(self.weights)
|
||||
nn.init.zeros_(self.bias_pool)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
self.hyperGNN_dim = 16
|
||||
self.middle_dim = 2
|
||||
self.embed_dim = embed_dim
|
||||
self.fc = nn.Sequential(
|
||||
OrderedDict([('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
|
||||
('sigmoid1', nn.Sigmoid()),
|
||||
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
|
||||
('sigmoid2', nn.Sigmoid()),
|
||||
('fc3', nn.Linear(self.middle_dim, self.embed_dim))]))
|
||||
|
||||
def forward(self, x, node_embeddings):
|
||||
node_num = node_embeddings[0].shape[1]
|
||||
supports1 = torch.eye(node_num).to(node_embeddings[0].device)
|
||||
filter = self.fc(x)
|
||||
nodevec = torch.tanh(torch.mul(node_embeddings[0], filter)) # [B,N,dim_in]
|
||||
graph = F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1)))
|
||||
supports2 = DGCN.get_laplacian(graph, supports1)
|
||||
|
||||
x_g1 = torch.einsum("nm,bmc->bnc", supports1, x)
|
||||
x_g2 = torch.einsum("bnm,bmc->bnc", supports2, x)
|
||||
x_g = torch.stack([x_g1, x_g2], dim=1)
|
||||
|
||||
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
|
||||
bias = torch.matmul(node_embeddings[1], self.bias_pool)
|
||||
|
||||
x_g = x_g.permute(0, 2, 1, 3)
|
||||
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias
|
||||
|
||||
return x_gconv
|
||||
|
||||
@staticmethod
|
||||
def get_laplacian(graph, I, normalize=True):
|
||||
"""
|
||||
return the laplacian of the graph.
|
||||
|
||||
:param graph: the graph structure without self loop, [N, N].
|
||||
:param normalize: whether to used the normalized laplacian.
|
||||
:return: graph laplacian.
|
||||
"""
|
||||
if normalize:
|
||||
epsilon = 1e-6
|
||||
D = torch.diag_embed((torch.sum(graph, dim=-1) + epsilon) ** (-1 / 2))
|
||||
L = torch.matmul(torch.matmul(D, graph), D)
|
||||
else:
|
||||
graph = graph + I
|
||||
D = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
|
||||
L = torch.matmul(torch.matmul(D, graph), D)
|
||||
return L
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from federatedscope.trafficflow.model.DGCN import DGCN
|
||||
|
||||
|
||||
class DDGCRNCell(nn.Module):
|
||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
|
||||
super(DDGCRNCell, self).__init__()
|
||||
self.node_num = node_num
|
||||
self.hidden_dim = dim_out
|
||||
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
|
||||
self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim)
|
||||
|
||||
def forward(self, x, state, node_embeddings):
|
||||
state = state.to(x.device)
|
||||
input_and_state = torch.cat((x, state), dim=-1)
|
||||
z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))
|
||||
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
|
||||
candidate = torch.cat((x, z * state), dim=-1)
|
||||
hc = torch.tanh(self.update(candidate, node_embeddings))
|
||||
h = r * state + (1 - r) * hc
|
||||
return h
|
||||
|
||||
def init_hidden_state(self, batch_size):
|
||||
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
from federatedscope.register import register_model
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from federatedscope.trafficflow.model.DGCRUCell import DDGCRNCell
|
||||
|
||||
class DGCRM(nn.Module):
|
||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
||||
super(DGCRM, self).__init__()
|
||||
assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.'
|
||||
self.node_num = node_num
|
||||
self.input_dim = dim_in
|
||||
self.num_layers = num_layers
|
||||
self.DGCRM_cells = nn.ModuleList()
|
||||
self.DGCRM_cells.append(DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim))
|
||||
for _ in range(1, num_layers):
|
||||
self.DGCRM_cells.append(DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim))
|
||||
|
||||
def forward(self, x, init_state, node_embeddings):
|
||||
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
||||
seq_length = x.shape[1]
|
||||
current_inputs = x
|
||||
output_hidden = []
|
||||
for i in range(self.num_layers):
|
||||
state = init_state[i]
|
||||
inner_states = []
|
||||
for t in range(seq_length):
|
||||
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
||||
inner_states.append(state)
|
||||
output_hidden.append(state)
|
||||
current_inputs = torch.stack(inner_states, dim=1)
|
||||
return current_inputs, output_hidden
|
||||
|
||||
def init_hidden(self, batch_size):
|
||||
init_states = []
|
||||
for i in range(self.num_layers):
|
||||
init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size))
|
||||
return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim)
|
||||
|
||||
# Build you torch or tf model class here
|
||||
class FedDGCN(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(FedDGCN, self).__init__()
|
||||
self.num_node = args.num_nodes
|
||||
self.input_dim = args.input_dim
|
||||
self.hidden_dim = args.rnn_units
|
||||
self.output_dim = args.output_dim
|
||||
self.horizon = args.horizon
|
||||
self.num_layers = args.num_layers
|
||||
self.use_D = args.use_day
|
||||
self.use_W = args.use_week
|
||||
self.dropout1 = nn.Dropout(p=args.dropout) # 0.1
|
||||
self.dropout2 = nn.Dropout(p=args.dropout)
|
||||
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
||||
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
||||
self.T_i_D_emb = nn.Parameter(torch.empty(288, args.embed_dim))
|
||||
self.D_i_W_emb = nn.Parameter(torch.empty(7, args.embed_dim))
|
||||
# 初始化参数
|
||||
nn.init.xavier_uniform_(self.node_embeddings1)
|
||||
nn.init.xavier_uniform_(self.T_i_D_emb)
|
||||
nn.init.xavier_uniform_(self.D_i_W_emb)
|
||||
|
||||
self.encoder1 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order,
|
||||
args.embed_dim, args.num_layers)
|
||||
self.encoder2 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order,
|
||||
args.embed_dim, args.num_layers)
|
||||
# predictor
|
||||
self.end_conv1 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||
self.end_conv2 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||
self.end_conv3 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||
|
||||
def forward(self, source, i=2):
|
||||
node_embedding1 = self.node_embeddings1
|
||||
if self.use_D:
|
||||
t_i_d_data = source[..., 1]
|
||||
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)]
|
||||
node_embedding1 = torch.mul(node_embedding1, T_i_D_emb)
|
||||
|
||||
if self.use_W:
|
||||
d_i_w_data = source[..., 2]
|
||||
D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)]
|
||||
node_embedding1 = torch.mul(node_embedding1, D_i_W_emb)
|
||||
|
||||
node_embeddings=[node_embedding1,self.node_embeddings1]
|
||||
|
||||
source = source[..., 0].unsqueeze(-1)
|
||||
|
||||
init_state1 = self.encoder1.init_hidden(source.shape[0])
|
||||
output, _ = self.encoder1(source, init_state1, node_embeddings)
|
||||
output = self.dropout1(output[:, -1:, :, :])
|
||||
|
||||
output1 = self.end_conv1(output)
|
||||
source1 = self.end_conv2(output)
|
||||
|
||||
source2 = source - source1
|
||||
|
||||
init_state2 = self.encoder2.init_hidden(source2.shape[0])
|
||||
output2, _ = self.encoder2(source2, init_state2, node_embeddings)
|
||||
output2 = self.dropout2(output2[:, -1:, :, :])
|
||||
output2 = self.end_conv3(output2)
|
||||
|
||||
return output1 + output2
|
||||
|
||||
|
||||
# Instantiate your model class with config and data
|
||||
def ModelBuilder(model_config, local_data):
|
||||
model = FedDGCN(model_config)
|
||||
return model
|
||||
|
||||
|
||||
def call_ddgcrn(model_config, local_data):
|
||||
if model_config.type == "DDGCRN":
|
||||
model = ModelBuilder(model_config, local_data)
|
||||
return model
|
||||
|
||||
|
||||
register_model("DDGCRN", call_ddgcrn)
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
|
||||
from federatedscope.register import register_splitter
|
||||
from federatedscope.core.splitters import BaseSplitter
|
||||
|
||||
|
||||
class TrafficSplitter(BaseSplitter):
|
||||
def __init__(self, client_num, **kwargs):
|
||||
super(TrafficSplitter, self).__init__(client_num, **kwargs)
|
||||
|
||||
def __call__(self, dataset, *args, **kwargs):
|
||||
"""
|
||||
后面考虑子图标记划分
|
||||
|
||||
Args:
|
||||
dataset: ndarray(timestep, num_node, channel)
|
||||
*args:
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
[ndarray(timestep, per_node, channel) * client_nums]
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def call_my_splitter(splitter_type, client_num, **kwargs):
|
||||
if splitter_type == 'trafficflow':
|
||||
splitter = TrafficSplitter(client_num, **kwargs)
|
||||
return splitter
|
||||
|
||||
|
||||
register_splitter('trafficflow', call_my_splitter)
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer as Trainer
|
||||
from federatedscope.trafficflow.dataset.normalization import StandardScaler
|
||||
from federatedscope.core.trainers.enums import MODE, LIFECYCLE
|
||||
from federatedscope.core.trainers.context import Context, CtxVar, lifecycle
|
||||
|
||||
def print_model_parameters(model):
|
||||
print("Model parameters and their shapes:")
|
||||
for name, param in model.named_parameters():
|
||||
print(f"{name}: {param.shape}")
|
||||
|
||||
class TrafficflowTrainer(Trainer):
|
||||
def __init__(self, model, scaler, args, data, device,
|
||||
monitor):
|
||||
super().__init__(model, data, device, args, monitor=monitor)
|
||||
self.scaler = StandardScaler(scaler[0], scaler[1])
|
||||
|
||||
def train(self, target_data_split_name="train", hooks_set=None):
|
||||
hooks_set = hooks_set or self.hooks_in_train
|
||||
|
||||
self.ctx.check_split(target_data_split_name)
|
||||
|
||||
num_samples = self._run_routine(MODE.TRAIN, hooks_set,
|
||||
target_data_split_name)
|
||||
|
||||
train_loss = self.ctx.eval_metrics
|
||||
val_loss = self.evaluate('val')
|
||||
test_loss = self.evaluate('test')
|
||||
all_metrics = {'train_loss': train_loss['train_avg_loss'],
|
||||
'val_loss': val_loss['val_avg_loss'],
|
||||
'test_loss': test_loss['test_avg_loss'],
|
||||
}
|
||||
self.ctx.eval_metrics = all_metrics
|
||||
|
||||
return num_samples, self.get_model_para(), self.ctx.eval_metrics
|
||||
|
||||
def _hook_on_batch_forward(self, ctx):
|
||||
"""
|
||||
Note:
|
||||
The modified attributes and according operations are shown below:
|
||||
================================== ===========================
|
||||
Attribute Operation
|
||||
================================== ===========================
|
||||
``ctx.y_true`` Move to `ctx.device`
|
||||
``ctx.y_prob`` Forward propagation get y_prob
|
||||
``ctx.loss_batch`` Calculate the loss
|
||||
``ctx.batch_size`` Get the batch_size
|
||||
================================== ===========================
|
||||
"""
|
||||
x, label = [_.to(ctx.device) for _ in ctx.data_batch]
|
||||
pred = ctx.model(x)
|
||||
pred = self.scaler.inverse_transform(pred)
|
||||
|
||||
if len(label.size()) == 0:
|
||||
label = label.unsqueeze(0)
|
||||
|
||||
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
|
||||
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
|
||||
ctx.loss_batch = CtxVar(ctx.criterion(pred, label), LIFECYCLE.BATCH)
|
||||
ctx.batch_size = CtxVar(len(label), LIFECYCLE.BATCH)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from federatedscope.register import register_trainer
|
||||
|
||||
def call_trafficflow_trainer(config, model, data, device, monitor):
|
||||
if config.trainer.type == 'trafficflowtrainer':
|
||||
from federatedscope.trafficflow.trainer.trafficflow import TrafficflowTrainer
|
||||
Trainer = TrafficflowTrainer(model=model,
|
||||
scaler=config.data.scaler,
|
||||
args=config,
|
||||
data=data,
|
||||
device=device,
|
||||
monitor=monitor)
|
||||
return Trainer
|
||||
|
||||
|
||||
# register_trainer('trafficflowtrainer', call_trafficflow_trainer)
|
||||
Loading…
Reference in New Issue