Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Ported to keras v3
  • Loading branch information
di-kan committed Sep 7, 2024
commit 91a7c76064f68ce6c2b5841e53dfc3a59c0c06a3
29 changes: 9 additions & 20 deletions examples/vision/shiftvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,15 @@
In this example, we minimally implement the paper with close alignement to the author's
[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).

This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can
be installed using the following command:
"""
"""shell
pip install -qq -U tensorflow-addons
"""

"""
## Setup and imports
"""

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import keras
from keras import layers

import pathlib
import glob
Expand Down Expand Up @@ -280,7 +271,7 @@ def __init__(self, drop_path_prob, **kwargs):
def call(self, x, training=False):
if training:
keep_prob = 1 - self.drop_path_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
shape = (tf.shape(x)[0],) + (1,) * (len(x.shape) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
Expand Down Expand Up @@ -871,7 +862,7 @@ def get_config(self):
)

# Get the optimizer.
optimizer = tfa.optimizers.AdamW(
optimizer = keras.optimizers.AdamW(
learning_rate=scheduled_lrs, weight_decay=config.weight_decay
)

Expand Down Expand Up @@ -913,7 +904,7 @@ def get_config(self):

It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well.
"""
model.save("ShiftViT")
model.export("ShiftViT")

"""
## Model inference
Expand All @@ -932,12 +923,10 @@ def get_config(self):
"""
**Load saved model**
"""
# Custom objects are not included when the model is saved.
# At loading time, these objects need to be passed for reconstruction of the model
saved_model = tf.keras.models.load_model(
"ShiftViT",
custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW},
)
saved_layer = keras.layers.TFSMLayer("ShiftViT")
inputs = tf.keras.Input(shape=(config.input_shape)) # specify your input shape
outputs = saved_layer(inputs)
saved_model = tf.keras.Model(inputs, outputs)

"""
**Utility functions for inference**
Expand Down