Update traffic_dataloader_v2.py
This commit is contained in:
parent
578ff48c71
commit
22b229891a
|
|
@ -86,11 +86,10 @@ def data_loader(X, Y, batch_size, shuffle=True, drop_last=True):
|
|||
|
||||
|
||||
def load_traffic_data(config, client_cfgs):
|
||||
print("Use Mini graph")
|
||||
root = config.data.root
|
||||
dataName = 'PEMSD' + root[-1]
|
||||
raw_data = load_st_dataset(dataName)
|
||||
|
||||
sub_graph_size = config.model.minigraph_size
|
||||
|
||||
l, n, f = raw_data.shape
|
||||
|
||||
|
|
@ -132,9 +131,6 @@ def load_traffic_data(config, client_cfgs):
|
|||
x_train[..., :config.model.input_dim] = scaler.transform(x_train[..., :config.model.input_dim])
|
||||
x_val[..., :config.model.input_dim] = scaler.transform(x_val[..., :config.model.input_dim])
|
||||
x_test[..., :config.model.input_dim] = scaler.transform(x_test[..., :config.model.input_dim])
|
||||
# y_train[..., :config.model.output_dim] = scaler.transform(y_train[..., :config.model.output_dim])
|
||||
# y_val[..., :config.model.output_dim] = scaler.transform(y_val[..., :config.model.output_dim])
|
||||
# y_test[..., :config.model.output_dim] = scaler.transform(y_test[..., :config.model.output_dim])
|
||||
|
||||
# Client-side dataset splitting
|
||||
node_num = config.data.num_nodes
|
||||
|
|
@ -171,18 +167,20 @@ def load_traffic_data(config, client_cfgs):
|
|||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
minigraph_size = config.model.minigraph_size
|
||||
|
||||
data_list[i + 1] = {
|
||||
'train': torch.utils.data.TensorDataset(
|
||||
torch.tensor(sub_array_train, dtype=torch.float, device=device),
|
||||
torch.tensor(sub_y_train, dtype=torch.float, device=device)
|
||||
torch.tensor(split_into_mini_graphs(sub_array_train, minigraph_size), dtype=torch.float, device=device),
|
||||
torch.tensor(split_into_mini_graphs(sub_y_train, minigraph_size), dtype=torch.float, device=device)
|
||||
),
|
||||
'val': torch.utils.data.TensorDataset(
|
||||
torch.tensor(sub_array_val, dtype=torch.float, device=device),
|
||||
torch.tensor(sub_y_val, dtype=torch.float, device=device)
|
||||
torch.tensor(split_into_mini_graphs(sub_array_val, minigraph_size), dtype=torch.float, device=device),
|
||||
torch.tensor(split_into_mini_graphs(sub_y_val, minigraph_size), dtype=torch.float, device=device)
|
||||
),
|
||||
'test': torch.utils.data.TensorDataset(
|
||||
torch.tensor(sub_array_test, dtype=torch.float, device=device),
|
||||
torch.tensor(sub_y_test, dtype=torch.float, device=device)
|
||||
torch.tensor(split_into_mini_graphs(sub_array_test, minigraph_size), dtype=torch.float, device=device),
|
||||
torch.tensor(split_into_mini_graphs(sub_y_test, minigraph_size), dtype=torch.float, device=device)
|
||||
)
|
||||
}
|
||||
cur_index += per_samples
|
||||
|
|
@ -190,6 +188,40 @@ def load_traffic_data(config, client_cfgs):
|
|||
return data_list, config
|
||||
|
||||
|
||||
def split_into_mini_graphs(tensor, minigraph_size):
|
||||
"""
|
||||
Splits a tensor into mini-graphs of specified size. Pads the last mini-graph with dummy nodes if necessary.
|
||||
|
||||
Args:
|
||||
tensor (np.ndarray): Input tensor with shape (timestep, horizon, node_num, dim).
|
||||
minigraph_size (int): The size of each mini-graph.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Output tensor with shape (timestep, horizon, minigraph_num, minigraph_size, dim).
|
||||
"""
|
||||
timestep, horizon, node_num, dim = tensor.shape
|
||||
|
||||
# Calculate the number of mini-graphs
|
||||
minigraph_num = (node_num + minigraph_size - 1) // minigraph_size # Round up division
|
||||
|
||||
# Initialize output tensor with zeros (dummy nodes)
|
||||
output = np.zeros((timestep, horizon, minigraph_num, minigraph_size, dim), dtype=tensor.dtype)
|
||||
|
||||
# Fill in the real data
|
||||
for i in range(minigraph_num):
|
||||
start_idx = i * minigraph_size
|
||||
end_idx = min(start_idx + minigraph_size, node_num) # Ensure we don't exceed the node number
|
||||
slice_size = end_idx - start_idx
|
||||
|
||||
# Assign the data to the corresponding mini-graph
|
||||
output[:, :, i, :slice_size, :] = tensor[:, :, start_idx:end_idx, :]
|
||||
|
||||
# For the remaining part in the mini-graph, it remains as dummy nodes (zeros)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
a = 'data/trafficflow/PeMS04'
|
||||
name = 'PEMSD' + a[-1]
|
||||
|
|
|
|||
Loading…
Reference in New Issue