Skip to content

Commit f50a888

Browse files
committed
Add Neural Style Transfer implementation
1 parent 9336c58 commit f50a888

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed

Chapter12/Part_03/chapter_12.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
from tensorflow import keras
4+
import matplotlib.pyplot as plt
5+
6+
7+
class NeuralStyleTransfer:
8+
def __init__(self, base_image_url, style_image_url, img_height=400, style_weight=1e-6, content_weight=2.5e-8, tv_weight=1e-6):
9+
"""
10+
Initializes the Neural Style Transfer model with given parameters.
11+
12+
- base_image_url: URL of the base image
13+
- style_image_url: URL of the style reference image
14+
- img_height: Height of the processed images
15+
- style_weight: Weight for style loss
16+
- content_weight: Weight for content loss
17+
- tv_weight: Weight for total variation loss
18+
"""
19+
self.base_image_path = keras.utils.get_file("base.jpg", origin=base_image_url)
20+
self.style_reference_image_path = keras.utils.get_file("style.jpg", origin=style_image_url)
21+
22+
original_width, original_height = keras.utils.load_img(self.base_image_path).size
23+
self.img_height = img_height
24+
self.img_width = round(original_width * img_height / original_height)
25+
26+
self.style_weight = style_weight
27+
self.content_weight = content_weight
28+
self.tv_weight = tv_weight
29+
30+
self.model = keras.applications.vgg19.VGG19(weights="imagenet", include_top=False)
31+
self.feature_extractor = self.build_feature_extractor()
32+
33+
self.style_layer_names = [
34+
"block1_conv1",
35+
"block2_conv1",
36+
"block3_conv1",
37+
"block4_conv1",
38+
"block5_conv1",
39+
]
40+
self.content_layer_name = "block5_conv2"
41+
42+
def preprocess_image(self, image_path):
43+
"""
44+
Prepares an image for use with the VGG19 model.
45+
"""
46+
img = keras.utils.load_img(image_path, target_size=(self.img_height, self.img_width))
47+
img = keras.utils.img_to_array(img)
48+
img = np.expand_dims(img, axis=0)
49+
img = keras.applications.vgg19.preprocess_input(img)
50+
return img
51+
52+
def deprocess_image(self, img):
53+
"""
54+
Converts a processed image back to a viewable format.
55+
"""
56+
img = img.reshape((self.img_height, self.img_width, 3))
57+
img[:, :, 0] += 103.939
58+
img[:, :, 1] += 116.779
59+
img[:, :, 2] += 123.68
60+
img = img[:, :, ::-1]
61+
img = np.clip(img, 0, 255).astype("uint8")
62+
return img
63+
64+
def build_feature_extractor(self):
65+
"""
66+
Builds a feature extractor model from VGG19.
67+
"""
68+
outputs_dict = {layer.name: layer.output for layer in self.model.layers}
69+
return keras.Model(inputs=self.model.inputs, outputs=outputs_dict)
70+
71+
@staticmethod
72+
def content_loss(base_img, combination_img):
73+
return tf.reduce_sum(tf.square(combination_img - base_img))
74+
75+
@staticmethod
76+
def gram_matrix(x):
77+
x = tf.transpose(x, (2, 0, 1))
78+
features = tf.reshape(x, (tf.shape(x)[0], -1))
79+
return tf.matmul(features, tf.transpose(features))
80+
81+
def style_loss(self, style_img, combination_img):
82+
"""
83+
Computes the style loss using Gram matrices.
84+
"""
85+
S = self.gram_matrix(style_img)
86+
C = self.gram_matrix(combination_img)
87+
channels = 3
88+
size = self.img_height * self.img_width
89+
return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels ** 2) * (size ** 2))
90+
91+
def total_variation_loss(self, x):
92+
"""
93+
Computes the total variation loss for smoothness.
94+
"""
95+
a = tf.square(x[:, : self.img_height - 1, : self.img_width - 1, :] - x[:, 1:, : self.img_width - 1, :])
96+
b = tf.square(x[:, : self.img_height - 1, : self.img_width - 1, :] - x[:, : self.img_height - 1, 1:, :])
97+
return tf.reduce_sum(tf.pow(a + b, 1.25))
98+
99+
def compute_loss(self, combination_image, base_image, style_reference_image):
100+
"""
101+
Computes the total loss for optimization.
102+
"""
103+
input_tensor = tf.concat([base_image, style_reference_image, combination_image], axis=0)
104+
features = self.feature_extractor(input_tensor)
105+
106+
loss = tf.zeros(shape=())
107+
layer_features = features[self.content_layer_name]
108+
base_image_features = layer_features[0, :, :, :]
109+
combination_features = layer_features[2, :, :, :]
110+
loss += self.content_weight * self.content_loss(base_image_features, combination_features)
111+
112+
for layer_name in self.style_layer_names:
113+
layer_features = features[layer_name]
114+
style_reference_features = layer_features[1, :, :, :]
115+
combination_features = layer_features[2, :, :, :]
116+
loss += (self.style_weight / len(self.style_layer_names)) * self.style_loss(style_reference_features, combination_features)
117+
118+
loss += self.tv_weight * self.total_variation_loss(combination_image)
119+
return loss
120+
121+
@tf.function
122+
def compute_loss_and_grads(self, combination_image, base_image, style_reference_image):
123+
"""
124+
Computes gradients for optimization.
125+
"""
126+
with tf.GradientTape() as tape:
127+
loss = self.compute_loss(combination_image, base_image, style_reference_image)
128+
grads = tape.gradient(loss, combination_image)
129+
return loss, grads
130+
131+
def train(self, iterations=4000, learning_rate=100.0, decay_steps=100, decay_rate=0.96):
132+
"""
133+
Trains the neural style transfer model.
134+
"""
135+
optimizer = keras.optimizers.SGD(
136+
keras.optimizers.schedules.ExponentialDecay(
137+
initial_learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate
138+
)
139+
)
140+
141+
base_image = self.preprocess_image(self.base_image_path)
142+
style_reference_image = self.preprocess_image(self.style_reference_image_path)
143+
combination_image = tf.Variable(self.preprocess_image(self.base_image_path))
144+
145+
for i in range(1, iterations + 1):
146+
loss, grads = self.compute_loss_and_grads(combination_image, base_image, style_reference_image)
147+
optimizer.apply_gradients([(grads, combination_image)])
148+
149+
if i % 100 == 0:
150+
print(f"Iteration {i}: loss={loss:.2f}")
151+
img = self.deprocess_image(combination_image.numpy())
152+
fname = f"combination_image_at_iteration_{i}.png"
153+
keras.utils.save_img(fname, img)
154+
155+
print("Training complete! Final image saved.")
156+
157+
158+
nst = NeuralStyleTransfer(
159+
base_image_url="https://img-datasets.s3.amazonaws.com/sf.jpg",
160+
style_image_url="https://img-datasets.s3.amazonaws.com/starry_night.jpg"
161+
)
162+
nst.train()

0 commit comments

Comments
 (0)