Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
create PersonalizedCLIPEmbedder
  • Loading branch information
vicgalle committed Oct 8, 2022
commit 61baaa76fb367216be2c061add1605e81e091197
235 changes: 178 additions & 57 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import torch
import torch.nn as nn
import torch.optim as optim
from functools import partial
import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
from transformers import CLIPTokenizer, CLIPTextModel, CLIPProcessor, CLIPModel
import kornia

from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from ldm.modules.x_transformer import (
Encoder,
TransformerWrapper,
) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test


class AbstractEncoder(nn.Module):
Expand All @@ -17,9 +21,8 @@ def encode(self, *args, **kwargs):
raise NotImplementedError



class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
def __init__(self, embed_dim, n_classes=1000, key="class"):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
Expand All @@ -35,11 +38,15 @@ def forward(self, batch, key=None):

class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""

def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
)

def forward(self, tokens):
tokens = tokens.to(self.device) # meh
Expand All @@ -51,18 +58,27 @@ def encode(self, x):


class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""

def __init__(self, device="cuda", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements

self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length

def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
return tokens

Expand All @@ -79,20 +95,32 @@ def decode(self, text):

class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):

def __init__(
self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device="cuda",
use_tokenizer=True,
embedding_dropout=0.0,
):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout,
)

def forward(self, text):
if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device)
tokens = self.tknz_fn(text) # .to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True)
Expand All @@ -104,59 +132,134 @@ def encode(self, text):


class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False):
def __init__(
self,
n_stages=1,
method="bilinear",
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False,
):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
assert method in [
"nearest",
"linear",
"bilinear",
"trilinear",
"bicubic",
"area",
]
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
print(
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
)
self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias)

def forward(self,x):
def forward(self, x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)


if self.remap_output:
x = self.channel_mapper(x)
return x

def encode(self, x):
return self(x)

class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):

class PersonalizedCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder with the option of personalization with an aesthetic embedding"""

def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
T=5,
lr=0.0001,
aesthetic_embedding_path="aesthetic_embeddings/cloudcore.pt",
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.full_clip_processor = CLIPProcessor.from_pretrained(version)
self.device = device
self.max_length = max_length
self.freeze()

self.T = T
self.lr = lr
self.aesthetic_embedding_path = aesthetic_embedding_path
# self.freeze()

def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False

def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)

z = outputs.last_hidden_state
return z
with torch.enable_grad():
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)

if text[0] != "":

# This is the model to be personalized
full_clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-large-patch14",
).to(self.device)

# We load the aesthetic embeddings
image_embs = torch.load(self.aesthetic_embedding_path).to(self.device)

# We compute the loss (similarity between the prompt embedding and the aesthetic embedding)
image_embs /= image_embs.norm(dim=-1, keepdim=True)
text_embs = full_clip_model.get_text_features(tokens)
text_embs /= text_embs.norm(dim=-1, keepdim=True)
sim = text_embs @ image_embs.T
loss = -sim
print(loss)

# lr = 0.0001

# We optimize the model to maximize the similarity
optimizer = optim.Adam(
full_clip_model.text_model.parameters(), lr=self.lr
)

# T = 0
for i in range(self.T):
optimizer.zero_grad()

loss.mean().backward()
optimizer.step()

text_embs = full_clip_model.get_text_features(tokens)
text_embs /= text_embs.norm(dim=-1, keepdim=True)
sim = text_embs @ image_embs.T
loss = -sim
print(loss)

z = full_clip_model.text_model(input_ids=tokens).last_hidden_state

else:
z = self.transformer(input_ids=tokens).last_hidden_state

self.freeze()
return z

def encode(self, text):
return self(text)
Expand All @@ -166,7 +269,15 @@ class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):

def __init__(
self,
version="ViT-L/14",
device="cuda",
max_length=77,
n_repeat=1,
normalize=True,
):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.device = device
Expand All @@ -188,37 +299,46 @@ def forward(self, text):

def encode(self, text):
z = self(text)
if z.ndim==2:
if z.ndim == 2:
z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
z = repeat(z, "b 1 d -> b k d", k=self.n_repeat)
return z


class FrozenClipImageEmbedder(nn.Module):
"""
Uses the CLIP image encoder.
"""
Uses the CLIP image encoder.
"""

def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
):
self,
model,
jit=False,
device="cuda" if torch.cuda.is_available() else "cpu",
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)

self.antialias = antialias

self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)

def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
Expand All @@ -230,5 +350,6 @@ def forward(self, x):

if __name__ == "__main__":
from ldm.util import count_params
model = FrozenCLIPEmbedder()
count_params(model, verbose=True)

model = PersonalizedCLIPEmbedder()
count_params(model, verbose=True)