Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions src/climax/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
get_2d_sincos_pos_embed,
)

from .parallelpatchembed import ParallelVarPatchEmbed


class ClimaX(nn.Module):
"""Implements the ClimaX model as described in the paper,
Expand All @@ -29,6 +31,7 @@ class ClimaX(nn.Module):
mlp_ratio (float): ratio of mlp hidden dimension to embedding dimension
drop_path (float): stochastic depth rate
drop_rate (float): dropout rate
parallel_patch_embed (bool): whether to use parallel patch embedding
"""

def __init__(
Expand All @@ -43,19 +46,24 @@ def __init__(
mlp_ratio=4.0,
drop_path=0.1,
drop_rate=0.1,
parallel_patch_embed=False,
):
super().__init__()

# TODO: remove time_history parameter
self.img_size = img_size
self.patch_size = patch_size
self.default_vars = default_vars

self.parallel_patch_embed = parallel_patch_embed
# variable tokenization: separate embedding layer for each input variable
self.token_embeds = nn.ModuleList(
[PatchEmbed(img_size, patch_size, 1, embed_dim) for i in range(len(default_vars))]
)
self.num_patches = self.token_embeds[0].num_patches
if self.parallel_patch_embed:
self.token_embeds = ParallelVarPatchEmbed(len(default_vars), img_size, patch_size, embed_dim)
self.num_patches = self.token_embeds.num_patches
else:
self.token_embeds = nn.ModuleList(
[PatchEmbed(img_size, patch_size, 1, embed_dim) for i in range(len(default_vars))]
)
self.num_patches = self.token_embeds[0].num_patches

# variable embedding to denote which variable each token belongs to
# helps in aggregating variables
Expand Down Expand Up @@ -118,9 +126,14 @@ def initialize_weights(self):
self.var_embed.data.copy_(torch.from_numpy(var_embed).float().unsqueeze(0))

# token embedding layer
for i in range(len(self.token_embeds)):
w = self.token_embeds[i].proj.weight.data
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
if self.parallel_patch_embed:
for i in range(len(self.token_embeds.proj_weights)):
w = self.token_embeds.proj_weights[i].data
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
else:
for i in range(len(self.token_embeds)):
w = self.token_embeds[i].proj.weight.data
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)

# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
Expand Down Expand Up @@ -193,10 +206,14 @@ def forward_encoder(self, x: torch.Tensor, lead_times: torch.Tensor, variables):
# tokenize each variable separately
embeds = []
var_ids = self.get_var_ids(variables, x.device)
for i in range(len(var_ids)):
id = var_ids[i]
embeds.append(self.token_embeds[id](x[:, i : i + 1]))
x = torch.stack(embeds, dim=1) # B, V, L, D

if self.parallel_patch_embed:
x = self.token_embeds(x, var_ids) # B, V, L, D
else:
for i in range(len(var_ids)):
id = var_ids[i]
embeds.append(self.token_embeds[id](x[:, i : i + 1]))
x = torch.stack(embeds, dim=1) # B, V, L, D

# add variable embedding
var_embed = self.get_var_emb(self.var_embed, variables)
Expand Down
8 changes: 7 additions & 1 deletion src/climax/global_forecast/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
self.load_pretrained_weights(pretrained_path)

def load_pretrained_weights(self, pretrained_path):
if pretrained_path.startswith('http'):
if pretrained_path.startswith("http"):
checkpoint = torch.hub.load_state_dict_from_url(pretrained_path)
else:
checkpoint = torch.load(pretrained_path, map_location=torch.device("cpu"))
Expand All @@ -65,6 +65,12 @@ def load_pretrained_weights(self, pretrained_path):
interpolate_pos_embed(self.net, checkpoint_model, new_size=self.net.img_size)

state_dict = self.state_dict()
if self.net.parallel_patch_embed:
if "token_embeds.proj_weights" not in checkpoint_model.keys():
raise ValueError(
"Pretrained checkpoint does not have token_embeds.proj_weights for parallel processing. Please convert the checkpoints first or disable parallel patch_embed tokenization."
)

# checkpoint_keys = list(checkpoint_model.keys())
for k in list(checkpoint_model.keys()):
if "channel" in k:
Expand Down
75 changes: 75 additions & 0 deletions src/climax/parallelpatchembed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import math

import torch
import torch.nn.functional as F
from timm.models.layers.helpers import to_2tuple
from torch import nn


def _get_conv2d_weights(
in_channels,
out_channels,
kernel_size,
):
weight = torch.empty(out_channels, in_channels, *kernel_size)
return weight


def _get_conv2d_biases(out_channels):
bias = torch.empty(out_channels)
return bias


class ParallelVarPatchEmbed(nn.Module):
"""Variable to Patch Embedding with multiple variables in a single kernel. Key idea is to use Grouped Convolutions.

Args:
max_vars (int): Maximum number of variables
img_size (int): Image size
patch_size (int): Patch size
embed_dim (int): Embedding dimension
norm_layer (nn.Module, optional): Normalization layer. Defaults to None.
flatten (bool, optional): Flatten the output. Defaults to True.
"""

def __init__(self, max_vars: int, img_size, patch_size, embed_dim, norm_layer=None, flatten=True):
super().__init__()
self.max_vars = max_vars
self.img_size = to_2tuple(img_size)
self.patch_size = to_2tuple(patch_size)
self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten

grouped_weights = torch.stack(
[_get_conv2d_weights(1, embed_dim, self.patch_size) for _ in range(max_vars)], dim=0
)
self.proj_weights = nn.Parameter(grouped_weights)
grouped_biases = torch.stack([_get_conv2d_biases(embed_dim) for _ in range(max_vars)], dim=0)
self.proj_biases = nn.Parameter(grouped_biases)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self.reset_parameters()

def reset_parameters(self):
for idx in range(self.max_vars):
nn.init.kaiming_uniform_(self.proj_weights[idx], a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.proj_weights[idx])
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.proj_biases[idx], -bound, bound)

def forward(self, x, vars=None):
B, C, H, W = x.shape
if vars is None:
vars = range(self.max_vars)
weights = self.proj_weights[vars].flatten(0, 1)
biases = self.proj_biases[vars].flatten(0, 1)

groups = len(vars)
proj = F.conv2d(x, weights, biases, groups=groups, stride=self.patch_size)
if self.flatten:
proj = proj.reshape(B, groups, -1, *proj.shape[-2:])
proj = proj.flatten(3).transpose(2, 3)

proj = self.norm(proj)
return proj
16 changes: 12 additions & 4 deletions src/climax/regional_forecast/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
self.load_pretrained_weights(pretrained_path)

def load_pretrained_weights(self, pretrained_path):
if pretrained_path.startswith('http'):
if pretrained_path.startswith("http"):
checkpoint = torch.hub.load_state_dict_from_url(pretrained_path)
else:
checkpoint = torch.load(pretrained_path, map_location=torch.device("cpu"))
Expand All @@ -66,6 +66,12 @@ def load_pretrained_weights(self, pretrained_path):
interpolate_pos_embed(self.net, checkpoint_model, new_size=self.net.img_size)

state_dict = self.state_dict()
if self.net.parallel_patch_embed:
if "token_embeds.proj_weights" not in checkpoint_model.keys():
raise ValueError(
"Pretrained checkpoint does not have token_embeds.proj_weights for parallel processing. Please convert the checkpoints first or disable parallel patch_embed tokenization."
)

# checkpoint_keys = list(checkpoint_model.keys())
for k in list(checkpoint_model.keys()):
if "channel" in k:
Expand Down Expand Up @@ -102,7 +108,9 @@ def get_patch_size(self):
def training_step(self, batch: Any, batch_idx: int):
x, y, lead_times, variables, out_variables, region_info = batch

loss_dict, _ = self.net.forward(x, y, lead_times, variables, out_variables, [lat_weighted_mse], lat=self.lat, region_info=region_info)
loss_dict, _ = self.net.forward(
x, y, lead_times, variables, out_variables, [lat_weighted_mse], lat=self.lat, region_info=region_info
)
loss_dict = loss_dict[0]
for var in loss_dict.keys():
self.log(
Expand Down Expand Up @@ -136,7 +144,7 @@ def validation_step(self, batch: Any, batch_idx: int):
lat=self.lat,
clim=self.val_clim,
log_postfix=log_postfix,
region_info=region_info
region_info=region_info,
)

loss_dict = {}
Expand Down Expand Up @@ -175,7 +183,7 @@ def test_step(self, batch: Any, batch_idx: int):
lat=self.lat,
clim=self.test_clim,
log_postfix=log_postfix,
region_info=region_info
region_info=region_info,
)

loss_dict = {}
Expand Down
36 changes: 36 additions & 0 deletions tests/test_parallel_patch_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch

from climax.arch import ClimaX


def test_parallel_patch_embed():
vars = tuple(["a", "b", "c"])
x = torch.rand(4, len(vars), 32, 64)
serial_model = ClimaX(vars, img_size=[32, 64], patch_size=4, embed_dim=128, parallel_patch_embed=False)
parallel_model = ClimaX(vars, img_size=[32, 64], patch_size=4, embed_dim=128, parallel_patch_embed=True)
assert serial_model.token_embeds[0].num_patches == parallel_model.num_patches
assert len(serial_model.token_embeds) == len(parallel_model.token_embeds.proj_weights)
for i in range(len(serial_model.token_embeds)):
serial_model.token_embeds[i].proj.weight.data = parallel_model.token_embeds.proj_weights[i].data
serial_model.token_embeds[i].proj.bias.data = parallel_model.token_embeds.proj_biases[i].data

test_vars = [["a"], ["b"], ["c"], ["a", "b"], ["a", "c"], ["b", "c"], ["a", "b", "c"]]
for vars in test_vars:
vars = tuple(vars)
serial_var_ids = serial_model.get_var_ids(vars, x.device)
parallel_var_ids = parallel_model.get_var_ids(vars, x.device)
assert all(serial_var_ids == parallel_var_ids)
x = torch.rand(4, len(vars), 32, 64)
parallel_embed = parallel_model.token_embeds(x, parallel_var_ids)
embeds = []
for i in range(len(serial_var_ids)):
id = serial_var_ids[i]
embeds.append(serial_model.token_embeds[id](x[:, i : i + 1, :, :]))

serial_embed = torch.stack(embeds, dim=1)
assert parallel_embed.shape == serial_embed.shape
assert torch.allclose(serial_embed, parallel_embed)


if __name__ == "__main__":
test_parallel_patch_embed()