1+ import torch
2+ import torch .nn as nn
3+ import torch .nn .functional as F
4+ import numpy as np
5+ from torch_geometric_temporal .dataset import PemsDatasetLoader
6+ from torch_geometric_temporal .nn .recurrent import TGCN2
7+ import argparse
8+ import time
9+
10+
11+ def parse_arguments ():
12+ parser = argparse .ArgumentParser (description = "Demo of index batching with PemsBay dataset" )
13+
14+ parser .add_argument (
15+ "-e" , "--epochs" , type = int , default = 100 , help = "The desired number of training epochs"
16+ )
17+ parser .add_argument (
18+ "-bs" , "--batch-size" , type = int , default = 64 , help = "The desired batch size"
19+ )
20+ parser .add_argument (
21+ "-g" , "--gpu" , type = str , default = "False" , help = "Should data be preprocessed and migrated directly to the GPU"
22+ )
23+ parser .add_argument (
24+ "-d" , "--debug" , type = str , default = "False" , help = "Print values for debugging"
25+ )
26+ return parser .parse_args ()
27+
28+ # --- Model ---
29+ class BatchedTGCN (nn .Module ):
30+ def __init__ (self , in_channels , hidden_dim , out_channels ):
31+ super ().__init__ ()
32+ self .tgnn = TGCN2 (in_channels , hidden_dim , 1 )
33+ self .linear = nn .Linear (hidden_dim , out_channels )
34+
35+ def forward (self , x , edge_index , edge_weight ):
36+ # x: [B, N, F, T]
37+ B , N , Fin , T = x .shape
38+
39+ h = None
40+ output_sequence = []
41+ for t in range (T ):
42+ h = self .tgnn (x [..., t ], edge_index , edge_weight , h ) # h: [B, N, hidden_dim]
43+ h_t = F .relu (h )
44+ out_t = self .linear (h_t ).unsqueeze (1 ) # [B, N, output_dim] → [B, 1, N, output_dim]
45+ output_sequence .append (out_t )
46+
47+ return torch .cat (output_sequence , dim = 1 ) # [B, T, N, output_dim]
48+
49+ def masked_mae_loss (y_pred , y_true ):
50+ mask = (y_true != 0 ).float ()
51+ mask /= mask .mean ()
52+ loss = torch .abs (y_pred - y_true )
53+ loss = loss * mask
54+ # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
55+ loss [loss != loss ] = 0
56+ return loss .mean ()
57+
58+ def train (train_dataloader , val_dataloader , mean , std , batch_size , epochs , edge_index , edge_weight , DEVICE , allGPU = False , debug = False ):
59+ # currently predicting speed and time of day. This can be changed to just predict speed.
60+ model = BatchedTGCN (in_channels = 2 , out_channels = 2 , hidden_dim = 32 , ).to (DEVICE )
61+
62+ if not allGPU :
63+ mean = mean .to (DEVICE )
64+ std = std .to (DEVICE )
65+
66+ edge_index = edge_index .to (DEVICE )
67+ edge_weight = edge_weight .to (DEVICE )
68+
69+ optimizer = torch .optim .Adam (model .parameters (), lr = 0.001 )
70+
71+ t_maes = []
72+ v_maes = []
73+
74+ for epoch in range (epochs ):
75+ model .train ()
76+ epoch_loss = []
77+ i = 1
78+ total = len (train_dataloader )
79+ t1 = time .time ()
80+ for x , y in train_dataloader :
81+
82+ if allGPU :
83+ x = x .permute (0 ,2 ,3 ,1 ).float ()
84+ y = y .float ()
85+ else :
86+ x = x .permute (0 ,2 ,3 ,1 ).to (DEVICE ).float ()
87+ y = y .to (DEVICE ).float ()
88+
89+ y_hat = model (x , edge_index , edge_weight ).squeeze ()
90+ loss = masked_mae_loss ((y_hat * std ) + mean , (y * std ) + mean )
91+
92+ loss .backward ()
93+ optimizer .step ()
94+ optimizer .zero_grad ()
95+ epoch_loss .append (loss .item ())
96+ if debug :
97+ print (f"Train Batch: { i } /{ total } " , end = "\r " )
98+ i += 1
99+
100+ if debug :
101+ print (" " , end = "\r " )
102+ model .eval ()
103+ test_loss = []
104+ i = 1
105+ total = len (val_dataloader )
106+ with torch .no_grad ():
107+ for x , y in val_dataloader :
108+ if allGPU :
109+ x = x .permute (0 ,2 ,3 ,1 ).float ()
110+ y = y .float ()
111+ else :
112+ x = x .permute (0 ,2 ,3 ,1 ).to (DEVICE ).float ()
113+ y = y .to (DEVICE ).float ()
114+
115+ y_hat = model (x , edge_index , edge_weight ).squeeze ()
116+
117+ loss = masked_mae_loss ((y_hat * std ) + mean , (y * std ) + mean )
118+ test_loss .append (loss .item ())
119+
120+ if debug :
121+ print (f"Test Batch: { i } /{ total } " , end = "\r " )
122+ i += 1
123+
124+
125+ t_maes .append (np .mean (epoch_loss ))
126+ v_maes .append (np .mean (test_loss ))
127+ t2 = time .time ()
128+
129+ print (f"Epoch { epoch + 1 } /{ epochs } , Runtime: { t2 - t1 } , Train Loss: { np .mean (epoch_loss ):.4f} , Val Loss: { np .mean (test_loss ):.4f} " , flush = True )
130+
131+ return min (t_maes ).item (), min (v_maes ).item ()
132+
133+ def main ():
134+ args = parse_arguments ()
135+ allGPU = args .gpu .lower () in ["true" , "y" , "t" , "yes" ]
136+ debug = args .debug .lower () in ["true" , "y" , "t" , "yes" ]
137+ batch_size = args .batch_size
138+ epochs = args .epochs
139+
140+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
141+ shuffle = True
142+
143+ start = time .time ()
144+ p1 = time .time ()
145+ indexLoader = PemsDatasetLoader (index = True )
146+ if allGPU :
147+ train_dataloader , val_dataloader , test_dataloader , edges , edge_weights , mean , std = indexLoader .get_index_dataset (batch_size = batch_size , shuffle = shuffle , allGPU = 0 )
148+ else :
149+ train_dataloader , val_dataloader , test_dataloader , edges , edge_weights , mean , std = indexLoader .get_index_dataset (batch_size = batch_size , shuffle = shuffle )
150+ p2 = time .time ()
151+ t_mse , v_mse = train (train_dataloader , val_dataloader , mean , std , batch_size , epochs , edges , edge_weights , device , debug = debug )
152+ end = time .time ()
153+
154+ print (f"Runtime: { round (end - start ,2 )} ; T-MAE: { round (t_mse , 5 )} ; V-MAE: { round (v_mse , 5 )} " )
155+
156+ if __name__ == "__main__" :
157+ main ()
0 commit comments