Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ldm/models/diffusion/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, model, schedule="linear", **kwargs):

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
if attr.device != torch.device("cuda") and torch.cuda.is_available():
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)

Expand Down Expand Up @@ -238,4 +238,4 @@ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unco
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
return x_dec
return x_dec
2 changes: 1 addition & 1 deletion ldm/models/diffusion/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, model, schedule="linear", **kwargs):

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
if attr.device != torch.device("cuda") and torch.cuda.is_available():
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)

Expand Down
10 changes: 5 additions & 5 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def forward(self, batch, key=None):

class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda" if torch.cuda.is_available() else "cpu"):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
Expand All @@ -52,7 +52,7 @@ def encode(self, x):

class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
Expand Down Expand Up @@ -80,7 +80,7 @@ def decode(self, text):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
device="cuda" if torch.cuda.is_available() else "cpu", use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
Expand Down Expand Up @@ -136,7 +136,7 @@ def encode(self, x):

class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda" if torch.cuda.is_available() else "cpu", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
Expand Down Expand Up @@ -231,4 +231,4 @@ def forward(self, x):
if __name__ == "__main__":
from ldm.util import count_params
model = FrozenCLIPEmbedder()
count_params(model, verbose=True)
count_params(model, verbose=True)
8 changes: 5 additions & 3 deletions notebook_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def load_model_from_config(config, ckpt):
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
if torch.cuda.is_available():
model.cuda()
model.eval()
return {"model": model}, global_step

Expand Down Expand Up @@ -117,7 +118,8 @@ def get_cond(mode, selected_path):
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.

c = c.to(torch.device("cuda"))
if torch.cuda.is_available():
c = c.to(torch.device("cuda"))
example["LR_image"] = c
example["image"] = c_up

Expand Down Expand Up @@ -267,4 +269,4 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e
log["sample"] = x_sample
log["time"] = t1 - t0

return log
return log
8 changes: 6 additions & 2 deletions scripts/knn2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)

model.cuda()
if torch.cuda.is_available():
model.cuda()
model.eval()
return model

Expand Down Expand Up @@ -358,7 +359,10 @@ def __call__(self, x, n):
uc = None
if searcher is not None:
nn_dict = searcher(c, opt.knn)
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
nn_embeddings = torch.from_numpy(nn_dict['nn_embeddings'])
if torch.cuda.is_available():
nn_embeddings = nn_embeddings.cuda()
c = torch.cat([c, nn_embeddings], dim=1)
if opt.scale != 1.0:
uc = torch.zeros_like(c)
if isinstance(prompts, tuple):
Expand Down
3 changes: 2 additions & 1 deletion scripts/sample_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ def get_parser():
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
model.load_state_dict(sd,strict=False)
model.cuda()
if torch.cuda.is_available():
model.cuda()
model.eval()
return model

Expand Down
3 changes: 2 additions & 1 deletion scripts/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)

model.cuda()
if torch.cuda.is_available():
model.cuda()
model.eval()
return model

Expand Down