@@ -61,6 +61,10 @@ def train(train_dataloader, val_dataloader, mean, std, batch_size, epochs, edge_
6161
6262 edge_index = edge_index .to (DEVICE )
6363 edge_weight = edge_weight .to (DEVICE )
64+
65+ if not allGPU :
66+ mean = mean .to (DEVICE )
67+ std = std .to (DEVICE )
6468
6569 optimizer = torch .optim .Adam (model .parameters (), lr = 0.001 )
6670
@@ -76,10 +80,12 @@ def train(train_dataloader, val_dataloader, mean, std, batch_size, epochs, edge_
7680 for x , y in train_dataloader :
7781
7882 if allGPU :
79- x = x .permute (0 ,2 ,3 ,1 )
83+ x = x .permute (0 ,2 ,3 ,1 ).float ()
84+ y = y .float ()
85+
8086 else :
81- x = x .permute (0 ,2 ,3 ,1 ).to (DEVICE )
82- y = y .to (DEVICE )
87+ x = x .permute (0 ,2 ,3 ,1 ).to (DEVICE ). float ()
88+ y = y .to (DEVICE ). float ()
8389
8490 y_hat = model (x , edge_index , edge_weight ).squeeze ()
8591 loss = masked_mae_loss ((y_hat * std ) + mean , (y * std ) + mean )
@@ -101,11 +107,11 @@ def train(train_dataloader, val_dataloader, mean, std, batch_size, epochs, edge_
101107 with torch .no_grad ():
102108 for x , y in val_dataloader :
103109 if allGPU :
104- x = x .permute (0 ,2 ,3 ,1 )
105- # y = y[...,0]
110+ x = x .permute (0 ,2 ,3 ,1 ). float ()
111+ y = y . float ()
106112 else :
107- x = x .permute (0 ,2 ,3 ,1 ).to (DEVICE )
108- y = y .to (DEVICE )
113+ x = x .permute (0 ,2 ,3 ,1 ).to (DEVICE ). float ()
114+ y = y .to (DEVICE ). float ()
109115
110116 y_hat = model (x , edge_index , edge_weight ).squeeze ()
111117
0 commit comments