Skip to content

Commit 52da897

Browse files
committed
✨: Add support for parallel optimization
1 parent ba4e7ca commit 52da897

File tree

2 files changed

+45
-21
lines changed

2 files changed

+45
-21
lines changed

backtesting/backtesting.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
from math import copysign
2727
from numpy.random import default_rng
2828

29-
from .samplers import LatinHypercubeSampler
30-
3129
try:
3230
from tqdm.auto import tqdm as _tqdm
3331
_tqdm = partial(_tqdm, leave=False)
@@ -1257,6 +1255,7 @@ def optimize(self, *,
12571255
return_optimization: bool = False,
12581256
random_state: Optional[int] = None,
12591257
n_initial_points: Optional[int] = None,
1258+
n_workers: int = 1,
12601259
**kwargs) -> Union[pd.Series,
12611260
Tuple[pd.Series, pd.Series],
12621261
Tuple[pd.Series, pd.Series, dict]]:
@@ -1546,7 +1545,8 @@ def _optimize_openbox() -> Union[pd.Series,
15461545
Tuple[pd.Series, pd.Series],
15471546
Tuple[pd.Series, pd.Series, dict]]:
15481547
try:
1549-
from openbox import space as sp, Optimizer
1548+
from openbox import space as sp, Optimizer, ParallelOptimizer
1549+
from .samplers import LatinHypercubeSampler
15501550
except ImportError:
15511551
raise ImportError("Need package 'openbox' for method='openbox'. "
15521552
"pip install openbox") from None
@@ -1609,20 +1609,31 @@ def eval_run(config: sp.Configuration):
16091609
logging.warning(f'Only {len(valid_initial_configs)}/{len(initial_configs)} valid configurations are generated for initial design strategy "Latin Hypercube". ')
16101610
num_random_config = n_initial_points - len(valid_initial_configs)
16111611
valid_initial_configs += Advisor.sample_random_configs(space, num_random_config, excluded_configs=valid_initial_configs)
1612-
opt = Optimizer(
1613-
eval_run,
1614-
space,
1615-
num_constraints=1,
1616-
num_objectives=1,
1617-
surrogate_type='auto',
1618-
acq_type='auto',
1619-
acq_optimizer_type='auto',
1620-
max_runs=max_tries,
1621-
task_id='soc',
1622-
random_state=random_state,
1623-
initial_configurations=valid_initial_configs,
1624-
initial_runs=initial_runs,
1625-
)
1612+
1613+
params = {
1614+
"objective_function": eval_run,
1615+
"config_space": space,
1616+
"num_constraints": 1,
1617+
"num_objectives": 1,
1618+
"surrogate_type": 'auto',
1619+
"acq_type": 'auto',
1620+
"acq_optimizer_type": 'auto',
1621+
"max_runs": max_tries,
1622+
"task_id": 'soc',
1623+
"random_state": random_state,
1624+
"initial_configurations": valid_initial_configs,
1625+
"initial_runs": initial_runs,
1626+
}
1627+
1628+
if n_workers == 1:
1629+
opt = Optimizer(**params)
1630+
else:
1631+
opt = ParallelOptimizer(
1632+
parallel_strategy="async",
1633+
batch_size=n_workers,
1634+
batch_strategy="default",
1635+
**params
1636+
)
16261637
history = opt.run()
16271638
optimal_configurations = history.get_incumbents()
16281639
if not len(optimal_configurations):

backtesting/test/_test.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from glob import glob
1010
from runpy import run_path
1111
from tempfile import NamedTemporaryFile, gettempdir
12-
from time import time
12+
from time import time, sleep
1313
from unittest import TestCase
1414
from unittest.mock import patch
1515

@@ -87,9 +87,9 @@ def test_run_invalid_param(self):
8787

8888
def test_run_speed(self):
8989
bt = Backtest(GOOG, SmaCross)
90-
start = time.process_time()
90+
start = time()
9191
bt.run()
92-
end = time.process_time()
92+
end = time()
9393
self.assertLess(end - start, .3)
9494

9595
def test_data_missing_columns(self):
@@ -581,6 +581,19 @@ def test_method_openbox(self):
581581
random_state=2)
582582
self.assertIsInstance(res, pd.Series)
583583

584+
def test_method_openbox_parallel(self):
585+
bt = Backtest(GOOG.iloc[:100], SmaCross)
586+
res = bt.optimize(
587+
fast=range(2, 20), slow=np.arange(2, 20, dtype=object),
588+
constraint=lambda p: p.fast < p.slow,
589+
max_tries=30,
590+
method='openbox',
591+
return_optimization=False,
592+
return_heatmap=False,
593+
random_state=2,
594+
n_workers=2)
595+
self.assertIsInstance(res, pd.Series)
596+
584597
def test_timing(self):
585598
bt = Backtest(GOOG.iloc[:100], SmaCross)
586599

@@ -877,7 +890,7 @@ def test_plot_heatmaps(self):
877890

878891
# Preview
879892
plot_heatmaps(heatmap, filename=f)
880-
time.sleep(5)
893+
sleep(5)
881894

882895
def test_random_ohlc_data(self):
883896
generator = random_ohlc_data(GOOG, frac=1)

0 commit comments

Comments
 (0)