Skip to content

Commit e12963d

Browse files
author
Xuye (Chris) Qin
authored
Fix XGBoost when some workers do not have evals data (mars-project#2861)
1 parent 96af4fa commit e12963d

File tree

4 files changed

+298
-128
lines changed

4 files changed

+298
-128
lines changed

mars/learn/contrib/xgboost/start_tracker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ def execute(cls, ctx, op):
4747

4848
env = {"DMLC_NUM_WORKER": op.n_workers}
4949
rabit_context = RabitTracker(
50-
hostIP=ctx.get_local_host_ip(), nslave=op.n_workers
50+
host_ip=ctx.get_local_host_ip(), n_workers=op.n_workers
5151
)
52-
env.update(rabit_context.slave_envs())
52+
env.update(rabit_context.worker_envs())
5353

5454
rabit_context.start(op.n_workers)
5555
thread = Thread(target=rabit_context.join)

mars/learn/contrib/xgboost/tests/test_train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,17 @@ def test_local_train_dataframe(setup):
6060

6161

6262
@pytest.mark.skipif(xgboost is None, reason="XGBoost not installed")
63-
def test_train_evals(setup_cluster):
63+
@pytest.mark.parametrize("chunk_size", [n_rows // 5, n_rows])
64+
def test_train_evals(setup_cluster, chunk_size):
6465
rs = mt.random.RandomState(0)
6566
# keep 1 chunk for X and y
6667
X = rs.rand(n_rows, n_columns, chunk_size=(n_rows, n_columns // 2))
6768
y = rs.rand(n_rows, chunk_size=n_rows)
6869
base_margin = rs.rand(n_rows, chunk_size=n_rows)
6970
dtrain = MarsDMatrix(X, y, base_margin=base_margin)
7071
eval_x = MarsDMatrix(
71-
rs.rand(n_rows, n_columns, chunk_size=n_rows // 5),
72-
rs.rand(n_rows, chunk_size=n_rows // 5),
72+
rs.rand(n_rows, n_columns, chunk_size=chunk_size),
73+
rs.rand(n_rows, chunk_size=chunk_size),
7374
)
7475
evals = [(eval_x, "eval_x")]
7576
eval_result = dict()

0 commit comments

Comments
 (0)