diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9e3d04c --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +venv* diff --git a/README.md b/README.md index 243dd95..52a742e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Implementation of CutPaste -This is a **unofficial** work in progress PyTorch reimplementation of [CutPaste: Self-Supervised Learning for Anomaly Detection and Localization](https://arxiv.org/abs/2104.04015) and in no way affiliated with the original authors. Use at own risk. Pull requestes and feedback is appreciated. +This is an **unofficial** work in progress PyTorch reimplementation of [CutPaste: Self-Supervised Learning for Anomaly Detection and Localization](https://arxiv.org/abs/2104.04015) and in no way affiliated with the original authors. Use at own risk. Pull requests and feedback is appreciated. ## Setup Download the MVTec Anomaly detection Dataset from [here](https://www.mvtec.com/company/research/datasets/mvtec-ad) and extract it into a new folder named `Data`. @@ -25,6 +25,11 @@ python run_training.py --model_dir models --head_layer 2 ``` The Script will train a model for each defect type and save it in the `model_dir` Folder. +To enable training on an Nvidia GPU use the `--cuda 1` flag. +``` +python run_training.py --model_dir models --head_layer 2 --cuda 1 +``` + One can track the training progress of the models with tensorboard: ``` tensorboard --logdir logdirs @@ -36,13 +41,130 @@ python eval.py --model_dir models --head_layer 2 ``` This will create a new directory `Eval` with plots for each defect type/model. -## Some implementation details -Only the normal CutPaste augmentation and 2-Class classification variant is implemented. +# Implementation details +### CutPaste Location The pasted image patch always origins from the same image it is pasted to. I'm not sure if this is a Problem and if this is also the case in the original paper/code. +### Epochs +Li et al. define "256 parameter update steps" as one epoch. The `--epoch` parameter takes the number of update steps and not their definition of epochs. + +### Batch Size +Li et al. use a "batch size of 64 (or 96 for 3-way)". Because the number of images feed into the model changes from the normal to the 3-way variant I suspect that they always start with 32 images that get augmented. The `--batch_size` parameter specifies the number of images read from disk. So for the all variants `--batch_size=32` should correspond with the batch size used by Li et al. + +### Projection head +I did not find a model description of the projection head Li et al. use. +The `--head_layer` parameter is used to vary the number of layers used in this implementation. +Actually `head_layer + 2` fully connected layers are used. +Starting with `head_layer` layers with 512 neurons, followed by a layer with 128 neurons and the output layer with 2 or 3 neurons. The number of neurons depends on the variant. 2 for `normal` and `scar` and 3 for `3way`. + +### Augmentations used before CutPaste +Li et al. "apply random translation and +color jitters for data augmentation". +This implementation only applies color jitter before the CutPaste augmentation. I tried to use [torchvision.transforms.RandomResizedCrop](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.RandomResizedCrop) as translation, but in a brief test I did not find that it improves performance. + +### Tensorflow vs PyTorch +Li et al. use tensorflow for their implementation. This implementation is using PyTorch. + +### Kernel Density Estimation +I implemented two Kernel Density Estimation and mahalanobis distance pipelines. +Li et al. use sklearn for the density estimation but [Ripple et al.](https://github.com/ORippler/gaussian-ad-mvtec) have their own. +The `eval.py` has a `--density` flag that can be toggled between `torch` for the Ripple et al. implementation and `sklearn` for my sklearn implementation. +In my limited testing both implementations have small differences between the resulting ROC AUCs: +``` +> python eval.py --density torch --cuda 1 --head_layer 2 --save_plots 0| grep AUC +bottle AUC: 0.9944444444444445 +cable AUC: 0.8549475262368815 +capsule AUC: 0.8232947746310331 +carpet AUC: 0.9329855537720706 +grid AUC: 0.982456140350877 +hazelnut AUC: 0.9160714285714285 +leather AUC: 1.0 +metal_nut AUC: 0.9403714565004888 +pill AUC: 0.8046917621385706 +screw AUC: 0.701988112318098 +tile AUC: 0.9430014430014431 +toothbrush AUC: 0.8972222222222221 +transistor AUC: 0.9008333333333334 +wood AUC: 0.9815789473684211 +zipper AUC: 0.9997373949579832 + +> python eval.py --density sklearn --cuda 1 --head_layer 2 --save_plots 0| grep AUC +bottle AUC: 0.9944444444444445 +cable AUC: 0.8549475262368815 +capsule AUC: 0.8232947746310331 +carpet AUC: 0.9329855537720706 +grid AUC: 0.982456140350877 +hazelnut AUC: 0.9160714285714285 +leather AUC: 1.0 +metal_nut AUC: 0.9403714565004888 +pill AUC: 0.8046917621385706 +screw AUC: 0.701988112318098 +tile AUC: 0.9430014430014431 +toothbrush AUC: 0.8972222222222221 +transistor AUC: 0.9008333333333334 +wood AUC: 0.9815789473684211 +zipper AUC: 0.9997373949579832 +``` + + +# Results +This implementation only tries to recreate the main results from section 4.1 and shown in table 1. +## CutPaste +``` +python run_training.py --epochs 10000 --test_epochs 32 --no-pretrained --cuda 1 --head_layer 1 --batch_size 32 --variant normal +``` +![training loss](doc/imgs/normal_loss.png) +The blue line is the real value and the orange line is an average over 100 epochs. +![training accuracy](doc/imgs/normal_acc.png) +![validation accuracy](doc/imgs/normal_eval_auc.png) +We only run the ROC AUC every 32nd update step, here the orange line is an average over 320 update steps (10 ROC AUC values). +Note: The validation accuracy (named test set ROC AUC) is using the Mahalanobis distance as anomaly score. It can not be directly compared with the accuracy during training. + +![comparison with Li et al.](doc/imgs/author_vs_thisimpl_CutPaste.png) + +Note that for readability, the y-axis starts at 40% AUC ROC. +## CutPaste (scar) +``` +python run_training.py --epochs 10000 --test_epochs 32 --no-pretrained --cuda 1 --head_layer 1 --batch_size 32 --variant scar +``` +![training loss](doc/imgs/scar_loss.png) +![training accuracy](doc/imgs/scar_acc.png) +![validation accuracy](doc/imgs/scar_eval_auc.png) + +![comparision with Li et al.](doc/imgs/author_vs_thisimpl_CutPaste_scar.png) +## CutPaste (3-way) +Due to limited computing resources, the evaluation during training is disabled. +``` +python run_training.py --epochs 10000 --test_epochs -1 --no-pretrained --cuda 1 --head_layer 1 --batch_size 32 --variant 3way +``` +![training loss](doc/imgs/3way_loss.png) +![training accuracy](doc/imgs/3way_acc.png) +![comparison with Li et al.](doc/imgs/author_vs_thisimpl_CutPaste_3way.png) + +# Comparison to Li et al. +| defect_type | CutPaste | Li et al. CutPaste | CutPaste (scar) | Li et al. CutPaste (scar) | CutPaste (3-way) | Li et al. CutPaste (3-way) | +|:--------------|-----------:|---------------------:|------------------:|----------------------------:|-------------------:|-----------------------------:| +| bottle | 99.7 | 99.2 | 97.9 | 98.0 | 99.6 | 98.3 | +| cable | 92.3 | 87.1 | 75.0 | 78.8 | 77.2 | 80.6 | +| capsule | 86.2 | 87.9 | 84.5 | 95.3 | 92.4 | 96.2 | +| carpet | 59.8 | 67.9 | 88.6 | 94.6 | 60.1 | 93.1 | +| grid | 100.0 | 99.9 | 99.9 | 95.5 | 100.0 | 99.9 | +| hazelnut | 83.7 | 91.3 | 87.5 | 96.7 | 86.8 | 97.3 | +| leather | 99.5 | 99.7 | 99.5 | 100.0 | 100.0 | 100.0 | +| metal_nut | 91.5 | 96.8 | 80.6 | 97.9 | 87.8 | 99.3 | +| pill | 89.4 | 93.4 | 78.4 | 85.8 | 91.7 | 92.4 | +| screw | 44.1 | 54.4 | 80.7 | 83.7 | 86.8 | 86.3 | +| tile | 88.7 | 95.9 | 95.3 | 89.4 | 97.2 | 93.4 | +| toothbrush | 96.7 | 99.2 | 88.3 | 96.7 | 94.7 | 98.3 | +| transistor | 95.1 | 96.4 | 86.8 | 91.1 | 93.0 | 95.5 | +| wood | 98.6 | 94.9 | 98.0 | 98.7 | 99.4 | 98.6 | +| zipper | 99.6 | 99.4 | 95.9 | 99.5 | 98.8 | 99.4 | +| average | 88.3 | 90.9 | 89.1 | 93.4 | 91.0 | 95.2 | + +![comparison with Li et al.](doc/imgs/compare_all.png) # TODOs -- [ ] implement Cut-Paste Scar +- [x] implement Cut-Paste Scar - [ ] implement gradCam - [ ] implement localization variant - [ ] add option to finetune on EfficientNet(B4) diff --git a/cutpaste.py b/cutpaste.py index a828299..ca276f6 100644 --- a/cutpaste.py +++ b/cutpaste.py @@ -102,7 +102,6 @@ def __call__(self, img): box = [from_location_w, from_location_h, from_location_w + cut_w, from_location_h + cut_h] patch = img.crop(box) - print(patch.size) if self.colorJitter: patch = self.colorJitter(patch) diff --git a/dataset.py b/dataset.py index 2cf7ea6..4447c80 100644 --- a/dataset.py +++ b/dataset.py @@ -3,8 +3,22 @@ from PIL import Image from joblib import Parallel, delayed +class Repeat(Dataset): + def __init__(self, org_dataset, new_length): + self.org_dataset = org_dataset + self.org_length = len(self.org_dataset) + self.new_length = new_length + + def __len__(self): + return self.new_length + + def __getitem__(self, idx): + return self.org_dataset[idx % self.org_length] + class MVTecAT(Dataset): - """Face Landmarks dataset.""" + """MVTec anomaly detection dataset. + Link: https://www.mvtec.com/company/research/datasets/mvtec-ad + """ def __init__(self, root_dir, defect_name, size, transform=None, mode="train"): """ @@ -12,7 +26,7 @@ def __init__(self, root_dir, defect_name, size, transform=None, mode="train"): root_dir (string): Directory with the MVTec AD dataset. defect_name (string): defect to load. transform: Transform to apply to data - mode: "train" loads training sammples "test" test samples default "train" + mode: "train" loads training samples "test" test samples default "train" """ self.root_dir = Path(root_dir) self.defect_name = defect_name diff --git a/density.py b/density.py new file mode 100644 index 0000000..848223d --- /dev/null +++ b/density.py @@ -0,0 +1,68 @@ + +from sklearn.covariance import LedoitWolf +from sklearn.neighbors import KernelDensity +import torch + + +class Density(object): + def fit(self, embeddings): + raise NotImplementedError + + def predict(self, embeddings): + raise NotImplementedError + + +class GaussianDensityTorch(object): + """Gaussian Density estimation similar to the implementation used by Ripple et al. + The code of Ripple et al. can be found here: https://github.com/ORippler/gaussian-ad-mvtec. + """ + def fit(self, embeddings): + self.mean = torch.mean(embeddings, axis=0) + self.inv_cov = torch.Tensor(LedoitWolf().fit(embeddings.cpu()).precision_,device="cpu") + + def predict(self, embeddings): + distances = self.mahalanobis_distance(embeddings, self.mean, self.inv_cov) + return distances + + @staticmethod + def mahalanobis_distance( + values: torch.Tensor, mean: torch.Tensor, inv_covariance: torch.Tensor + ) -> torch.Tensor: + """Compute the batched mahalanobis distance. + values is a batch of feature vectors. + mean is either the mean of the distribution to compare, or a second + batch of feature vectors. + inv_covariance is the inverse covariance of the target distribution. + + from https://github.com/ORippler/gaussian-ad-mvtec/blob/4e85fb5224eee13e8643b684c8ef15ab7d5d016e/src/gaussian/model.py#L308 + """ + assert values.dim() == 2 + assert 1 <= mean.dim() <= 2 + assert len(inv_covariance.shape) == 2 + assert values.shape[1] == mean.shape[-1] + assert mean.shape[-1] == inv_covariance.shape[0] + assert inv_covariance.shape[0] == inv_covariance.shape[1] + + if mean.dim() == 1: # Distribution mean. + mean = mean.unsqueeze(0) + x_mu = values - mean # batch x features + # Same as dist = x_mu.t() * inv_covariance * x_mu batch wise + dist = torch.einsum("im,mn,in->i", x_mu, inv_covariance, x_mu) + return dist.sqrt() + +class GaussianDensitySklearn(): + """Li et al. use sklearn for density estimation. + This implementation uses sklearn KernelDensity module for fitting and predicting. + """ + def fit(self, embeddings): + # estimate KDE parameters + # use grid search cross-validation to optimize the bandwidth + self.kde = KernelDensity(kernel='gaussian', bandwidth=1).fit(embeddings) + + def predict(self, embeddings): + scores = self.kde.score_samples(embeddings) + + # invert scores, so they fit to the class labels for the auc calculation + scores = -scores + + return scores diff --git a/doc/imgs/3way_acc.png b/doc/imgs/3way_acc.png new file mode 100644 index 0000000..f40da1a Binary files /dev/null and b/doc/imgs/3way_acc.png differ diff --git a/doc/imgs/3way_eval_auc.png b/doc/imgs/3way_eval_auc.png new file mode 100644 index 0000000..78ea0ae Binary files /dev/null and b/doc/imgs/3way_eval_auc.png differ diff --git a/doc/imgs/3way_loss.png b/doc/imgs/3way_loss.png new file mode 100644 index 0000000..f6cb731 Binary files /dev/null and b/doc/imgs/3way_loss.png differ diff --git a/doc/imgs/author_vs_thisimpl_CutPaste.png b/doc/imgs/author_vs_thisimpl_CutPaste.png new file mode 100644 index 0000000..ce76d3e Binary files /dev/null and b/doc/imgs/author_vs_thisimpl_CutPaste.png differ diff --git a/doc/imgs/author_vs_thisimpl_CutPaste_3way.png b/doc/imgs/author_vs_thisimpl_CutPaste_3way.png new file mode 100644 index 0000000..ce43842 Binary files /dev/null and b/doc/imgs/author_vs_thisimpl_CutPaste_3way.png differ diff --git a/doc/imgs/author_vs_thisimpl_CutPaste_scar.png b/doc/imgs/author_vs_thisimpl_CutPaste_scar.png new file mode 100644 index 0000000..dcee4ac Binary files /dev/null and b/doc/imgs/author_vs_thisimpl_CutPaste_scar.png differ diff --git a/doc/imgs/compare_all.png b/doc/imgs/compare_all.png new file mode 100644 index 0000000..7d12070 Binary files /dev/null and b/doc/imgs/compare_all.png differ diff --git a/doc/imgs/normal_acc.png b/doc/imgs/normal_acc.png new file mode 100644 index 0000000..849eb42 Binary files /dev/null and b/doc/imgs/normal_acc.png differ diff --git a/doc/imgs/normal_eval_auc.png b/doc/imgs/normal_eval_auc.png new file mode 100644 index 0000000..4849a14 Binary files /dev/null and b/doc/imgs/normal_eval_auc.png differ diff --git a/doc/imgs/normal_loss.png b/doc/imgs/normal_loss.png new file mode 100644 index 0000000..313d0e5 Binary files /dev/null and b/doc/imgs/normal_loss.png differ diff --git a/doc/imgs/scar_acc.png b/doc/imgs/scar_acc.png new file mode 100644 index 0000000..c59dde7 Binary files /dev/null and b/doc/imgs/scar_acc.png differ diff --git a/doc/imgs/scar_eval_auc.png b/doc/imgs/scar_eval_auc.png new file mode 100644 index 0000000..588ca41 Binary files /dev/null and b/doc/imgs/scar_eval_auc.png differ diff --git a/doc/imgs/scar_loss.png b/doc/imgs/scar_loss.png new file mode 100644 index 0000000..8355fd3 Binary files /dev/null and b/doc/imgs/scar_loss.png differ diff --git a/eval.py b/eval.py index 523585b..22009ae 100644 --- a/eval.py +++ b/eval.py @@ -1,6 +1,5 @@ from sklearn.metrics import roc_curve, auc from sklearn.manifold import TSNE -from sklearn.neighbors import KernelDensity from torchvision import transforms from torch.utils.data import DataLoader import torch @@ -14,9 +13,10 @@ from sklearn.utils import shuffle from sklearn.model_selection import GridSearchCV import numpy as np -from sklearn.covariance import LedoitWolf from collections import defaultdict +from density import GaussianDensitySklearn, GaussianDensityTorch import pandas as pd +from utils import str2bool test_data_eval = None test_transform = None @@ -37,7 +37,7 @@ def get_train_embeds(model, size, defect_type, transform, device): train_embed = torch.cat(train_embed) return train_embed -def eval_model(modelname, defect_type, device="cpu", save_plots=False, size=256, show_training_data=True, model=None, train_embed=None, head_layer=8): +def eval_model(modelname, defect_type, device="cpu", save_plots=False, size=256, show_training_data=True, model=None, train_embed=None, head_layer=8, density=GaussianDensityTorch()): # create test dataset global test_data_eval,test_transform, cached_type @@ -95,8 +95,8 @@ def eval_model(modelname, defect_type, device="cpu", save_plots=False, size=256, # also show some of the training data show_training_data = False if show_training_data: - #augmentation settig - # TODO: do all of this in a seperate function that we can call in training and evaluation. + #augmentation setting + # TODO: do all of this in a separate function that we can call in training and evaluation. # very ugly to just copy the code lol min_scale = 0.5 @@ -145,53 +145,10 @@ def eval_model(modelname, defect_type, device="cpu", save_plots=False, size=256, plot_tsne(tsne_labels, tsne_embeds, eval_dir / "tsne.png") else: eval_dir = Path("unused") - # TODO: put the GDE stuff into the Model class and do this at the end of the training - # # estemate KDE parameters - # # use grid search cross-validation to optimize the bandwidth - # params = {'bandwidth': np.logspace(-10, 10, 50)} - # grid = GridSearchCV(KernelDensity(), params) - # grid.fit(embeds) - - # print("best bandwidth: {0}".format(grid.best_estimator_.bandwidth)) - - # # use the best estimator to compute the kernel density estimate - # kde = grid.best_estimator_ - # kde = KernelDensity(kernel='gaussian', bandwidth=1).fit(train_embed) - # scores = kde.score_samples(embeds) - # print(scores) - # we get the probability to be in the correct distribution - # but our labels are inverted (1 for out of distribution) - # so we have to relabel - - # use own formulation with malanobis distance - # from https://github.com/ORippler/gaussian-ad-mvtec/blob/4e85fb5224eee13e8643b684c8ef15ab7d5d016e/src/gaussian/model.py#L308 - def mahalanobis_distance( - values: torch.Tensor, mean: torch.Tensor, inv_covariance: torch.Tensor - ) -> torch.Tensor: - """Compute the batched mahalanobis distance. - values is a batch of feature vectors. - mean is either the mean of the distribution to compare, or a second - batch of feature vectors. - inv_covariance is the inverse covariance of the target distribution. - """ - assert values.dim() == 2 - assert 1 <= mean.dim() <= 2 - assert len(inv_covariance.shape) == 2 - assert values.shape[1] == mean.shape[-1] - assert mean.shape[-1] == inv_covariance.shape[0] - assert inv_covariance.shape[0] == inv_covariance.shape[1] - - if mean.dim() == 1: # Distribution mean. - mean = mean.unsqueeze(0) - x_mu = values - mean # batch x features - # Same as dist = x_mu.t() * inv_covariance * x_mu batch wise - dist = torch.einsum("im,mn,in->i", x_mu, inv_covariance, x_mu) - return dist.sqrt() - # claculate mean - mean = torch.mean(train_embed, axis=0) - inv_cov = torch.Tensor(LedoitWolf().fit(train_embed.cpu()).precision_,device="cpu") - - distances = mahalanobis_distance(embeds, mean, inv_cov) + + print(f"using density estimation {density.__class__.__name__}") + density.fit(train_embed) + distances = density.predict(embeds) #TODO: set threshold on mahalanobis distances and use "real" probabilities roc_auc = plot_roc(labels, distances, eval_dir / "roc_plot.png", modelname=modelname, save_plots=save_plots) @@ -243,12 +200,17 @@ def plot_tsne(labels, embeds, filename): parser.add_argument('--model_dir', default="models", help=' directory contating models to evaluate (default: models)') - parser.add_argument('--cuda', default=False, + parser.add_argument('--cuda', default=False, type=str2bool, help='use cuda for model predictions (default: False)') parser.add_argument('--head_layer', default=8, type=int, help='number of layers in the projection head (default: 8)') + parser.add_argument('--density', default="torch", choices=["torch", "sklearn"], + help='density implementation to use. See `density.py` for both implementations. (default: torch)') + + parser.add_argument('--save_plots', default=True, type=str2bool, + help='save TSNE and roc plots') args = parser.parse_args() @@ -278,6 +240,12 @@ def plot_tsne(labels, embeds, filename): device = "cuda" if args.cuda else "cpu" + density_mapping = { + "torch": GaussianDensityTorch, + "sklearn": GaussianDensitySklearn + } + density = density_mapping[args.density] + # find models model_names = [list(Path(args.model_dir).glob(f"model-{data_type}*"))[0] for data_type in types if len(list(Path(args.model_dir).glob(f"model-{data_type}*"))) > 0] if len(model_names) < len(all_types): @@ -287,13 +255,13 @@ def plot_tsne(labels, embeds, filename): for model_name, data_type in zip(model_names, types): print(f"evaluating {data_type}") - roc_auc = eval_model(model_name, data_type, save_plots=True, device=device, head_layer=args.head_layer) + roc_auc = eval_model(model_name, data_type, save_plots=args.save_plots, device=device, head_layer=args.head_layer, density=density()) print(f"{data_type} AUC: {roc_auc}") obj["defect_type"].append(data_type) obj["roc_auc"].append(roc_auc) # save pandas dataframe - eval_dir = Path("eval") / model_name + eval_dir = Path("eval") / args.model_dir eval_dir.mkdir(parents=True, exist_ok=True) df = pd.DataFrame(obj) df.to_csv(str(eval_dir) + "_perf.csv") diff --git a/model.py b/model.py index 3a94879..018816a 100644 --- a/model.py +++ b/model.py @@ -10,7 +10,7 @@ def __init__(self, pretrained=True, head_layers=[512,512,512,512,512,512,512,512 #self.resnet18 = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=pretrained) self.resnet18 = resnet18(pretrained=pretrained) - # create MPL head as seen in the code in: https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py + # create MLP head as seen in the code in: https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py # TODO: check if this is really the right architecture last_layer = 512 sequential_layers = [] @@ -21,9 +21,6 @@ def __init__(self, pretrained=True, head_layers=[512,512,512,512,512,512,512,512 last_layer = num_neurons #the last layer without activation - #TODO: is this correct? check one classe representation framework paper/code - # sequential_layers.append(nn.Linear(last_layer, head_layers[-1])) - # last_layer = head_layers[-1] head = nn.Sequential( *sequential_layers diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..74c1afb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +torch +torchvision +sklearn +pandas +seaborn +tqdm +tensorboard diff --git a/run_training.py b/run_training.py index cb6c5a8..c5c514e 100644 --- a/run_training.py +++ b/run_training.py @@ -13,10 +13,11 @@ from torchvision import transforms -from dataset import MVTecAT +from dataset import MVTecAT, Repeat from cutpaste import CutPasteNormal,CutPasteScar, CutPaste3Way, CutPasteUnion, cut_paste_collate_fn from model import ProjectionNet from eval import eval_model +from utils import str2bool def run_training(data_type="screw", model_dir="models", @@ -43,7 +44,7 @@ def run_training(data_type="screw", model_name = f"model-{data_type}" + '-{date:%Y-%m-%d_%H_%M_%S}'.format(date=datetime.datetime.now() ) #augmentation: - min_scale = 0.5 + min_scale = 1 # create Training Dataset and Dataloader after_cutpaste_transform = transforms.Compose([]) @@ -52,14 +53,15 @@ def run_training(data_type="screw", std=[0.229, 0.224, 0.225])) train_transform = transforms.Compose([]) - # train_transform.transforms.append(transforms.RandomResizedCrop(size, scale=(min_scale,1))) + #train_transform.transforms.append(transforms.RandomResizedCrop(size, scale=(min_scale,1))) + train_transform.transforms.append(transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)) # train_transform.transforms.append(transforms.GaussianBlur(int(size/10), sigma=(0.1,2.0))) train_transform.transforms.append(transforms.Resize((size,size))) train_transform.transforms.append(cutpate_type(transform = after_cutpaste_transform)) # train_transform.transforms.append(transforms.ToTensor()) train_data = MVTecAT("Data", data_type, transform = train_transform, size=int(size * (1/min_scale))) - dataloader = DataLoader(train_data, batch_size=batch_size, drop_last=True, + dataloader = DataLoader(Repeat(train_data, 3000), batch_size=batch_size, drop_last=True, shuffle=True, num_workers=workers, collate_fn=cut_paste_collate_fn, persistent_workers=True, pin_memory=True, prefetch_factor=5) @@ -203,7 +205,7 @@ def get_data_inf(): parser.add_argument('--variant', default="3way", choices=['normal', 'scar', '3way', 'union'], help='cutpaste variant to use (dafault: "3way")') - parser.add_argument('--cuda', default=False, + parser.add_argument('--cuda', default=False, type=str2bool, help='use cuda for training (default: False)') parser.add_argument('--workers', default=8, type=int, help="number of workers to use for data loading (default:8)") diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..d925fa0 --- /dev/null +++ b/utils.py @@ -0,0 +1,13 @@ +def str2bool(v): + """argparse handels type=bool in a weird way. + See this stack overflow: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse + we can use this function as type converter for boolean values + """ + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') \ No newline at end of file