Skip to content
This repository was archived by the owner on Jan 5, 2024. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Test that dummy pytorch model passes checks (currently does not)
  • Loading branch information
ggaziv committed Jan 11, 2022
commit ead336fc598fe79385f65b23a7f8b3e67c20eb68
33 changes: 33 additions & 0 deletions tests/models/test_dummy_pytorch_model_submit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import functools

import torchvision.models
from model_tools.activations.pytorch import PytorchWrapper
from model_tools.activations.pytorch import load_preprocess_images
from model_tools.check_submission import check_models


def get_model_list():
return ['dummy_model']


def get_model(name):
assert name == 'dummy_model'
from torch import nn
model = nn.Sequential(nn.Conv2d(3, 3, 3))
preprocessing = functools.partial(load_preprocess_images, image_size=224)
wrapper = PytorchWrapper(identifier='alexnet', model=model, preprocessing=preprocessing)
wrapper.image_size = 224
return wrapper


def get_layers(name):
assert name == 'dummy_model'
return ['0']


def get_bibtex(model_identifier):
return """Dummy Model"""


def test_dummy_pytorch_model_submit():
check_models.check_base_models(__name__)