Skip to content

Commit b7a7b9b

Browse files
committed
simplify grid build
1 parent f4af925 commit b7a7b9b

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

python/pyspark/ml/tuning.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616
#
1717

18+
import itertools
19+
1820
__all__ = ['ParamGridBuilder']
1921

2022

@@ -76,17 +78,9 @@ def build(self):
7678
Builds and returns all combinations of parameters specified
7779
by the param grid.
7880
"""
79-
param_maps = [{}]
80-
for (param, values) in self._param_grid.items():
81-
new_param_maps = []
82-
for value in values:
83-
for old_map in param_maps:
84-
copied_map = old_map.copy()
85-
copied_map[param] = value
86-
new_param_maps.append(copied_map)
87-
param_maps = new_param_maps
88-
89-
return param_maps
81+
keys = self._param_grid.keys()
82+
grid_values = self._param_grid.values()
83+
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
9084

9185

9286
if __name__ == "__main__":

0 commit comments

Comments
 (0)