Skip to content

Commit 5842b6d

Browse files
committed
upload
1 parent 5b59ff4 commit 5842b6d

File tree

20 files changed

+805
-0
lines changed

20 files changed

+805
-0
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# neural-transfer
2+
This is my implement of neural-transfer according to http://pytorch.org/tutorials/advanced/neural_style_tutorial.html#sphx-glr-advanced-neural-style-tutorial-py
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch.nn as nn
2+
import torchvision.models as models
3+
4+
import loss
5+
6+
vgg = models.vgg19(pretrained=True).features
7+
vgg = vgg.cuda()
8+
9+
content_layers_default = ['conv_4']
10+
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
11+
12+
13+
def get_style_model_and_loss(style_img,
14+
content_img,
15+
cnn=vgg,
16+
style_weight=1000,
17+
content_weight=1,
18+
content_layers=content_layers_default,
19+
style_layers=style_layers_default):
20+
21+
content_loss_list = []
22+
style_loss_list = []
23+
24+
model = nn.Sequential()
25+
model = model.cuda()
26+
gram = loss.Gram()
27+
gram = gram.cuda()
28+
29+
i = 1
30+
for layer in cnn:
31+
if isinstance(layer, nn.Conv2d):
32+
name = 'conv_' + str(i)
33+
model.add_module(name, layer)
34+
35+
if name in content_layers_default:
36+
target = model(content_img)
37+
content_loss = loss.Content_Loss(target, content_weight)
38+
model.add_module('content_loss_' + str(i), content_loss)
39+
content_loss_list.append(content_loss)
40+
41+
if name in style_layers_default:
42+
target = model(style_img)
43+
target = gram(target)
44+
style_loss = loss.Style_Loss(target, style_weight)
45+
model.add_module('style_loss_' + str(i), style_loss)
46+
style_loss_list.append(style_loss)
47+
48+
i += 1
49+
if isinstance(layer, nn.MaxPool2d):
50+
name = 'pool_' + str(i)
51+
model.add_module(name, layer)
52+
53+
if isinstance(layer, nn.ReLU):
54+
name = 'relu' + str(i)
55+
model.add_module(name, layer)
56+
57+
return model, style_loss_list, content_loss_list

chapter8_Application/neural-transfer/demo.ipynb

Lines changed: 118 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import PIL.Image as Image
2+
import torchvision.transforms as transforms
3+
4+
img_size = 512
5+
6+
7+
def load_img(img_path):
8+
img = Image.open(img_path).convert('RGB')
9+
img = img.resize((img_size, img_size))
10+
img = transforms.ToTensor()(img)
11+
img = img.unsqueeze(0)
12+
return img
13+
14+
15+
def show_img(img):
16+
img = img.squeeze(0)
17+
img = transforms.ToPILImage()(img)
18+
img.show()
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
class Content_Loss(nn.Module):
6+
def __init__(self, target, weight):
7+
super(Content_Loss, self).__init__()
8+
self.weight = weight
9+
self.target = target.detach() * self.weight
10+
# 必须要用detach来分离出target,这时候target不再是一个Variable,这是为了动态计算梯度,否则forward会出错,不能向前传播
11+
self.criterion = nn.MSELoss()
12+
13+
def forward(self, input):
14+
self.loss = self.criterion(input * self.weight, self.target)
15+
out = input.clone()
16+
return out
17+
18+
def backward(self, retain_variabels=True):
19+
self.loss.backward(retain_variables=retain_variabels)
20+
return self.loss
21+
22+
23+
class Gram(nn.Module):
24+
def __init__(self):
25+
super(Gram, self).__init__()
26+
27+
def forward(self, input):
28+
a, b, c, d = input.size()
29+
feature = input.view(a * b, c * d)
30+
gram = torch.mm(feature, feature.t())
31+
gram /= (a * b * c * d)
32+
return gram
33+
34+
35+
class Style_Loss(nn.Module):
36+
def __init__(self, target, weight):
37+
super(Style_Loss, self).__init__()
38+
self.weight = weight
39+
self.target = target.detach() * self.weight
40+
self.gram = Gram()
41+
self.criterion = nn.MSELoss()
42+
43+
def forward(self, input):
44+
G = self.gram(input) * self.weight
45+
self.loss = self.criterion(G, self.target)
46+
out = input.clone()
47+
return out
48+
49+
def backward(self, retain_variabels=True):
50+
self.loss.backward(retain_variables=retain_variabels)
51+
return self.loss

0 commit comments

Comments
 (0)