From 3cb8fddc8e338e4bf0efcd92859e3b68ee2d079e Mon Sep 17 00:00:00 2001 From: Seth Date: Tue, 10 Jun 2025 02:17:00 +0000 Subject: [PATCH 1/6] update Metra-LA to support index-batching --- torch_geometric_temporal/dataset/metr_la.py | 139 +++++++++++++++++--- 1 file changed, 123 insertions(+), 16 deletions(-) diff --git a/torch_geometric_temporal/dataset/metr_la.py b/torch_geometric_temporal/dataset/metr_la.py index 00ba68ec..00169cca 100644 --- a/torch_geometric_temporal/dataset/metr_la.py +++ b/torch_geometric_temporal/dataset/metr_la.py @@ -5,6 +5,9 @@ import numpy as np import torch from torch_geometric.utils import dense_to_sparse +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from typing import Tuple from ..signal import StaticGraphTemporalSignal @@ -20,11 +23,17 @@ class METRLADatasetLoader(object): Data-Driven Traffic Forecasting" `_ """ - def __init__(self, raw_data_dir=os.path.join(os.getcwd(), "data")): + def __init__(self, raw_data_dir=os.path.join(os.getcwd(), "data"), index: bool = False): super(METRLADatasetLoader, self).__init__() + self.index = index + self.raw_data_dir = raw_data_dir self._read_web_data() + if index: + from ..signal.index_dataset import IndexDataset + self.IndexDataset = IndexDataset + def _download_url(self, url, save_path): # pragma: no cover context = ssl._create_unverified_context() with urllib.request.urlopen(url, context=context) as dl_file: @@ -51,21 +60,21 @@ def _read_web_data(self): os.path.join(self.raw_data_dir, "METR-LA.zip"), "r" ) as zip_fh: zip_fh.extractall(self.raw_data_dir) - - A = np.load(os.path.join(self.raw_data_dir, "adj_mat.npy")) - X = np.load(os.path.join(self.raw_data_dir, "node_values.npy")).transpose( - (1, 2, 0) - ) - X = X.astype(np.float32) - - # Normalise as in DCRNN paper (via Z-Score Method) - means = np.mean(X, axis=(0, 2)) - X = X - means.reshape(1, -1, 1) - stds = np.std(X, axis=(0, 2)) - X = X / stds.reshape(1, -1, 1) - - self.A = torch.from_numpy(A) - self.X = torch.from_numpy(X) + if not self.index: + A = np.load(os.path.join(self.raw_data_dir, "adj_mat.npy")) + X = np.load(os.path.join(self.raw_data_dir, "node_values.npy")).transpose( + (1, 2, 0) + ) + X = X.astype(np.float32) + + # Normalise as in DCRNN paper (via Z-Score Method) + means = np.mean(X, axis=(0, 2)) + X = X - means.reshape(1, -1, 1) + stds = np.std(X, axis=(0, 2)) + X = X / stds.reshape(1, -1, 1) + + self.A = torch.from_numpy(A) + self.X = torch.from_numpy(X) def _get_edges_and_weights(self): edge_indices, values = dense_to_sparse(self.A) @@ -116,3 +125,101 @@ def get_dataset( ) return dataset + + + + def get_index_dataset(self, lags: int = 12, batch_size: int = 64, shuffle: bool = False, allGPU: int = -1, + ratio: Tuple[float, float, float] = (0.7, 0.1, 0.2), world_size: int =-1, ddp_rank: int = -1, + dask_batching: bool = False) -> Tuple[DataLoader, DataLoader, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Returns torch dataloaders using index batching for PeMS dataset. + + Args: + lags (int, optional): The number of time lags. Defaults to 12. + batch_size (int, optional): Batch size. Defaults to 64. + shuffle (bool, optional): If the data should be shuffled. Defaults to False. + allGPU (int, optional): GPU device ID for performing preprocessing in GPU memory. + If -1, computation is done on CPU. Defaults to -1. + world_size (int, optional): The number of workers if DDP is being used. Defaults to -1. + ddp_rank (int, optional): The DDP rank of the worker if DDP is being used. Defaults to -1. + ratio (tuple of float, optional): The desired train, validation, and test split ratios, respectively. + + Returns: + Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + A 7-tuple containing: + - **train_dataLoader** (*torch.utils.data.DataLoader*): Dataloader for the training set. + - **val_dataLoader** (*torch.utils.data.DataLoader*): Dataloader for the validation set. + - **test_dataLoader** (*torch.utils.data.DataLoader*): Dataloader for the test set. + - **edges** (*torch.Tensor*): The graph edges as a 2D matrix, shape `[2, num_edges]`. + - **edge_weights** (*torch.Tensor*): Each graph edge's weight, shape `[num_edges]`. + - **means** (*torch.Tensor*): The means of each feature dimension. + - **stds** (*torch.Tensor*): The standard deviations of each feature dimension. + """ + + # adj matrix setup + A = np.load(os.path.join(self.raw_data_dir, "adj_mat.npy")) + edges, edge_weights = dense_to_sparse(torch.from_numpy(A)) + + data = np.load(os.path.join(self.raw_data_dir, "node_values.npy")).transpose( (1, 2, 0)) + data = data.astype(np.float32) + + # Normalise as in DCRNN paper (via Z-Score Method) + if allGPU != -1: + data = torch.tensor(data,dtype=torch.float).to(f"cuda:{allGPU}") + means = torch.mean(data, dim=(0, 2), keepdim=True) + data = data - means + + stds = torch.std(data, dim=(0, 2), keepdim=True) + data = data / stds + data = data.permute(2, 0, 1) + + means.squeeze_() + stds.squeeze_() + + else: + + means = np.mean(data, axis=(0, 2)) + data = data - means.reshape(1, -1, 1) + stds = np.std(data, axis=(0, 2)) + data = data / stds.reshape(1, -1, 1) + data = data.transpose((2, 0, 1)) + + means = torch.tensor(means,dtype=torch.float) + stds = torch.tensor(stds,dtype=torch.float) + + + num_samples = data.shape[0] + x_i = np.arange(num_samples - (2 * lags - 1)) + num_samples = x_i.shape[0] + num_train = round(num_samples * ratio[0]) + num_test = round(num_samples * ratio[2]) + num_val = num_samples - num_train - num_test + + x_train = x_i[:num_train] + x_val = x_i[num_train: num_train + num_val] + x_test = x_i[-num_test:] + + train_dataset = self.IndexDataset(x_train,data,lags,gpu=not (allGPU == -1), lazy=dask_batching) + val_dataset = self.IndexDataset(x_val,data,lags,gpu=not (allGPU == -1), lazy=dask_batching) + test_dataset = self.IndexDataset(x_test,data,lags,gpu=not (allGPU == -1),lazy=dask_batching) + + + if ddp_rank != -1: + train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=ddp_rank, shuffle=shuffle) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler) + + val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=ddp_rank, shuffle=shuffle) + val_dataloader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler) + + test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=ddp_rank, shuffle=shuffle) + test_dataloader = DataLoader(test_dataset, batch_size=batch_size, sampler=test_sampler) + else: + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) + val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle) + test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle) + + return train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, means, stds + + + From 0231863860daf2d533e088c30767cad11c8ebb41 Mon Sep 17 00:00:00 2001 From: Seth Date: Tue, 10 Jun 2025 02:28:00 +0000 Subject: [PATCH 2/6] fix doc strings --- torch_geometric_temporal/dataset/metr_la.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric_temporal/dataset/metr_la.py b/torch_geometric_temporal/dataset/metr_la.py index 00169cca..323bf85d 100644 --- a/torch_geometric_temporal/dataset/metr_la.py +++ b/torch_geometric_temporal/dataset/metr_la.py @@ -132,7 +132,7 @@ def get_index_dataset(self, lags: int = 12, batch_size: int = 64, shuffle: bool ratio: Tuple[float, float, float] = (0.7, 0.1, 0.2), world_size: int =-1, ddp_rank: int = -1, dask_batching: bool = False) -> Tuple[DataLoader, DataLoader, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - Returns torch dataloaders using index batching for PeMS dataset. + Returns torch dataloaders using index batching for Metr-LA dataset. Args: lags (int, optional): The number of time lags. Defaults to 12. From 861b1912f1253a7cf4198477d72f3eb738adcb46 Mon Sep 17 00:00:00 2001 From: Seth Date: Tue, 10 Jun 2025 02:34:02 +0000 Subject: [PATCH 3/6] Create DCRNN sub-folder --- examples/indexBatching/{ => DCRNN}/chicken_pox_main.py | 0 examples/indexBatching/{ => DCRNN}/pems_allLA_main.py | 0 examples/indexBatching/{ => DCRNN}/pems_bay_main.py | 0 examples/indexBatching/{ => DCRNN}/pems_ddp.py | 0 examples/indexBatching/{ => DCRNN}/pems_main.py | 0 examples/indexBatching/{ => DCRNN}/submit.sh | 0 examples/indexBatching/{ => DCRNN}/utils.py | 0 examples/indexBatching/{ => DCRNN}/windmill_main.py | 0 examples/indexBatching/README.md | 4 ++++ 9 files changed, 4 insertions(+) rename examples/indexBatching/{ => DCRNN}/chicken_pox_main.py (100%) rename examples/indexBatching/{ => DCRNN}/pems_allLA_main.py (100%) rename examples/indexBatching/{ => DCRNN}/pems_bay_main.py (100%) rename examples/indexBatching/{ => DCRNN}/pems_ddp.py (100%) rename examples/indexBatching/{ => DCRNN}/pems_main.py (100%) rename examples/indexBatching/{ => DCRNN}/submit.sh (100%) rename examples/indexBatching/{ => DCRNN}/utils.py (100%) rename examples/indexBatching/{ => DCRNN}/windmill_main.py (100%) diff --git a/examples/indexBatching/chicken_pox_main.py b/examples/indexBatching/DCRNN/chicken_pox_main.py similarity index 100% rename from examples/indexBatching/chicken_pox_main.py rename to examples/indexBatching/DCRNN/chicken_pox_main.py diff --git a/examples/indexBatching/pems_allLA_main.py b/examples/indexBatching/DCRNN/pems_allLA_main.py similarity index 100% rename from examples/indexBatching/pems_allLA_main.py rename to examples/indexBatching/DCRNN/pems_allLA_main.py diff --git a/examples/indexBatching/pems_bay_main.py b/examples/indexBatching/DCRNN/pems_bay_main.py similarity index 100% rename from examples/indexBatching/pems_bay_main.py rename to examples/indexBatching/DCRNN/pems_bay_main.py diff --git a/examples/indexBatching/pems_ddp.py b/examples/indexBatching/DCRNN/pems_ddp.py similarity index 100% rename from examples/indexBatching/pems_ddp.py rename to examples/indexBatching/DCRNN/pems_ddp.py diff --git a/examples/indexBatching/pems_main.py b/examples/indexBatching/DCRNN/pems_main.py similarity index 100% rename from examples/indexBatching/pems_main.py rename to examples/indexBatching/DCRNN/pems_main.py diff --git a/examples/indexBatching/submit.sh b/examples/indexBatching/DCRNN/submit.sh similarity index 100% rename from examples/indexBatching/submit.sh rename to examples/indexBatching/DCRNN/submit.sh diff --git a/examples/indexBatching/utils.py b/examples/indexBatching/DCRNN/utils.py similarity index 100% rename from examples/indexBatching/utils.py rename to examples/indexBatching/DCRNN/utils.py diff --git a/examples/indexBatching/windmill_main.py b/examples/indexBatching/DCRNN/windmill_main.py similarity index 100% rename from examples/indexBatching/windmill_main.py rename to examples/indexBatching/DCRNN/windmill_main.py diff --git a/examples/indexBatching/README.md b/examples/indexBatching/README.md index 62d5c4ce..f6133354 100644 --- a/examples/indexBatching/README.md +++ b/examples/indexBatching/README.md @@ -3,11 +3,15 @@ Index-batching is a technique that reduces the memory cost of training ST-GNNs with spatiotemporal data with no impact on accurary, enabling greater scalability and training on the full PeMS dataset without graph partioning for the first time. Leveraging the reduced memory footprint, this techique also enables GPU-index-batching - a technique that performs preprocessing entirely in GPU memory and utilizes a single CPU-to-GPU mem-copy in place of batch-level CPU-to-GPU transfers throughout training. We implemented GPU-index-batching and index-batching for the following existing datasets and added two new datasets (highlighted in bold) to PyTorch Geometric Temporal (PGT): * PeMs-Bay +* Metr-LA * WindmillLarge * HungaryChickenpox * **PeMSAllLA** * **PeMS** +This folder contains examples with DCRNN and A3TGCN. We hope to build out our examples over time. + + Utilizing index-batching requires minimal modifications to the existing PGT workflow. For example, the following is a sample training loop with static graph dataset with temporal signal: ``` From 06e101f0dc49122d5ec737f9d4f5ebf70713f498 Mon Sep 17 00:00:00 2001 From: Seth Date: Tue, 10 Jun 2025 02:34:49 +0000 Subject: [PATCH 4/6] index-batching with a3tgcn --- examples/indexBatching/A3TGCN/metr_la_main.py | 166 ++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 examples/indexBatching/A3TGCN/metr_la_main.py diff --git a/examples/indexBatching/A3TGCN/metr_la_main.py b/examples/indexBatching/A3TGCN/metr_la_main.py new file mode 100644 index 00000000..329c69b1 --- /dev/null +++ b/examples/indexBatching/A3TGCN/metr_la_main.py @@ -0,0 +1,166 @@ +import numpy as np +import time +import csv +import torch +import torch.nn.functional as F +from torch_geometric.nn import GCNConv +from torch_geometric_temporal.nn.recurrent import A3TGCN2 +from torch_geometric_temporal.dataset import METRLADatasetLoader +import argparse + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Demo of index batching with PemsBay dataset") + + parser.add_argument( + "-e", "--epochs", type=int, default=100, help="The desired number of training epochs" + ) + parser.add_argument( + "-bs", "--batch-size", type=int, default=64, help="The desired batch size" + ) + parser.add_argument( + "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU" + ) + parser.add_argument( + "-d", "--debug", type=str, default="False", help="Print values for debugging" + ) + return parser.parse_args() + +# Making the model +class TemporalGNN(torch.nn.Module): + def __init__(self, node_features, periods, batch_size): + super(TemporalGNN, self).__init__() + # Attention Temporal Graph Convolutional Cell + self.tgnn = A3TGCN2(in_channels=node_features, out_channels=32, periods=periods,batch_size=batch_size) # node_features=2, periods=12 + # Equals single-shot prediction + self.linear = torch.nn.Linear(32, periods) + + def forward(self, x, edge_index): + """ + x = Node features for T time steps + edge_index = Graph edge indices + """ + h = self.tgnn(x, edge_index) # x [b, 207, 2, 12] returns h [b, 207, 12] + h = F.relu(h) + h = self.linear(h) + return h + + + +def train(train_dataloader, val_dataloader, batch_size, epochs, edges, DEVICE, allGPU=False, debug=False): + + # Create model and optimizers + model = TemporalGNN(node_features=2, periods=12, batch_size=batch_size).to(DEVICE) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + loss_fn = torch.nn.MSELoss() + + stats = [] + t_mse = [] + v_mse = [] + + + edges = edges.to(DEVICE) + for epoch in range(epochs): + step = 0 + loss_list = [] + t1 = time.time() + i = 1 + total = len(train_dataloader) + for batch in train_dataloader: + X_batch, y_batch = batch + + # Need to permute based on expected input shape for ATGCN + if allGPU: + X_batch = X_batch.permute(0, 2, 3, 1) + y_batch = y_batch[...,0].permute(0, 2, 1) + else: + X_batch = X_batch.permute(0, 2, 3, 1).to(DEVICE) + y_batch = y_batch[...,0].permute(0, 2, 1).to(DEVICE) + + + + y_hat = model(X_batch, edges) # Get model predictions + loss = loss_fn(y_hat, y_batch) # Mean squared error #loss = torch.mean((y_hat-labels)**2) sqrt to change it to rmse + loss.backward() + optimizer.step() + optimizer.zero_grad() + step= step+ 1 + loss_list.append(loss.item()) + + if debug: + print(f"Train Batch: {i}/{total}", end="\r") + i+=1 + + + model.eval() + step = 0 + # Store for analysis + total_loss = [] + i = 1 + total = len(val_dataloader) + if debug: + print(" ", end="\r") + with torch.no_grad(): + for batch in val_dataloader: + X_batch, y_batch = batch + + + # Need to permute based on expected input shape for ATGCN + if allGPU: + X_batch = X_batch.permute(0, 2, 3, 1) + y_batch = y_batch[...,0].permute(0, 2, 1) + else: + X_batch = X_batch.permute(0, 2, 3, 1).to(DEVICE) + y_batch = y_batch[...,0].permute(0, 2, 1).to(DEVICE) + + # Get model predictions + y_hat = model(X_batch, edges) + # Mean squared error + loss = loss_fn(y_hat, y_batch) + total_loss.append(loss.item()) + + if debug: + print(f"Val Batch: {i}/{total}", end="\r") + i += 1 + + + t2 = time.time() + print("Epoch {} time: {:.4f} train RMSE: {:.4f} Test MSE: {:.4f}".format(epoch,t2 - t1, sum(loss_list)/len(loss_list), sum(total_loss)/len(total_loss))) + stats.append([epoch, t2-t1, sum(loss_list)/len(loss_list), sum(total_loss)/len(total_loss)]) + t_mse.append(sum(loss_list)/len(loss_list)) + v_mse.append(sum(total_loss)/len(total_loss)) + return min(t_mse), min(v_mse) + + + + + + + + +def main(): + args = parse_arguments() + allGPU = args.gpu.lower() in ["true", "y", "t", "yes"] + debug = args.debug.lower() in ["true", "y", "t", "yes"] + batch_size = args.batch_size + epochs = args.epochs + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + shuffle= True + + + start = time.time() + p1 = time.time() + indexLoader = METRLADatasetLoader(index=True) + if allGPU: + train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = indexLoader.get_index_dataset(batch_size=batch_size, shuffle=shuffle, allGPU=0) + else: + train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = indexLoader.get_index_dataset(batch_size=batch_size, shuffle=shuffle) + p2 = time.time() + t_mse, v_mse = train(train_dataloader, val_dataloader, batch_size, epochs, edges, device, debug=debug) + end = time.time() + + print(f"Runtime: {round(end - start,2)}; T-MSE: {round(t_mse, 3)}; V-MSE: {round(v_mse, 3)}") + +if __name__ == "__main__": + main() \ No newline at end of file From 4b17db72457627ca41390b69edd47954a1177f0c Mon Sep 17 00:00:00 2001 From: Seth Ockerman Date: Mon, 9 Jun 2025 21:42:49 -0500 Subject: [PATCH 5/6] update docs --- docs/source/notes/introduction.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/notes/introduction.rst b/docs/source/notes/introduction.rst index 41a042ff..db2cb27b 100644 --- a/docs/source/notes/introduction.rst +++ b/docs/source/notes/introduction.rst @@ -130,6 +130,7 @@ index-batching for the following existing datasets and added two new datasets (h to PyTorch Geometric Temporal (PGT): * PeMs-Bay +* Metr-LA * WindmillLarge * HungaryChickenpox * **PeMSAllLA** From e883dff0b5b29c6faae1479d11818f5958a51b08 Mon Sep 17 00:00:00 2001 From: Seth Date: Tue, 10 Jun 2025 03:29:23 +0000 Subject: [PATCH 6/6] pems_ddp.py --- examples/indexBatching/A3TGCN/pems_ddp.py | 244 ++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 examples/indexBatching/A3TGCN/pems_ddp.py diff --git a/examples/indexBatching/A3TGCN/pems_ddp.py b/examples/indexBatching/A3TGCN/pems_ddp.py new file mode 100644 index 00000000..2d98111f --- /dev/null +++ b/examples/indexBatching/A3TGCN/pems_ddp.py @@ -0,0 +1,244 @@ +import time +import csv +import argparse +import uuid +import os + +import numpy as np +import time +import csv +import torch +import torch.nn.functional as F +from torch_geometric.nn import GCNConv +from torch_geometric_temporal.nn.recurrent import A3TGCN2 + +from torch_geometric_temporal.dataset import PemsBayDatasetLoader,PemsAllLADatasetLoader,PemsDatasetLoader,METRLADatasetLoader + + +import torch +import torch.optim as optim +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + + +from dask.distributed import LocalCluster +from dask.distributed import Client +from dask_pytorch_ddp import dispatch, results +from dask.distributed import wait as Wait + + +def parse_arguments(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Demo of DDP index-batching with PeMS-Bay, PeMS-All-LA, and PeMS") + + parser.add_argument( + "-e", "--epochs", type=int, default=100, help="The desired number of training epochs" + ) + parser.add_argument( + "-bs", "--batch-size", type=int, default=64, help="The desired batch size" + ) + parser.add_argument( + "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU" + ) + parser.add_argument( + "-d", "--debug", type=str, default="False", help="Print values for debugging" + ) + + parser.add_argument( + "--dask-cluster-file", type=str, default="", help="Dask scheduler file for the Dask CLI Interfance" + ) + + parser.add_argument( + "-np","--npar", type=int, default=1, help="The number of GPUs/workers per node" + ) + parser.add_argument( + "--dataset", type=str, default="pems-bay", help="Which dataset is in use" + ) + + + return parser.parse_args() + + +# Making the model +class TemporalGNN(torch.nn.Module): + def __init__(self, node_features, periods, batch_size): + super(TemporalGNN, self).__init__() + # Attention Temporal Graph Convolutional Cell + self.tgnn = A3TGCN2(in_channels=node_features, out_channels=32, periods=periods,batch_size=batch_size) # node_features=2, periods=12 + # Equals single-shot prediction + self.linear = torch.nn.Linear(32, periods) + + def forward(self, x, edge_index): + """ + x = Node features for T time steps + edge_index = Graph edge indices + """ + h = self.tgnn(x, edge_index) # x [b, 207, 2, 12] returns h [b, 207, 12] + h = F.relu(h) + h = self.linear(h) + return h + + + +def train(args=None, epochs=None, batch_size=None, allGPU=False, debug=False, loader=None, start_time=None): + + worker_rank = int(dist.get_rank()) + gpu = worker_rank % 4 + device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): torch.cuda.set_device(device) + + world_size = dist.get_world_size() + + + if allGPU == True: + train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = loader.get_index_dataset(allGPU=gpu, batch_size=batch_size, world_size=world_size, ddp_rank=worker_rank) + else: + train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = loader.get_index_dataset(batch_size=batch_size, world_size=world_size, ddp_rank=worker_rank) + + + + model = TemporalGNN(node_features=2, periods=12, batch_size=batch_size).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + loss_fn = torch.nn.MSELoss() + + stats = [] + t_mse = [] + v_mse = [] + + if torch.cuda.is_available(): + model = DDP(model, gradient_as_bucket_view=True, device_ids=[device], output_device=[device]) + else: + model = DDP(model, gradient_as_bucket_view=True) + + + + # Training loop + stats = [] + min_t = 9999 + min_v = 9999 + + edges = edges.to(device) + for epoch in range(epochs): + step = 0 + loss_list = [] + t1 = time.time() + i = 1 + total = len(train_dataloader) + for batch in train_dataloader: + X_batch, y_batch = batch + + # Need to permute based on expected input shape for ATGCN + if allGPU: + X_batch = X_batch.permute(0, 2, 3, 1) + y_batch = y_batch[...,0].permute(0, 2, 1) + else: + X_batch = X_batch.permute(0, 2, 3, 1).to(device).float() + y_batch = y_batch[...,0].permute(0, 2, 1).to(device).float() + + + + y_hat = model(X_batch, edges) # Get model predictions + loss = loss_fn(y_hat, y_batch) # Mean squared error #loss = torch.mean((y_hat-labels)**2) sqrt to change it to rmse + loss.backward() + optimizer.step() + optimizer.zero_grad() + step= step+ 1 + loss_list.append(loss.item()) + + if debug: + print(f"Train Batch: {i}/{total}", end="\r") + i+=1 + + + model.eval() + step = 0 + + # Store for analysis + val_loss = 0 + i = 1 + total = len(val_dataloader) + if debug: + print(" ", end="\r") + with torch.no_grad(): + for batch in val_dataloader: + X_batch, y_batch = batch + + + # Need to permute based on expected input shape for ATGCN + if allGPU: + X_batch = X_batch.permute(0, 2, 3, 1) + y_batch = y_batch[...,0].permute(0, 2, 1) + else: + X_batch = X_batch.permute(0, 2, 3, 1).to(device).float() + y_batch = y_batch[...,0].permute(0, 2, 1).to(device).float() + + # Get model predictions + y_hat = model(X_batch, edges) + # Mean squared error + loss = loss_fn(y_hat, y_batch) + val_loss += loss.item() + + if debug: + print(f"Val Batch: {i}/{total}", end="\r") + i += 1 + + val_tensor = torch.tensor([val_loss, len(val_dataloader)]) + dist.reduce(val_tensor,dst=0, op=dist.ReduceOp.SUM) + t2 = time.time() + + if worker_rank == 0: + val_loss = val_tensor[0]/ val_tensor[1] + + t2 = time.time() + print("Epoch {} time: {:.4f} Train RMSE: {:.4f} Validation MSE: {:.4f}".format(epoch,t2 - t1, sum(loss_list)/len(loss_list), val_loss.item())) + stats.append([epoch, t2-t1, sum(loss_list)/len(loss_list), val_loss.item()]) + t_mse.append(sum(loss_list)/len(loss_list)) + v_mse.append(val_loss.item()) + + + + print(f"Runtime: {round(t2 - start_time,2)}; Best Train MSE: {min(t_mse)}; Best Validation MSE: {v_mse}", flush=True) + +def main(): + args = parse_arguments() + allGPU = args.gpu.lower() in ["true", "y", "t", "yes"] + debug = args.debug.lower() in ["true", "y", "t", "yes"] + batch_size = args.batch_size + epochs = args.epochs + npar = args.npar + + if args.dataset.lower() not in ["pems-bay","pemsallla", "pems", "metr-la"]: + raise ValueError("Invalid argument for --dataset. --dataset must be 'metr-la', 'pems-bay', 'pemsAllLA', or 'pems'") + + t1 = time.time() + + # force the datasets to download before launching dask cluster + if args.dataset.lower() == "pems-bay": + loader = PemsBayDatasetLoader(index=True) + if args.dataset.lower() == "pemsallla": + loader = PemsAllLADatasetLoader(index=True) + if args.dataset.lower() == "pems": + loader = PemsDatasetLoader(index=True) + if args.dataset.lower() == "metr-la": + loader = METRLADatasetLoader(index=True) + + + if args.dask_cluster_file != "": + client = Client(scheduler_file = args.dask_cluster_file) + else: + cluster = LocalCluster(n_workers=npar) + client = Client(cluster) + + futures = dispatch.run(client, train, + args=args, debug=debug, epochs=epochs, batch_size=batch_size,allGPU=allGPU,loader=loader, + start_time=t1, + backend="gloo") + + key = uuid.uuid4().hex + rh = results.DaskResultsHandler(key) + rh.process_results(".", futures, raise_errors=False) + client.shutdown() + + +if __name__ == "__main__": + main() \ No newline at end of file