Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
parallel patch embed for faster execution
  • Loading branch information
rejuvyesh committed Apr 24, 2023
commit ab9ea6e39ca095a6967068c6a31b10e77c3d0f9d
40 changes: 28 additions & 12 deletions src/climax/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
get_2d_sincos_pos_embed,
)

from .parallelpatchembed import ParallelVarPatchEmbed


class ClimaX(nn.Module):
"""Implements the ClimaX model as described in the paper,
Expand Down Expand Up @@ -43,19 +45,24 @@ def __init__(
mlp_ratio=4.0,
drop_path=0.1,
drop_rate=0.1,
parallel_patch_embed=False,
):
super().__init__()

# TODO: remove time_history parameter
self.img_size = img_size
self.patch_size = patch_size
self.default_vars = default_vars

self.parallel_patch_embed = parallel_patch_embed
# variable tokenization: separate embedding layer for each input variable
self.token_embeds = nn.ModuleList(
[PatchEmbed(img_size, patch_size, 1, embed_dim) for i in range(len(default_vars))]
)
self.num_patches = self.token_embeds[0].num_patches
if self.parallel_patch_embed:
self.token_embeds = ParallelVarPatchEmbed(len(default_vars), img_size, patch_size, embed_dim)
self.num_patches = self.token_embeds.num_patches
else:
self.token_embeds = nn.ModuleList(
[PatchEmbed(img_size, patch_size, 1, embed_dim) for i in range(len(default_vars))]
)
self.num_patches = self.token_embeds[0].num_patches

# variable embedding to denote which variable each token belongs to
# helps in aggregating variables
Expand Down Expand Up @@ -118,9 +125,14 @@ def initialize_weights(self):
self.var_embed.data.copy_(torch.from_numpy(var_embed).float().unsqueeze(0))

# token embedding layer
for i in range(len(self.token_embeds)):
w = self.token_embeds[i].proj.weight.data
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
if self.parallel_patch_embed:
for i in range(len(self.token_embeds.proj_weights)):
w = self.token_embeds.proj_weights[i].data
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
else:
for i in range(len(self.token_embeds)):
w = self.token_embeds[i].proj.weight.data
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)

# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
Expand Down Expand Up @@ -193,10 +205,14 @@ def forward_encoder(self, x: torch.Tensor, lead_times: torch.Tensor, variables):
# tokenize each variable separately
embeds = []
var_ids = self.get_var_ids(variables, x.device)
for i in range(len(var_ids)):
id = var_ids[i]
embeds.append(self.token_embeds[id](x[:, i : i + 1]))
x = torch.stack(embeds, dim=1) # B, V, L, D

if self.parallel_patch_embed:
x = self.token_embeds(x, var_ids) # B, V, L, D
else:
for i in range(len(var_ids)):
id = var_ids[i]
embeds.append(self.token_embeds[id](x[:, i : i + 1]))
x = torch.stack(embeds, dim=1) # B, V, L, D

# add variable embedding
var_embed = self.get_var_emb(self.var_embed, variables)
Expand Down
75 changes: 75 additions & 0 deletions src/climax/parallelpatchembed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import math

import torch
import torch.nn.functional as F
from timm.models.layers.helpers import to_2tuple
from torch import nn


def _get_conv2d_weights(
in_channels,
out_channels,
kernel_size,
):
weight = torch.empty(out_channels, in_channels, *kernel_size)
return weight


def _get_conv2d_biases(out_channels):
bias = torch.empty(out_channels)
return bias


class ParallelVarPatchEmbed(nn.Module):
"""Variable to Patch Embedding with multiple variables in a single kernel. Key idea is to use Grouped Convolutions.

Args:
max_vars (int): Maximum number of variables
img_size (int): Image size
patch_size (int): Patch size
embed_dim (int): Embedding dimension
norm_layer (nn.Module, optional): Normalization layer. Defaults to None.
flatten (bool, optional): Flatten the output. Defaults to True.
"""

def __init__(self, max_vars: int, img_size, patch_size, embed_dim, norm_layer=None, flatten=True):
super().__init__()
self.max_vars = max_vars
self.img_size = to_2tuple(img_size)
self.patch_size = to_2tuple(patch_size)
self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten

grouped_weights = torch.stack(
[_get_conv2d_weights(1, embed_dim, self.patch_size) for _ in range(max_vars)], dim=0
)
self.proj_weights = nn.Parameter(grouped_weights)
grouped_biases = torch.stack([_get_conv2d_biases(embed_dim) for _ in range(max_vars)], dim=0)
self.proj_biases = nn.Parameter(grouped_biases)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self.reset_parameters()

def reset_parameters(self):
for idx in range(self.max_vars):
nn.init.kaiming_uniform_(self.proj_weights[idx], a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.proj_weights[idx])
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.proj_biases[idx], -bound, bound)

def forward(self, x, vars=None):
B, C, H, W = x.shape
if vars is None:
vars = range(self.max_vars)
weights = self.proj_weights[vars].flatten(0, 1)
biases = self.proj_biases[vars].flatten(0, 1)

groups = len(vars)
proj = F.conv2d(x, weights, biases, groups=groups, stride=self.patch_size)
if self.flatten:
proj = proj.reshape(B, groups, -1, *proj.shape[-2:])
proj = proj.flatten(3).transpose(2, 3)

proj = self.norm(proj)
return proj
33 changes: 33 additions & 0 deletions tests/test_parallel_patch_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch

from climax.arch import ClimaX


def test_parallel_patch_embed():
vars = tuple(["a", "b", "c"])
x = torch.rand(4, len(vars), 32, 64)
serial_model = ClimaX(vars, img_size=[32, 64], patch_size=4, embed_dim=128, parallel_patch_embed=False)
parallel_model = ClimaX(vars, img_size=[32, 64], patch_size=4, embed_dim=128, parallel_patch_embed=True)
assert serial_model.token_embeds[0].num_patches == parallel_model.num_patches
assert len(serial_model.token_embeds) == len(parallel_model.token_embeds.proj_weights)
for i in range(len(serial_model.token_embeds)):
serial_model.token_embeds[i].proj.weight.data = parallel_model.token_embeds.proj_weights[i].data
serial_model.token_embeds[i].proj.bias.data = parallel_model.token_embeds.proj_biases[i].data

serial_var_ids = serial_model.get_var_ids(vars, x.device)
parallel_var_ids = parallel_model.get_var_ids(vars, x.device)
assert all(serial_var_ids == parallel_var_ids)

parallel_embed = parallel_model.token_embeds(x, parallel_var_ids)
embeds = []
for i in range(len(serial_var_ids)):
id = serial_var_ids[i]
embeds.append(serial_model.token_embeds[id](x[:, i : i + 1, :, :]))

serial_embed = torch.stack(embeds, dim=1)
assert parallel_embed.shape == serial_embed.shape
assert torch.allclose(serial_embed, parallel_embed)


if __name__ == "__main__":
test_parallel_patch_embed()