From 22b229891a2193f72957ee2faf7541d51e0c5bea Mon Sep 17 00:00:00 2001 From: HengZhang Date: Wed, 27 Nov 2024 21:14:01 +0800 Subject: [PATCH] Update traffic_dataloader_v2.py --- .../dataloader/traffic_dataloader_v2.py | 54 +++++++++++++++---- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py b/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py index 0a48590..bba354c 100644 --- a/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py +++ b/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py @@ -86,11 +86,10 @@ def data_loader(X, Y, batch_size, shuffle=True, drop_last=True): def load_traffic_data(config, client_cfgs): - print("Use Mini graph") root = config.data.root dataName = 'PEMSD' + root[-1] raw_data = load_st_dataset(dataName) - + sub_graph_size = config.model.minigraph_size l, n, f = raw_data.shape @@ -132,9 +131,6 @@ def load_traffic_data(config, client_cfgs): x_train[..., :config.model.input_dim] = scaler.transform(x_train[..., :config.model.input_dim]) x_val[..., :config.model.input_dim] = scaler.transform(x_val[..., :config.model.input_dim]) x_test[..., :config.model.input_dim] = scaler.transform(x_test[..., :config.model.input_dim]) - # y_train[..., :config.model.output_dim] = scaler.transform(y_train[..., :config.model.output_dim]) - # y_val[..., :config.model.output_dim] = scaler.transform(y_val[..., :config.model.output_dim]) - # y_test[..., :config.model.output_dim] = scaler.transform(y_test[..., :config.model.output_dim]) # Client-side dataset splitting node_num = config.data.num_nodes @@ -171,18 +167,20 @@ def load_traffic_data(config, client_cfgs): device = 'cuda' if torch.cuda.is_available() else 'cpu' + minigraph_size = config.model.minigraph_size + data_list[i + 1] = { 'train': torch.utils.data.TensorDataset( - torch.tensor(sub_array_train, dtype=torch.float, device=device), - torch.tensor(sub_y_train, dtype=torch.float, device=device) + torch.tensor(split_into_mini_graphs(sub_array_train, minigraph_size), dtype=torch.float, device=device), + torch.tensor(split_into_mini_graphs(sub_y_train, minigraph_size), dtype=torch.float, device=device) ), 'val': torch.utils.data.TensorDataset( - torch.tensor(sub_array_val, dtype=torch.float, device=device), - torch.tensor(sub_y_val, dtype=torch.float, device=device) + torch.tensor(split_into_mini_graphs(sub_array_val, minigraph_size), dtype=torch.float, device=device), + torch.tensor(split_into_mini_graphs(sub_y_val, minigraph_size), dtype=torch.float, device=device) ), 'test': torch.utils.data.TensorDataset( - torch.tensor(sub_array_test, dtype=torch.float, device=device), - torch.tensor(sub_y_test, dtype=torch.float, device=device) + torch.tensor(split_into_mini_graphs(sub_array_test, minigraph_size), dtype=torch.float, device=device), + torch.tensor(split_into_mini_graphs(sub_y_test, minigraph_size), dtype=torch.float, device=device) ) } cur_index += per_samples @@ -190,6 +188,40 @@ def load_traffic_data(config, client_cfgs): return data_list, config +def split_into_mini_graphs(tensor, minigraph_size): + """ + Splits a tensor into mini-graphs of specified size. Pads the last mini-graph with dummy nodes if necessary. + + Args: + tensor (np.ndarray): Input tensor with shape (timestep, horizon, node_num, dim). + minigraph_size (int): The size of each mini-graph. + + Returns: + np.ndarray: Output tensor with shape (timestep, horizon, minigraph_num, minigraph_size, dim). + """ + timestep, horizon, node_num, dim = tensor.shape + + # Calculate the number of mini-graphs + minigraph_num = (node_num + minigraph_size - 1) // minigraph_size # Round up division + + # Initialize output tensor with zeros (dummy nodes) + output = np.zeros((timestep, horizon, minigraph_num, minigraph_size, dim), dtype=tensor.dtype) + + # Fill in the real data + for i in range(minigraph_num): + start_idx = i * minigraph_size + end_idx = min(start_idx + minigraph_size, node_num) # Ensure we don't exceed the node number + slice_size = end_idx - start_idx + + # Assign the data to the corresponding mini-graph + output[:, :, i, :slice_size, :] = tensor[:, :, start_idx:end_idx, :] + + # For the remaining part in the mini-graph, it remains as dummy nodes (zeros) + + return output + + + if __name__ == '__main__': a = 'data/trafficflow/PeMS04' name = 'PEMSD' + a[-1]