Skip to content

Commit 369ed4a

Browse files
committed
update
1 parent 6a23e53 commit 369ed4a

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

chapter8_Application/neural-transfer/build_model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import torch
12
import torch.nn as nn
23
import torchvision.models as models
34

45
import loss
56

67
vgg = models.vgg19(pretrained=True).features
7-
vgg = vgg.cuda()
8+
if torch.cuda.is_available():
9+
vgg = vgg.cuda()
810

911
content_layers_default = ['conv_4']
1012
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
@@ -22,9 +24,11 @@ def get_style_model_and_loss(style_img,
2224
style_loss_list = []
2325

2426
model = nn.Sequential()
25-
model = model.cuda()
27+
if torch.cuda.is_available():
28+
model = model.cuda()
2629
gram = loss.Gram()
27-
gram = gram.cuda()
30+
if torch.cuda.is_available():
31+
gram = gram.cuda()
2832

2933
i = 1
3034
for layer in cnn:

0 commit comments

Comments
 (0)