Update traffic_dataloader_v2.py

This commit is contained in:
HengZhang 2024-11-27 21:14:01 +08:00
parent 578ff48c71
commit 22b229891a
1 changed files with 43 additions and 11 deletions

View File

@ -86,11 +86,10 @@ def data_loader(X, Y, batch_size, shuffle=True, drop_last=True):
def load_traffic_data(config, client_cfgs): def load_traffic_data(config, client_cfgs):
print("Use Mini graph")
root = config.data.root root = config.data.root
dataName = 'PEMSD' + root[-1] dataName = 'PEMSD' + root[-1]
raw_data = load_st_dataset(dataName) raw_data = load_st_dataset(dataName)
sub_graph_size = config.model.minigraph_size
l, n, f = raw_data.shape l, n, f = raw_data.shape
@ -132,9 +131,6 @@ def load_traffic_data(config, client_cfgs):
x_train[..., :config.model.input_dim] = scaler.transform(x_train[..., :config.model.input_dim]) x_train[..., :config.model.input_dim] = scaler.transform(x_train[..., :config.model.input_dim])
x_val[..., :config.model.input_dim] = scaler.transform(x_val[..., :config.model.input_dim]) x_val[..., :config.model.input_dim] = scaler.transform(x_val[..., :config.model.input_dim])
x_test[..., :config.model.input_dim] = scaler.transform(x_test[..., :config.model.input_dim]) x_test[..., :config.model.input_dim] = scaler.transform(x_test[..., :config.model.input_dim])
# y_train[..., :config.model.output_dim] = scaler.transform(y_train[..., :config.model.output_dim])
# y_val[..., :config.model.output_dim] = scaler.transform(y_val[..., :config.model.output_dim])
# y_test[..., :config.model.output_dim] = scaler.transform(y_test[..., :config.model.output_dim])
# Client-side dataset splitting # Client-side dataset splitting
node_num = config.data.num_nodes node_num = config.data.num_nodes
@ -171,18 +167,20 @@ def load_traffic_data(config, client_cfgs):
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
minigraph_size = config.model.minigraph_size
data_list[i + 1] = { data_list[i + 1] = {
'train': torch.utils.data.TensorDataset( 'train': torch.utils.data.TensorDataset(
torch.tensor(sub_array_train, dtype=torch.float, device=device), torch.tensor(split_into_mini_graphs(sub_array_train, minigraph_size), dtype=torch.float, device=device),
torch.tensor(sub_y_train, dtype=torch.float, device=device) torch.tensor(split_into_mini_graphs(sub_y_train, minigraph_size), dtype=torch.float, device=device)
), ),
'val': torch.utils.data.TensorDataset( 'val': torch.utils.data.TensorDataset(
torch.tensor(sub_array_val, dtype=torch.float, device=device), torch.tensor(split_into_mini_graphs(sub_array_val, minigraph_size), dtype=torch.float, device=device),
torch.tensor(sub_y_val, dtype=torch.float, device=device) torch.tensor(split_into_mini_graphs(sub_y_val, minigraph_size), dtype=torch.float, device=device)
), ),
'test': torch.utils.data.TensorDataset( 'test': torch.utils.data.TensorDataset(
torch.tensor(sub_array_test, dtype=torch.float, device=device), torch.tensor(split_into_mini_graphs(sub_array_test, minigraph_size), dtype=torch.float, device=device),
torch.tensor(sub_y_test, dtype=torch.float, device=device) torch.tensor(split_into_mini_graphs(sub_y_test, minigraph_size), dtype=torch.float, device=device)
) )
} }
cur_index += per_samples cur_index += per_samples
@ -190,6 +188,40 @@ def load_traffic_data(config, client_cfgs):
return data_list, config return data_list, config
def split_into_mini_graphs(tensor, minigraph_size):
"""
Splits a tensor into mini-graphs of specified size. Pads the last mini-graph with dummy nodes if necessary.
Args:
tensor (np.ndarray): Input tensor with shape (timestep, horizon, node_num, dim).
minigraph_size (int): The size of each mini-graph.
Returns:
np.ndarray: Output tensor with shape (timestep, horizon, minigraph_num, minigraph_size, dim).
"""
timestep, horizon, node_num, dim = tensor.shape
# Calculate the number of mini-graphs
minigraph_num = (node_num + minigraph_size - 1) // minigraph_size # Round up division
# Initialize output tensor with zeros (dummy nodes)
output = np.zeros((timestep, horizon, minigraph_num, minigraph_size, dim), dtype=tensor.dtype)
# Fill in the real data
for i in range(minigraph_num):
start_idx = i * minigraph_size
end_idx = min(start_idx + minigraph_size, node_num) # Ensure we don't exceed the node number
slice_size = end_idx - start_idx
# Assign the data to the corresponding mini-graph
output[:, :, i, :slice_size, :] = tensor[:, :, start_idx:end_idx, :]
# For the remaining part in the mini-graph, it remains as dummy nodes (zeros)
return output
if __name__ == '__main__': if __name__ == '__main__':
a = 'data/trafficflow/PeMS04' a = 'data/trafficflow/PeMS04'
name = 'PEMSD' + a[-1] name = 'PEMSD' + a[-1]