Skip to content

Commit 48d89e9

Browse files
pems with index-batching for a3tgcn
1 parent 58867d4 commit 48d89e9

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import numpy as np
2+
import time
3+
import csv
4+
import torch
5+
import torch.nn.functional as F
6+
from torch_geometric.nn import GCNConv
7+
from torch_geometric_temporal.nn.recurrent import A3TGCN2
8+
from torch_geometric_temporal.dataset import PemsDatasetLoader
9+
import argparse
10+
from utils import *
11+
12+
13+
def parse_arguments():
14+
parser = argparse.ArgumentParser(description="Demo of index batching with PemsBay dataset")
15+
16+
parser.add_argument(
17+
"-e", "--epochs", type=int, default=100, help="The desired number of training epochs"
18+
)
19+
parser.add_argument(
20+
"-bs", "--batch-size", type=int, default=64, help="The desired batch size"
21+
)
22+
parser.add_argument(
23+
"-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU"
24+
)
25+
parser.add_argument(
26+
"-d", "--debug", type=str, default="False", help="Print values for debugging"
27+
)
28+
return parser.parse_args()
29+
30+
# Making the model
31+
class TemporalGNN(torch.nn.Module):
32+
def __init__(self, node_features, periods, batch_size):
33+
super(TemporalGNN, self).__init__()
34+
# Attention Temporal Graph Convolutional Cell
35+
self.tgnn = A3TGCN2(in_channels=node_features, out_channels=32, periods=periods,batch_size=batch_size) # node_features=2, periods=12
36+
# Equals single-shot prediction
37+
self.linear = torch.nn.Linear(32, periods)
38+
39+
def forward(self, x, edge_index):
40+
"""
41+
x = Node features for T time steps
42+
edge_index = Graph edge indices
43+
"""
44+
h = self.tgnn(x, edge_index) # x [b, 207, 2, 12] returns h [b, 207, 12]
45+
h = F.relu(h)
46+
h = self.linear(h)
47+
return h
48+
49+
50+
51+
def train(train_dataloader, val_dataloader, batch_size, epochs, edges, DEVICE, allGPU=False, debug=False):
52+
53+
# Create model and optimizers
54+
model = TemporalGNN(node_features=2, periods=12, batch_size=batch_size).to(DEVICE)
55+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
56+
loss_fn = torch.nn.MSELoss()
57+
58+
stats = []
59+
t_mse = []
60+
v_mse = []
61+
62+
63+
edges = edges.to(DEVICE)
64+
for epoch in range(epochs):
65+
step = 0
66+
loss_list = []
67+
t1 = time.time()
68+
i = 1
69+
total = len(train_dataloader)
70+
mae_total = 0
71+
for batch in train_dataloader:
72+
X_batch, y_batch = batch
73+
74+
# Need to permute based on expected input shape for ATGCN
75+
if allGPU:
76+
X_batch = X_batch.permute(0, 2, 3, 1)
77+
y_batch = y_batch[...,0].permute(0, 2, 1)
78+
else:
79+
X_batch = X_batch.permute(0, 2, 3, 1).to(DEVICE).float()
80+
y_batch = y_batch[...,0].permute(0, 2, 1).to(DEVICE).float()
81+
82+
83+
84+
y_hat = model(X_batch, edges) # Get model predictions
85+
loss = loss_fn(y_hat, y_batch) # Mean squared error #loss = torch.mean((y_hat-labels)**2) sqrt to change it to rmse
86+
87+
loss.backward()
88+
optimizer.step()
89+
optimizer.zero_grad()
90+
step= step+ 1
91+
loss_list.append(loss.item())
92+
93+
if debug:
94+
print(f"Train Batch: {i}/{total}", end="\r")
95+
i+=1
96+
97+
98+
model.eval()
99+
step = 0
100+
# Store for analysis
101+
total_loss = []
102+
i = 1
103+
total = len(val_dataloader)
104+
if debug:
105+
print(" ", end="\r")
106+
with torch.no_grad():
107+
for batch in val_dataloader:
108+
X_batch, y_batch = batch
109+
110+
111+
# Need to permute based on expected input shape for ATGCN
112+
if allGPU:
113+
X_batch = X_batch.permute(0, 2, 3, 1)
114+
y_batch = y_batch[...,0].permute(0, 2, 1)
115+
else:
116+
X_batch = X_batch.permute(0, 2, 3, 1).to(DEVICE).float()
117+
y_batch = y_batch[...,0].permute(0, 2, 1).to(DEVICE).float()
118+
119+
# Get model predictions
120+
y_hat = model(X_batch, edges)
121+
# Mean squared error
122+
loss = loss_fn(y_hat, y_batch)
123+
total_loss.append(loss.item())
124+
125+
mae_total += masked_mae_loss(y_hat, y_batch)
126+
127+
if debug:
128+
print(f"Val Batch: {i}/{total}", end="\r")
129+
i += 1
130+
131+
mae = mae_total / len(val_dataloader)
132+
t2 = time.time()
133+
print("Epoch {} time: {:.4f} train RMSE: {:.4f} Test MSE: {:.4f} Test MAE: {:.4f}".format(epoch,t2 - t1, sum(loss_list)/len(loss_list), sum(total_loss)/len(total_loss), mae))
134+
stats.append([epoch, t2-t1, sum(loss_list)/len(loss_list), sum(total_loss)/len(total_loss)])
135+
t_mse.append(sum(loss_list)/len(loss_list))
136+
v_mse.append(sum(total_loss)/len(total_loss))
137+
return min(t_mse), min(v_mse)
138+
139+
140+
141+
142+
143+
144+
145+
146+
def main():
147+
args = parse_arguments()
148+
allGPU = args.gpu.lower() in ["true", "y", "t", "yes"]
149+
debug = args.debug.lower() in ["true", "y", "t", "yes"]
150+
batch_size = args.batch_size
151+
epochs = args.epochs
152+
153+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154+
shuffle= True
155+
156+
157+
start = time.time()
158+
p1 = time.time()
159+
indexLoader = PemsDatasetLoader(index=True)
160+
if allGPU:
161+
train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = indexLoader.get_index_dataset(batch_size=batch_size, shuffle=shuffle, allGPU=0)
162+
else:
163+
train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = indexLoader.get_index_dataset(batch_size=batch_size, shuffle=shuffle)
164+
p2 = time.time()
165+
t_mse, v_mse = train(train_dataloader, val_dataloader, batch_size, epochs, edges, device, debug=debug)
166+
end = time.time()
167+
168+
print(f"Runtime: {round(end - start,2)}; T-MSE: {round(t_mse, 3)}; V-MSE: {round(v_mse, 3)}")
169+
170+
if __name__ == "__main__":
171+
main()

0 commit comments

Comments
 (0)