Skip to content

Commit 3a9d4f4

Browse files
author
xyliao
committed
finish fine tune trainer
1 parent 3534e1c commit 3a9d4f4

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# encoding: utf-8
2+
"""
3+
@author: xyliao
4+
5+
"""
6+
import warnings
7+
from pprint import pprint
8+
9+
10+
class DefaultConfig(object):
11+
model = 'resnet50'
12+
# Dataset.
13+
train_data_path = './hymenoptera_data/train/'
14+
test_data_path = './hymenoptera_data/val/'
15+
16+
# Store result and save models.
17+
# result_file = 'result.txt'
18+
save_file = './checkpoints/'
19+
save_freq = 30 # save model every N epochs
20+
save_best = True # If save best test metric model.
21+
22+
# Visualization results on tensorboard.
23+
# vis_dir = './vis/'
24+
plot_freq = 100 # plot in tensorboard every N iterations
25+
26+
# Model hyperparameters.
27+
use_gpu = True # use GPU or not
28+
ctx = 0 # running on which cuda device
29+
batch_size = 64 # batch size
30+
num_workers = 4 # how many workers for loading data
31+
max_epoch = 30
32+
lr = 1e-2 # initial learning rate
33+
momentum = 0
34+
weight_decay = 1e-4
35+
lr_decay = 0.95
36+
# lr_decay_freq = 10
37+
38+
def _parse(self, kwargs):
39+
for k, v in kwargs.items():
40+
if not hasattr(self, k):
41+
warnings.warn("Warning: opt has not attribut %s" % k)
42+
setattr(self, k, v)
43+
44+
print('=========user config==========')
45+
pprint(self._state_dict())
46+
print('============end===============')
47+
48+
def _state_dict(self):
49+
return {k: getattr(self, k) for k, _ in DefaultConfig.__dict__.items()
50+
if not k.startswith('_')}
51+
52+
53+
opt = DefaultConfig()
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# encoding: utf-8
2+
"""
3+
@author: xyliao
4+
5+
"""
6+
import copy
7+
8+
import torch
9+
from config import opt
10+
from mxtorch import meter
11+
from mxtorch import transforms as tfs
12+
from mxtorch.trainer import *
13+
from mxtorch.vision import model_zoo
14+
from torch import nn
15+
from torch.autograd import Variable
16+
from torch.utils.data import DataLoader
17+
from torchvision.datasets import ImageFolder
18+
from tqdm import tqdm
19+
20+
train_tf = tfs.Compose([
21+
tfs.RandomResizedCrop(224),
22+
tfs.RandomHorizontalFlip(),
23+
tfs.ToTensor(),
24+
tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
25+
])
26+
27+
28+
def test_tf(img):
29+
img = tfs.Resize(256)(img)
30+
img, _ = tfs.CenterCrop(224)(img)
31+
normalize = tfs.Compose([
32+
tfs.ToTensor(),
33+
tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
34+
])
35+
img = normalize(img)
36+
return img
37+
38+
39+
def get_train_data():
40+
train_set = ImageFolder(opt.train_data_path, train_tf)
41+
return DataLoader(train_set, opt.batch_size, True, num_workers=opt.num_workers)
42+
43+
44+
def get_test_data():
45+
test_set = ImageFolder(opt.test_data_path, test_tf)
46+
return DataLoader(test_set, opt.batch_size, True, num_workers=opt.num_workers)
47+
48+
49+
def get_model():
50+
model = model_zoo.resnet50(pretrained=True)
51+
model.fc = nn.Linear(2048, 2)
52+
if opt.use_gpu:
53+
model = model.cuda(opt.ctx)
54+
return model
55+
56+
57+
def get_loss(score, label):
58+
return nn.CrossEntropyLoss()(score, label)
59+
60+
61+
def get_optimizer(model):
62+
optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum,
63+
weight_decay=opt.weight_decay)
64+
return ScheduledOptim(optimizer)
65+
66+
67+
class FineTuneTrainer(Trainer):
68+
def __init__(self):
69+
model = get_model()
70+
criterion = get_loss
71+
optimizer = get_optimizer(model)
72+
super().__init__(model, criterion, optimizer)
73+
74+
self.metric_meter['loss'] = meter.AverageValueMeter()
75+
self.metric_meter['acc'] = meter.AverageValueMeter()
76+
77+
def train(self, train_data):
78+
self.model.train()
79+
for data in tqdm(train_data):
80+
img, label = data
81+
if opt.use_gpu:
82+
img = img.cuda(opt.ctx)
83+
label = label.cuda(opt.ctx)
84+
img = Variable(img)
85+
label = Variable(label)
86+
87+
# Forward.
88+
score = self.model(img)
89+
loss = self.criterion(score, label)
90+
91+
# Backward.
92+
self.optimizer.zero_grad()
93+
loss.backward()
94+
self.optimizer.step()
95+
96+
# Update meters.
97+
acc = (score.max(1)[1] == label).float().mean()
98+
self.metric_meter['loss'].add(loss.data[0])
99+
self.metric_meter['acc'].add(acc.data[0])
100+
101+
# Update to tensorboard.
102+
# if (self.n_iter + 1) % opt.plot_freq == 0:
103+
# self.writer.add_scalars('loss', {'train': self.metric_meter['loss'].value()[0]}, self.n_plot)
104+
# self.writer.add_scalars('acc', {'train': self.metric_meter['acc'].value()[0], self.n_plot})
105+
# self.n_plot += 1
106+
self.n_iter += 1
107+
108+
# Log the train metric dict to print result.
109+
self.metric_log['train loss'] = self.metric_meter['loss'].value()[0]
110+
self.metric_log['train acc'] = self.metric_meter['acc'].value()[0]
111+
112+
def test(self, test_data):
113+
self.model.eval()
114+
for data in tqdm(test_data):
115+
img, label = data
116+
if opt.use_gpu:
117+
img = img.cuda(opt.ctx)
118+
label = label.cuda(opt.ctx)
119+
img = Variable(img, volatile=True)
120+
label = Variable(label, volatile=True)
121+
122+
score = self.model(img)
123+
loss = self.criterion(score, label)
124+
acc = (score.max(1)[1] == label).float().mean()
125+
126+
self.metric_meter['loss'].add(loss.data[0])
127+
self.metric_meter['acc'].add(acc.data[0])
128+
129+
# Update to tensorboard.
130+
# self.writer.add_scalars('loss', {'test': self.metric_meter['loss'].value()[0]}, self.n_plot)
131+
# self.writer.add_scalars('acc', {'test': self.metric_meter['acc'].value()[0]}, self.n_plot)
132+
# self.n_plot += 1
133+
134+
# Log the test metric to dict.
135+
self.metric_log['test loss'] = self.metric_meter['loss'].value()[0]
136+
self.metric_log['test acc'] = self.metric_meter['acc'].value()[0]
137+
138+
def get_best_model(self):
139+
if self.metric_log['test loss'] < self.best_metric:
140+
self.best_model = copy.deepcopy(self.model.state_dict())
141+
self.best_metric = self.metric_log['test loss']
142+
143+
144+
def train(**kwargs):
145+
opt._parse(kwargs)
146+
147+
train_data = get_train_data()
148+
test_data = get_test_data()
149+
150+
fine_tune_trainer = FineTuneTrainer()
151+
fine_tune_trainer.fit(train_data, test_data)
152+
153+
154+
if __name__ == '__main__':
155+
import fire
156+
157+
fire.Fire()

0 commit comments

Comments
 (0)