diff --git a/tpot/base.py b/tpot/base.py index df1b8fc17..b26502601 100644 --- a/tpot/base.py +++ b/tpot/base.py @@ -144,7 +144,11 @@ def __init__(self, generations=100, population_size=100, offspring_size=None, Random number generator seed for TPOT. Use this to make sure that TPOT will give you the same results each time you run it against the same data set with that seed. - config_dict: string (default: None) + config_dict: Python dictionary or string (default: None) + Python dictionary: + A dictionary customizing the operators and parameters that + TPOT uses in the optimization process. + For examples, see config_regressor.py and config_classifier.py Path for configuration file: A path to a configuration file for customizing the operators and parameters that TPOT uses in the optimization process. @@ -195,7 +199,9 @@ def __init__(self, generations=100, population_size=100, offspring_size=None, self.offspring_size = population_size if config_dict: - if config_dict == 'TPOT light': + if isinstance(config_dict, dict): + self.config_dict = config_dict + elif config_dict == 'TPOT light': if self.classification: self.config_dict = classifier_config_dict_light else: @@ -209,8 +215,8 @@ def __init__(self, generations=100, population_size=100, offspring_size=None, else: try: with open(config_dict, 'r') as input_file: - file_string = input_file.read() - operator_dict = eval(file_string[file_string.find('{'):(file_string.rfind('}') + 1)]) + file_string = input_file.read() + self.config_dict = eval(file_string[file_string.find('{'):(file_string.rfind('}') + 1)]) except: raise TypeError('The operator configuration file is in a bad format or not available. ' 'Please check the configuration file before running TPOT.')