Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions tpot/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,10 @@ def _get_arg_parser():
'f1_micro, '
'f1_samples, '
'f1_weighted, '
'log_loss, '
'mean_absolute_error, '
'mean_squared_error, '
'median_absolute_error, '
'neg_log_loss, '
'neg_mean_absolute_error, '
'neg_mean_squared_error, '
'neg_median_absolute_error, '
'precision, '
'precision_macro, '
'precision_micro, '
Expand All @@ -254,7 +254,7 @@ def _get_arg_parser():
default=5,
type=int,
help=(
'Number of folds to evaluate each pipeline over in k-fold '
'Number of folds to evaluate each pipeline over in stratified k-fold '
'cross-validation during the TPOT optimization process.'
)
)
Expand All @@ -266,8 +266,8 @@ def _get_arg_parser():
default=1.0,
type=float,
help=(
'Subsample ratio of the training instance. Setting it to 0.5 means that TPOT'
'randomly collects half of training samples for pipeline optimization process.'
'Subsample ratio of the training instance. Setting it to 0.5 means that TPOT '
'use a random subsample of half of training data for the pipeline optimization process.'
)
)

Expand Down Expand Up @@ -327,12 +327,13 @@ def _get_arg_parser():
'-config',
action='store',
dest='CONFIG_FILE',
default='',
default=None,
type=str,
help=(
'Configuration file for customizing the operators and parameters '
'that TPOT uses in the optimization process. Must be a python '
'module containing a dict export named "tpot_config".'
'that TPOT uses in the optimization process. Must be a Python '
'module containing a dict export named "tpot_config" or the name of '
'built-in configuration.'
)
)

Expand Down Expand Up @@ -370,14 +371,14 @@ def _get_arg_parser():

def _print_args(args):
print('\nTPOT settings:')
for arg, arg_val in args.__dict__.items():
for arg, arg_val in sorted(args.__dict__.items()):
if arg == 'DISABLE_UPDATE_CHECK':
continue
elif arg == 'SCORING_FN' and arg_val is None:
if args.TPOT_MODE == 'classification':
arg_val = 'accuracy'
else:
arg_val = 'mean_squared_error'
arg_val = 'neg_mean_squared_error'
elif arg == 'OFFSPRING_SIZE' and arg_val is None:
arg_val = args.__dict__['POPULATION_SIZE']
else:
Expand Down Expand Up @@ -419,7 +420,7 @@ def main(args):
training_features, testing_features, training_target, testing_target = \
train_test_split(features, input_data[args.TARGET_NAME], random_state=args.RANDOM_STATE)
tpot_type = TPOTClassifier if args.TPOT_MODE == 'classification' else TPOTRegressor
tpot = tpot_type(
tpot_obj = tpot_type(
generations=args.GENERATIONS,
population_size=args.POPULATION_SIZE,
offspring_size=args.OFFSPRING_SIZE,
Expand All @@ -437,29 +438,27 @@ def main(args):
disable_update_check=args.DISABLE_UPDATE_CHECK
)

print('')

tpot.fit(training_features, training_target)
tpot_obj.fit(training_features, training_target)

if args.VERBOSITY in [1, 2] and tpot._optimized_pipeline:
training_score = max([x.wvalues[1] for x in tpot._pareto_front.keys])
if args.VERBOSITY in [1, 2] and tpot_obj._optimized_pipeline:
training_score = max([x.wvalues[1] for x in tpot_obj._pareto_front.keys])
print('\nTraining score: {}'.format(abs(training_score)))
print('Holdout score: {}'.format(tpot.score(testing_features, testing_target)))
print('Holdout score: {}'.format(tpot_obj.score(testing_features, testing_target)))

elif args.VERBOSITY >= 3 and tpot._pareto_front:
elif args.VERBOSITY >= 3 and tpot_obj._pareto_front:
print('Final Pareto front testing scores:')
pipelines = zip(tpot._pareto_front.items, reversed(tpot._pareto_front.keys))
pipelines = zip(tpot_obj._pareto_front.items, reversed(tpot_obj._pareto_front.keys))
for pipeline, pipeline_scores in pipelines:
tpot._fitted_pipeline = tpot._pareto_front_fitted_pipelines[str(pipeline)]
tpot_obj._fitted_pipeline = tpot_obj.pareto_front_fitted_pipelines_[str(pipeline)]
print('{TRAIN_SCORE}\t{TEST_SCORE}\t{PIPELINE}'.format(
TRAIN_SCORE=int(abs(pipeline_scores.wvalues[0])),
TEST_SCORE=tpot.score(testing_features, testing_target),
TEST_SCORE=tpot_obj.score(testing_features, testing_target),
PIPELINE=pipeline
)
)

if args.OUTPUT_FILE != '':
tpot.export(args.OUTPUT_FILE)
tpot_obj.export(args.OUTPUT_FILE)


if __name__ == '__main__':
Expand Down