1+ import numpy as np
2+ import time
3+ import csv
4+ import torch
5+ import torch .nn .functional as F
6+ from torch_geometric .nn import GCNConv
7+ from torch_geometric_temporal .nn .recurrent import A3TGCN2
8+ from torch_geometric_temporal .dataset import PemsDatasetLoader
9+ import argparse
10+ from utils import *
11+
12+
13+ def parse_arguments ():
14+ parser = argparse .ArgumentParser (description = "Demo of index batching with PemsBay dataset" )
15+
16+ parser .add_argument (
17+ "-e" , "--epochs" , type = int , default = 100 , help = "The desired number of training epochs"
18+ )
19+ parser .add_argument (
20+ "-bs" , "--batch-size" , type = int , default = 64 , help = "The desired batch size"
21+ )
22+ parser .add_argument (
23+ "-g" , "--gpu" , type = str , default = "False" , help = "Should data be preprocessed and migrated directly to the GPU"
24+ )
25+ parser .add_argument (
26+ "-d" , "--debug" , type = str , default = "False" , help = "Print values for debugging"
27+ )
28+ return parser .parse_args ()
29+
30+ # Making the model
31+ class TemporalGNN (torch .nn .Module ):
32+ def __init__ (self , node_features , periods , batch_size ):
33+ super (TemporalGNN , self ).__init__ ()
34+ # Attention Temporal Graph Convolutional Cell
35+ self .tgnn = A3TGCN2 (in_channels = node_features , out_channels = 32 , periods = periods ,batch_size = batch_size ) # node_features=2, periods=12
36+ # Equals single-shot prediction
37+ self .linear = torch .nn .Linear (32 , periods )
38+
39+ def forward (self , x , edge_index ):
40+ """
41+ x = Node features for T time steps
42+ edge_index = Graph edge indices
43+ """
44+ h = self .tgnn (x , edge_index ) # x [b, 207, 2, 12] returns h [b, 207, 12]
45+ h = F .relu (h )
46+ h = self .linear (h )
47+ return h
48+
49+
50+
51+ def train (train_dataloader , val_dataloader , batch_size , epochs , edges , DEVICE , allGPU = False , debug = False ):
52+
53+ # Create model and optimizers
54+ model = TemporalGNN (node_features = 2 , periods = 12 , batch_size = batch_size ).to (DEVICE )
55+ optimizer = torch .optim .Adam (model .parameters (), lr = 0.001 )
56+ loss_fn = torch .nn .MSELoss ()
57+
58+ stats = []
59+ t_mse = []
60+ v_mse = []
61+
62+
63+ edges = edges .to (DEVICE )
64+ for epoch in range (epochs ):
65+ step = 0
66+ loss_list = []
67+ t1 = time .time ()
68+ i = 1
69+ total = len (train_dataloader )
70+ mae_total = 0
71+ for batch in train_dataloader :
72+ X_batch , y_batch = batch
73+
74+ # Need to permute based on expected input shape for ATGCN
75+ if allGPU :
76+ X_batch = X_batch .permute (0 , 2 , 3 , 1 )
77+ y_batch = y_batch [...,0 ].permute (0 , 2 , 1 )
78+ else :
79+ X_batch = X_batch .permute (0 , 2 , 3 , 1 ).to (DEVICE ).float ()
80+ y_batch = y_batch [...,0 ].permute (0 , 2 , 1 ).to (DEVICE ).float ()
81+
82+
83+
84+ y_hat = model (X_batch , edges ) # Get model predictions
85+ loss = loss_fn (y_hat , y_batch ) # Mean squared error #loss = torch.mean((y_hat-labels)**2) sqrt to change it to rmse
86+
87+ loss .backward ()
88+ optimizer .step ()
89+ optimizer .zero_grad ()
90+ step = step + 1
91+ loss_list .append (loss .item ())
92+
93+ if debug :
94+ print (f"Train Batch: { i } /{ total } " , end = "\r " )
95+ i += 1
96+
97+
98+ model .eval ()
99+ step = 0
100+ # Store for analysis
101+ total_loss = []
102+ i = 1
103+ total = len (val_dataloader )
104+ if debug :
105+ print (" " , end = "\r " )
106+ with torch .no_grad ():
107+ for batch in val_dataloader :
108+ X_batch , y_batch = batch
109+
110+
111+ # Need to permute based on expected input shape for ATGCN
112+ if allGPU :
113+ X_batch = X_batch .permute (0 , 2 , 3 , 1 )
114+ y_batch = y_batch [...,0 ].permute (0 , 2 , 1 )
115+ else :
116+ X_batch = X_batch .permute (0 , 2 , 3 , 1 ).to (DEVICE ).float ()
117+ y_batch = y_batch [...,0 ].permute (0 , 2 , 1 ).to (DEVICE ).float ()
118+
119+ # Get model predictions
120+ y_hat = model (X_batch , edges )
121+ # Mean squared error
122+ loss = loss_fn (y_hat , y_batch )
123+ total_loss .append (loss .item ())
124+
125+ mae_total += masked_mae_loss (y_hat , y_batch )
126+
127+ if debug :
128+ print (f"Val Batch: { i } /{ total } " , end = "\r " )
129+ i += 1
130+
131+ mae = mae_total / len (val_dataloader )
132+ t2 = time .time ()
133+ print ("Epoch {} time: {:.4f} train RMSE: {:.4f} Test MSE: {:.4f} Test MAE: {:.4f}" .format (epoch ,t2 - t1 , sum (loss_list )/ len (loss_list ), sum (total_loss )/ len (total_loss ), mae ))
134+ stats .append ([epoch , t2 - t1 , sum (loss_list )/ len (loss_list ), sum (total_loss )/ len (total_loss )])
135+ t_mse .append (sum (loss_list )/ len (loss_list ))
136+ v_mse .append (sum (total_loss )/ len (total_loss ))
137+ return min (t_mse ), min (v_mse )
138+
139+
140+
141+
142+
143+
144+
145+
146+ def main ():
147+ args = parse_arguments ()
148+ allGPU = args .gpu .lower () in ["true" , "y" , "t" , "yes" ]
149+ debug = args .debug .lower () in ["true" , "y" , "t" , "yes" ]
150+ batch_size = args .batch_size
151+ epochs = args .epochs
152+
153+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
154+ shuffle = True
155+
156+
157+ start = time .time ()
158+ p1 = time .time ()
159+ indexLoader = PemsDatasetLoader (index = True )
160+ if allGPU :
161+ train_dataloader , val_dataloader , test_dataloader , edges , edge_weights , mean , std = indexLoader .get_index_dataset (batch_size = batch_size , shuffle = shuffle , allGPU = 0 )
162+ else :
163+ train_dataloader , val_dataloader , test_dataloader , edges , edge_weights , mean , std = indexLoader .get_index_dataset (batch_size = batch_size , shuffle = shuffle )
164+ p2 = time .time ()
165+ t_mse , v_mse = train (train_dataloader , val_dataloader , batch_size , epochs , edges , device , debug = debug )
166+ end = time .time ()
167+
168+ print (f"Runtime: { round (end - start ,2 )} ; T-MSE: { round (t_mse , 3 )} ; V-MSE: { round (v_mse , 3 )} " )
169+
170+ if __name__ == "__main__" :
171+ main ()
0 commit comments