1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow import keras
4
+ import matplotlib .pyplot as plt
5
+
6
+
7
+ class NeuralStyleTransfer :
8
+ def __init__ (self , base_image_url , style_image_url , img_height = 400 , style_weight = 1e-6 , content_weight = 2.5e-8 , tv_weight = 1e-6 ):
9
+ """
10
+ Initializes the Neural Style Transfer model with given parameters.
11
+
12
+ - base_image_url: URL of the base image
13
+ - style_image_url: URL of the style reference image
14
+ - img_height: Height of the processed images
15
+ - style_weight: Weight for style loss
16
+ - content_weight: Weight for content loss
17
+ - tv_weight: Weight for total variation loss
18
+ """
19
+ self .base_image_path = keras .utils .get_file ("base.jpg" , origin = base_image_url )
20
+ self .style_reference_image_path = keras .utils .get_file ("style.jpg" , origin = style_image_url )
21
+
22
+ original_width , original_height = keras .utils .load_img (self .base_image_path ).size
23
+ self .img_height = img_height
24
+ self .img_width = round (original_width * img_height / original_height )
25
+
26
+ self .style_weight = style_weight
27
+ self .content_weight = content_weight
28
+ self .tv_weight = tv_weight
29
+
30
+ self .model = keras .applications .vgg19 .VGG19 (weights = "imagenet" , include_top = False )
31
+ self .feature_extractor = self .build_feature_extractor ()
32
+
33
+ self .style_layer_names = [
34
+ "block1_conv1" ,
35
+ "block2_conv1" ,
36
+ "block3_conv1" ,
37
+ "block4_conv1" ,
38
+ "block5_conv1" ,
39
+ ]
40
+ self .content_layer_name = "block5_conv2"
41
+
42
+ def preprocess_image (self , image_path ):
43
+ """
44
+ Prepares an image for use with the VGG19 model.
45
+ """
46
+ img = keras .utils .load_img (image_path , target_size = (self .img_height , self .img_width ))
47
+ img = keras .utils .img_to_array (img )
48
+ img = np .expand_dims (img , axis = 0 )
49
+ img = keras .applications .vgg19 .preprocess_input (img )
50
+ return img
51
+
52
+ def deprocess_image (self , img ):
53
+ """
54
+ Converts a processed image back to a viewable format.
55
+ """
56
+ img = img .reshape ((self .img_height , self .img_width , 3 ))
57
+ img [:, :, 0 ] += 103.939
58
+ img [:, :, 1 ] += 116.779
59
+ img [:, :, 2 ] += 123.68
60
+ img = img [:, :, ::- 1 ]
61
+ img = np .clip (img , 0 , 255 ).astype ("uint8" )
62
+ return img
63
+
64
+ def build_feature_extractor (self ):
65
+ """
66
+ Builds a feature extractor model from VGG19.
67
+ """
68
+ outputs_dict = {layer .name : layer .output for layer in self .model .layers }
69
+ return keras .Model (inputs = self .model .inputs , outputs = outputs_dict )
70
+
71
+ @staticmethod
72
+ def content_loss (base_img , combination_img ):
73
+ return tf .reduce_sum (tf .square (combination_img - base_img ))
74
+
75
+ @staticmethod
76
+ def gram_matrix (x ):
77
+ x = tf .transpose (x , (2 , 0 , 1 ))
78
+ features = tf .reshape (x , (tf .shape (x )[0 ], - 1 ))
79
+ return tf .matmul (features , tf .transpose (features ))
80
+
81
+ def style_loss (self , style_img , combination_img ):
82
+ """
83
+ Computes the style loss using Gram matrices.
84
+ """
85
+ S = self .gram_matrix (style_img )
86
+ C = self .gram_matrix (combination_img )
87
+ channels = 3
88
+ size = self .img_height * self .img_width
89
+ return tf .reduce_sum (tf .square (S - C )) / (4.0 * (channels ** 2 ) * (size ** 2 ))
90
+
91
+ def total_variation_loss (self , x ):
92
+ """
93
+ Computes the total variation loss for smoothness.
94
+ """
95
+ a = tf .square (x [:, : self .img_height - 1 , : self .img_width - 1 , :] - x [:, 1 :, : self .img_width - 1 , :])
96
+ b = tf .square (x [:, : self .img_height - 1 , : self .img_width - 1 , :] - x [:, : self .img_height - 1 , 1 :, :])
97
+ return tf .reduce_sum (tf .pow (a + b , 1.25 ))
98
+
99
+ def compute_loss (self , combination_image , base_image , style_reference_image ):
100
+ """
101
+ Computes the total loss for optimization.
102
+ """
103
+ input_tensor = tf .concat ([base_image , style_reference_image , combination_image ], axis = 0 )
104
+ features = self .feature_extractor (input_tensor )
105
+
106
+ loss = tf .zeros (shape = ())
107
+ layer_features = features [self .content_layer_name ]
108
+ base_image_features = layer_features [0 , :, :, :]
109
+ combination_features = layer_features [2 , :, :, :]
110
+ loss += self .content_weight * self .content_loss (base_image_features , combination_features )
111
+
112
+ for layer_name in self .style_layer_names :
113
+ layer_features = features [layer_name ]
114
+ style_reference_features = layer_features [1 , :, :, :]
115
+ combination_features = layer_features [2 , :, :, :]
116
+ loss += (self .style_weight / len (self .style_layer_names )) * self .style_loss (style_reference_features , combination_features )
117
+
118
+ loss += self .tv_weight * self .total_variation_loss (combination_image )
119
+ return loss
120
+
121
+ @tf .function
122
+ def compute_loss_and_grads (self , combination_image , base_image , style_reference_image ):
123
+ """
124
+ Computes gradients for optimization.
125
+ """
126
+ with tf .GradientTape () as tape :
127
+ loss = self .compute_loss (combination_image , base_image , style_reference_image )
128
+ grads = tape .gradient (loss , combination_image )
129
+ return loss , grads
130
+
131
+ def train (self , iterations = 4000 , learning_rate = 100.0 , decay_steps = 100 , decay_rate = 0.96 ):
132
+ """
133
+ Trains the neural style transfer model.
134
+ """
135
+ optimizer = keras .optimizers .SGD (
136
+ keras .optimizers .schedules .ExponentialDecay (
137
+ initial_learning_rate = learning_rate , decay_steps = decay_steps , decay_rate = decay_rate
138
+ )
139
+ )
140
+
141
+ base_image = self .preprocess_image (self .base_image_path )
142
+ style_reference_image = self .preprocess_image (self .style_reference_image_path )
143
+ combination_image = tf .Variable (self .preprocess_image (self .base_image_path ))
144
+
145
+ for i in range (1 , iterations + 1 ):
146
+ loss , grads = self .compute_loss_and_grads (combination_image , base_image , style_reference_image )
147
+ optimizer .apply_gradients ([(grads , combination_image )])
148
+
149
+ if i % 100 == 0 :
150
+ print (f"Iteration { i } : loss={ loss :.2f} " )
151
+ img = self .deprocess_image (combination_image .numpy ())
152
+ fname = f"combination_image_at_iteration_{ i } .png"
153
+ keras .utils .save_img (fname , img )
154
+
155
+ print ("Training complete! Final image saved." )
156
+
157
+
158
+ nst = NeuralStyleTransfer (
159
+ base_image_url = "https://img-datasets.s3.amazonaws.com/sf.jpg" ,
160
+ style_image_url = "https://img-datasets.s3.amazonaws.com/starry_night.jpg"
161
+ )
162
+ nst .train ()
0 commit comments