Skip to content

Commit 02ae18a

Browse files
committed
fix quant module
1 parent 764d8f5 commit 02ae18a

File tree

3 files changed

+104
-32
lines changed

3 files changed

+104
-32
lines changed

deploy/slim/prune/pruning_and_finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def main():
135135

136136
if alg in ['EAST', 'DB']:
137137
program.train_eval_det_run(
138-
config, exe, train_info_dict, eval_info_dict, is_pruning=True)
138+
config, exe, train_info_dict, eval_info_dict, is_slim="prune")
139139
else:
140140
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
141141

deploy/slim/quantization/quant.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,13 @@ def main():
155155
act_preprocess_func=act_preprocess_func,
156156
optimizer_func=optimizer_func,
157157
executor=executor,
158-
for_test=False,
159-
return_program=True)
158+
for_test=False)
160159

161160
# compile program for multi-devices
162161
train_compile_program = program.create_multi_devices_program(
163162
quant_train_program, train_opt_loss_name, for_quant=True)
164163

165-
init_model(config, quant_train_program, exe)
164+
init_model(config, train_program, exe)
166165

167166
train_info_dict = {'compile_program':train_compile_program,\
168167
'train_program':quant_train_program,\
@@ -177,9 +176,11 @@ def main():
177176
'fetch_varname_list':eval_fetch_varname_list}
178177

179178
if train_alg_type == 'det':
180-
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
179+
program.train_eval_det_run(
180+
config, exe, train_info_dict, eval_info_dict, is_slim="quant")
181181
else:
182-
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
182+
program.train_eval_rec_run(
183+
config, exe, train_info_dict, eval_info_dict, is_slim="quant")
183184

184185

185186
if __name__ == '__main__':

tools/program.py

Lines changed: 97 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,11 @@ def create_multi_devices_program(program, loss_var_name, for_quant=False):
241241
build_strategy.enable_inplace = True
242242
if for_quant:
243243
build_strategy.fuse_all_reduce_ops = False
244+
else:
245+
program = fluid.CompiledProgram(program)
244246
exec_strategy = fluid.ExecutionStrategy()
245247
exec_strategy.num_iteration_per_drop_scope = 1
246-
compile_program = fluid.CompiledProgram(program).with_data_parallel(
248+
compile_program = program.with_data_parallel(
247249
loss_name=loss_var_name,
248250
build_strategy=build_strategy,
249251
exec_strategy=exec_strategy)
@@ -254,7 +256,7 @@ def train_eval_det_run(config,
254256
exe,
255257
train_info_dict,
256258
eval_info_dict,
257-
is_pruning=False):
259+
is_slim=None):
258260
'''
259261
main program of evaluation for detection
260262
'''
@@ -313,14 +315,17 @@ def train_eval_det_run(config,
313315
best_batch_id = train_batch_id
314316
best_epoch = epoch
315317
save_path = save_model_dir + "/best_accuracy"
316-
if is_pruning:
317-
import paddleslim as slim
318-
slim.prune.save_model(
319-
exe, train_info_dict['train_program'],
320-
save_path)
321-
else:
318+
if is_slim is None:
322319
save_model(train_info_dict['train_program'],
323320
save_path)
321+
else:
322+
import paddleslim as slim
323+
if is_slim == "prune":
324+
slim.prune.save_model(
325+
exe, train_info_dict['train_program'],
326+
save_path)
327+
elif is_slim == "quant":
328+
save_model(eval_info_dict['program'], save_path)
324329
strs = 'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'.format(
325330
train_batch_id, metrics, best_eval_hmean, best_epoch,
326331
best_batch_id)
@@ -331,24 +336,34 @@ def train_eval_det_run(config,
331336
train_loader.reset()
332337
if epoch == 0 and save_epoch_step == 1:
333338
save_path = save_model_dir + "/iter_epoch_0"
334-
if is_pruning:
335-
import paddleslim as slim
336-
slim.prune.save_model(exe, train_info_dict['train_program'],
337-
save_path)
338-
else:
339+
if is_slim is None:
339340
save_model(train_info_dict['train_program'], save_path)
341+
else:
342+
import paddleslim as slim
343+
if is_slim == "prune":
344+
slim.prune.save_model(exe, train_info_dict['train_program'],
345+
save_path)
346+
elif is_slim == "quant":
347+
save_model(eval_info_dict['program'], save_path)
340348
if epoch > 0 and epoch % save_epoch_step == 0:
341349
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
342-
if is_pruning:
343-
import paddleslim as slim
344-
slim.prune.save_model(exe, train_info_dict['train_program'],
345-
save_path)
346-
else:
350+
if is_slim is None:
347351
save_model(train_info_dict['train_program'], save_path)
352+
else:
353+
import paddleslim as slim
354+
if is_slim == "prune":
355+
slim.prune.save_model(exe, train_info_dict['train_program'],
356+
save_path)
357+
elif is_slim == "quant":
358+
save_model(eval_info_dict['program'], save_path)
348359
return
349360

350361

351-
def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
362+
def train_eval_rec_run(config,
363+
exe,
364+
train_info_dict,
365+
eval_info_dict,
366+
is_slim=None):
352367
'''
353368
main program of evaluation for recognition
354369
'''
@@ -428,7 +443,17 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
428443
best_batch_id = train_batch_id
429444
best_epoch = epoch
430445
save_path = save_model_dir + "/best_accuracy"
431-
save_model(train_info_dict['train_program'], save_path)
446+
if is_slim is None:
447+
save_model(train_info_dict['train_program'],
448+
save_path)
449+
else:
450+
import paddleslim as slim
451+
if is_slim == "prune":
452+
slim.prune.save_model(
453+
exe, train_info_dict['train_program'],
454+
save_path)
455+
elif is_slim == "quant":
456+
save_model(eval_info_dict['program'], save_path)
432457
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
433458
train_batch_id, eval_acc, best_eval_acc, best_epoch,
434459
best_batch_id, eval_sample_num)
@@ -439,14 +464,34 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
439464
train_loader.reset()
440465
if epoch == 0 and save_epoch_step == 1:
441466
save_path = save_model_dir + "/iter_epoch_0"
442-
save_model(train_info_dict['train_program'], save_path)
467+
if is_slim is None:
468+
save_model(train_info_dict['train_program'], save_path)
469+
else:
470+
import paddleslim as slim
471+
if is_slim == "prune":
472+
slim.prune.save_model(exe, train_info_dict['train_program'],
473+
save_path)
474+
elif is_slim == "quant":
475+
save_model(eval_info_dict['program'], save_path)
443476
if epoch > 0 and epoch % save_epoch_step == 0:
444477
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
445-
save_model(train_info_dict['train_program'], save_path)
478+
if is_slim is None:
479+
save_model(train_info_dict['train_program'], save_path)
480+
else:
481+
import paddleslim as slim
482+
if is_slim == "prune":
483+
slim.prune.save_model(exe, train_info_dict['train_program'],
484+
save_path)
485+
elif is_slim == "quant":
486+
save_model(eval_info_dict['program'], save_path)
446487
return
447488

448489

449-
def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
490+
def train_eval_cls_run(config,
491+
exe,
492+
train_info_dict,
493+
eval_info_dict,
494+
is_slim=None):
450495
train_batch_id = 0
451496
log_smooth_window = config['Global']['log_smooth_window']
452497
epoch_num = config['Global']['epoch_num']
@@ -509,7 +554,17 @@ def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
509554
best_batch_id = train_batch_id
510555
best_epoch = epoch
511556
save_path = save_model_dir + "/best_accuracy"
512-
save_model(train_info_dict['train_program'], save_path)
557+
if is_slim is None:
558+
save_model(train_info_dict['train_program'],
559+
save_path)
560+
else:
561+
import paddleslim as slim
562+
if is_slim == "prune":
563+
slim.prune.save_model(
564+
exe, train_info_dict['train_program'],
565+
save_path)
566+
elif is_slim == "quant":
567+
save_model(eval_info_dict['program'], save_path)
513568
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
514569
train_batch_id, eval_acc, best_eval_acc, best_epoch,
515570
best_batch_id, eval_sample_num)
@@ -520,10 +575,26 @@ def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
520575
train_loader.reset()
521576
if epoch == 0 and save_epoch_step == 1:
522577
save_path = save_model_dir + "/iter_epoch_0"
523-
save_model(train_info_dict['train_program'], save_path)
578+
if is_slim is None:
579+
save_model(train_info_dict['train_program'], save_path)
580+
else:
581+
import paddleslim as slim
582+
if is_slim == "prune":
583+
slim.prune.save_model(exe, train_info_dict['train_program'],
584+
save_path)
585+
elif is_slim == "quant":
586+
save_model(eval_info_dict['program'], save_path)
524587
if epoch > 0 and epoch % save_epoch_step == 0:
525588
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
526-
save_model(train_info_dict['train_program'], save_path)
589+
if is_slim is None:
590+
save_model(train_info_dict['train_program'], save_path)
591+
else:
592+
import paddleslim as slim
593+
if is_slim == "prune":
594+
slim.prune.save_model(exe, train_info_dict['train_program'],
595+
save_path)
596+
elif is_slim == "quant":
597+
save_model(eval_info_dict['program'], save_path)
527598
return
528599

529600

0 commit comments

Comments
 (0)