-
Notifications
You must be signed in to change notification settings - Fork 105
Expand file tree
/
Copy pathmodule.py
More file actions
232 lines (199 loc) · 7.65 KB
/
module.py
File metadata and controls
232 lines (199 loc) · 7.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# credits: https://github.com/ashleve/lightning-hydra-template/blob/main/src/models/mnist_module.py
from typing import Any
import torch
from pytorch_lightning import LightningModule
from torchvision.transforms import transforms
from climax.arch import ClimaX
from climax.utils.lr_scheduler import LinearWarmupCosineAnnealingLR
from climax.utils.metrics import (
lat_weighted_acc,
lat_weighted_mse,
lat_weighted_mse_val,
lat_weighted_rmse,
)
from climax.utils.pos_embed import interpolate_pos_embed
class GlobalForecastModule(LightningModule):
"""Lightning module for global forecasting with the ClimaX model.
Args:
net (ClimaX): ClimaX model.
pretrained_path (str, optional): Path to pre-trained checkpoint.
lr (float, optional): Learning rate.
beta_1 (float, optional): Beta 1 for AdamW.
beta_2 (float, optional): Beta 2 for AdamW.
weight_decay (float, optional): Weight decay for AdamW.
warmup_epochs (int, optional): Number of warmup epochs.
max_epochs (int, optional): Number of total epochs.
warmup_start_lr (float, optional): Starting learning rate for warmup.
eta_min (float, optional): Minimum learning rate.
"""
def __init__(
self,
net: ClimaX,
pretrained_path: str = "",
lr: float = 5e-4,
beta_1: float = 0.9,
beta_2: float = 0.99,
weight_decay: float = 1e-5,
warmup_epochs: int = 10000,
max_epochs: int = 200000,
warmup_start_lr: float = 1e-8,
eta_min: float = 1e-8,
):
super().__init__()
self.save_hyperparameters(logger=False, ignore=["net"])
self.net = net
if len(pretrained_path) > 0:
self.load_pretrained_weights(pretrained_path)
def load_pretrained_weights(self, pretrained_path):
if pretrained_path.startswith("http"):
checkpoint = torch.hub.load_state_dict_from_url(pretrained_path)
else:
checkpoint = torch.load(pretrained_path, map_location=torch.device("cpu"))
print("Loading pre-trained checkpoint from: %s" % pretrained_path)
checkpoint_model = checkpoint["state_dict"]
# interpolate positional embedding
interpolate_pos_embed(self.net, checkpoint_model, new_size=self.net.img_size)
state_dict = self.state_dict()
if self.net.parallel_patch_embed:
if "token_embeds.proj_weights" not in checkpoint_model.keys():
raise ValueError(
"Pretrained checkpoint does not have token_embeds.proj_weights for parallel processing. Please convert the checkpoints first or disable parallel patch_embed tokenization."
)
# checkpoint_keys = list(checkpoint_model.keys())
for k in list(checkpoint_model.keys()):
if "channel" in k:
checkpoint_model[k.replace("channel", "var")] = checkpoint_model[k]
del checkpoint_model[k]
for k in list(checkpoint_model.keys()):
if k not in state_dict.keys() or checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# load pre-trained model
msg = self.load_state_dict(checkpoint_model, strict=False)
print(msg)
def set_denormalization(self, mean, std):
self.denormalization = transforms.Normalize(mean, std)
def set_lat_lon(self, lat, lon):
self.lat = lat
self.lon = lon
def set_pred_range(self, r):
self.pred_range = r
def set_val_clim(self, clim):
self.val_clim = clim
def set_test_clim(self, clim):
self.test_clim = clim
def training_step(self, batch: Any, batch_idx: int):
x, y, lead_times, variables, out_variables = batch
loss_dict, _ = self.net.forward(x, y, lead_times, variables, out_variables, [lat_weighted_mse], lat=self.lat)
loss_dict = loss_dict[0]
for var in loss_dict.keys():
self.log(
"train/" + var,
loss_dict[var],
on_step=True,
on_epoch=False,
prog_bar=True,
)
loss = loss_dict["loss"]
return loss
def validation_step(self, batch: Any, batch_idx: int):
x, y, lead_times, variables, out_variables = batch
if self.pred_range < 24:
log_postfix = f"{self.pred_range}_hours"
else:
days = int(self.pred_range / 24)
log_postfix = f"{days}_days"
all_loss_dicts = self.net.evaluate(
x,
y,
lead_times,
variables,
out_variables,
transform=self.denormalization,
metrics=[lat_weighted_mse_val, lat_weighted_rmse, lat_weighted_acc],
lat=self.lat,
clim=self.val_clim,
log_postfix=log_postfix,
)
loss_dict = {}
for d in all_loss_dicts:
for k in d.keys():
loss_dict[k] = d[k]
for var in loss_dict.keys():
self.log(
"val/" + var,
loss_dict[var],
on_step=False,
on_epoch=True,
prog_bar=False,
sync_dist=True,
)
return loss_dict
def test_step(self, batch: Any, batch_idx: int):
x, y, lead_times, variables, out_variables = batch
if self.pred_range < 24:
log_postfix = f"{self.pred_range}_hours"
else:
days = int(self.pred_range / 24)
log_postfix = f"{days}_days"
all_loss_dicts = self.net.evaluate(
x,
y,
lead_times,
variables,
out_variables,
transform=self.denormalization,
metrics=[lat_weighted_mse_val, lat_weighted_rmse, lat_weighted_acc],
lat=self.lat,
clim=self.test_clim,
log_postfix=log_postfix,
)
loss_dict = {}
for d in all_loss_dicts:
for k in d.keys():
loss_dict[k] = d[k]
for var in loss_dict.keys():
self.log(
"test/" + var,
loss_dict[var],
on_step=False,
on_epoch=True,
prog_bar=False,
sync_dist=True,
)
return loss_dict
def configure_optimizers(self):
decay = []
no_decay = []
for name, m in self.named_parameters():
if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name:
no_decay.append(m)
else:
decay.append(m)
optimizer = torch.optim.AdamW(
[
{
"params": decay,
"lr": self.hparams.lr,
"betas": (self.hparams.beta_1, self.hparams.beta_2),
"weight_decay": self.hparams.weight_decay,
},
{
"params": no_decay,
"lr": self.hparams.lr,
"betas": (self.hparams.beta_1, self.hparams.beta_2),
"weight_decay": 0,
},
]
)
lr_scheduler = LinearWarmupCosineAnnealingLR(
optimizer,
self.hparams.warmup_epochs,
self.hparams.max_epochs,
self.hparams.warmup_start_lr,
self.hparams.eta_min,
)
scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1}
return {"optimizer": optimizer, "lr_scheduler": scheduler}