Skip to content

Commit b95ee36

Browse files
author
Xuye (Chris) Qin
authored
Improved performance of mars.learn.metrics.{roc_curve, roc_auc_score} (mars-project#2838)
1 parent 505f884 commit b95ee36

File tree

19 files changed

+374
-199
lines changed

19 files changed

+374
-199
lines changed

.github/workflows/benchmark-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
git remote add upstream https://github.com/mars-project/mars.git
5050
git fetch upstream
5151
asv machine --yes
52-
asv continuous -f 1.1 upstream/master HEAD
52+
asv continuous -f 1.1 --strict upstream/master HEAD
5353
if: ${{ steps.build.outcome == 'success' }}
5454

5555
- name: Publish benchmarks artifact
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 1999-2022 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 1999-2022 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import mars
16+
import mars.tensor as mt
17+
from sklearn.datasets import make_classification
18+
from sklearn.linear_model import LogisticRegression
19+
from mars.learn import metrics
20+
21+
22+
class MetricsSuite:
23+
"""
24+
Benchmark learn metrics.
25+
"""
26+
27+
params = [20_000, 100_000]
28+
29+
def setup(self, chunk_size: int):
30+
X, y = make_classification(100_000, random_state=0)
31+
self.raw_X, self.raw_y = X, y
32+
clf = LogisticRegression(solver="liblinear", random_state=0).fit(X, y)
33+
self.raw_pred_y = clf.predict_proba(X)[:, 1]
34+
self._session = mars.new_session()
35+
self.y = mt.tensor(self.raw_y, chunk_size=chunk_size)
36+
self.pred_y = mt.tensor(self.raw_pred_y, chunk_size=chunk_size)
37+
38+
def teardown(self, chunk_size: int):
39+
self._session.stop_server()
40+
41+
def time_roc_curve_auc(self, chunk_size: int):
42+
fpr, tpr, _ = metrics.roc_curve(self.y, self.pred_y)
43+
metrics.auc(fpr, tpr)
44+
45+
def time_roc_auc_score(self, chunk_size: int):
46+
metrics.roc_auc_score(self.y, self.pred_y)

mars/config.py

Lines changed: 6 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -385,159 +385,19 @@ def validate(x):
385385
# deploy
386386
default_options.register_option("deploy.open_browser", True, validator=is_bool)
387387

388-
# Scheduler
389-
default_options.register_option(
390-
"scheduler.assign_chunk_workers", False, validator=is_bool, serialize=True
391-
)
392-
default_options.register_option(
393-
"scheduler.enable_active_push", True, validator=is_bool, serialize=True
394-
)
395-
default_options.register_option(
396-
"scheduler.enable_chunk_relocation", False, validator=is_bool, serialize=True
397-
)
398-
default_options.register_option(
399-
"scheduler.check_interval", 1, validator=is_integer, serialize=True
400-
)
401-
default_options.register_option(
402-
"scheduler.default_cpu_usage", 1, validator=(is_integer, is_float), serialize=True
403-
)
404-
default_options.register_option(
405-
"scheduler.default_cuda_usage",
406-
0.5,
407-
validator=(is_integer, is_float),
408-
serialize=True,
409-
)
410-
default_options.register_option(
411-
"scheduler.assign_timeout", 600, validator=is_integer, serialize=True
412-
)
413-
default_options.register_option(
414-
"scheduler.execution_timeout", 600, validator=is_integer, serialize=True
415-
)
416-
default_options.register_option(
417-
"scheduler.retry_num", 4, validator=is_integer, serialize=True
418-
)
419-
default_options.register_option(
420-
"scheduler.fetch_limit", 10 * 1024**2, validator=is_integer, serialize=True
421-
)
422-
default_options.register_option(
423-
"scheduler.retry_delay", 60, validator=is_integer, serialize=True
424-
)
425-
426-
default_options.register_option("scheduler.dump_graph_data", False, validator=is_bool)
427-
428-
default_options.register_option(
429-
"scheduler.enable_failover", True, validator=is_bool, serialize=True
430-
)
431-
default_options.register_option(
432-
"scheduler.status_timeout", 30, validator=is_numeric, serialize=True
433-
)
434-
default_options.register_option(
435-
"scheduler.worker_blacklist_time", 3600, validator=is_numeric, serialize=True
436-
)
437-
438-
# enqueue operands in a batch when creating OperandActors
439-
default_options.register_option(
440-
"scheduler.batch_enqueue_initials", True, validator=is_bool, serialize=True
441-
)
442-
# invoke assigning when where there is no ready descendants
443-
default_options.register_option(
444-
"scheduler.aggressive_assign", False, validator=is_bool, serialize=True
445-
)
446-
447-
# Worker
448-
default_options.register_option(
449-
"worker.spill_directory", None, validator=(is_null, is_string, is_list)
450-
)
451-
default_options.register_option(
452-
"worker.disk_compression", "lz4", validator=is_string, serialize=True
453-
)
454-
default_options.register_option(
455-
"worker.min_spill_size", "5%", validator=(is_string, is_integer)
456-
)
457-
default_options.register_option(
458-
"worker.max_spill_size", "95%", validator=(is_string, is_integer)
459-
)
460-
default_options.register_option(
461-
"worker.min_cache_mem_size", None, validator=(is_null, is_string, is_integer)
462-
)
463-
default_options.register_option(
464-
"worker.callback_preserve_time", 3600 * 24, validator=is_integer
465-
)
466-
default_options.register_option(
467-
"worker.event_preserve_time", 3600 * 24, validator=(is_integer, is_float)
468-
)
469-
default_options.register_option(
470-
"worker.copy_block_size", 64 * 1024, validator=is_integer
471-
)
472-
default_options.register_option("worker.cuda_thread_num", 2, validator=is_integer)
473-
default_options.register_option(
474-
"worker.transfer_block_size", 1 * 1024**2, validator=is_integer
475-
)
476-
default_options.register_option(
477-
"worker.transfer_compression", "lz4", validator=is_string, serialize=True
478-
)
479-
default_options.register_option(
480-
"worker.prepare_data_timeout", 1000, validator=is_integer
481-
)
482-
default_options.register_option(
483-
"worker.peer_blacklist_time", 3600, validator=is_numeric, serialize=True
484-
)
485-
default_options.register_option(
486-
"worker.io_parallel_num", 1, validator=is_integer, serialize=True
487-
)
488-
default_options.register_option(
489-
"worker.recover_dead_process", True, validator=is_bool, serialize=True
490-
)
491-
default_options.register_option(
492-
"worker.write_shuffle_to_disk", False, validator=is_bool, serialize=True
493-
)
494-
495-
default_options.register_option(
496-
"worker.filemerger.enabled", True, validator=is_bool, serialize=True
497-
)
498-
default_options.register_option(
499-
"worker.filemerger.concurrency", 128, validator=is_integer, serialize=True
500-
)
501-
default_options.register_option(
502-
"worker.filemerger.max_accept_size",
503-
128 * 1024,
504-
validator=is_integer,
505-
serialize=True,
506-
)
507-
default_options.register_option(
508-
"worker.filemerger.max_file_size",
509-
32 * 1024**2,
510-
validator=is_integer,
511-
serialize=True,
512-
)
513-
514-
default_options.register_option(
515-
"worker.plasma_dir", None, validator=(is_string, is_null)
516-
)
517-
default_options.register_option(
518-
"worker.plasma_limit", None, validator=(is_string, is_integer, is_null)
519-
)
388+
# optimization
389+
default_options.register_option("optimize_tileable_graph", True, validator=is_bool)
520390

521-
default_options.register_option(
522-
"worker.plasma_socket", "/tmp/plasma", validator=is_string
523-
)
391+
# eager mode
392+
default_options.register_option("eager_mode", False, validator=is_bool)
524393

525394
# optimization
526-
default_options.register_option("optimize.min_stats_count", 10, validator=is_integer)
527-
default_options.register_option(
528-
"optimize.stats_sufficient_ratio", 0.9, validator=is_float, serialize=True
529-
)
530-
default_options.register_option(
531-
"optimize.default_disk_io_speed", 10 * 1024**2, validator=is_integer
532-
)
533395
default_options.register_option(
534396
"optimize.head_optimize_threshold", 1000, validator=is_integer
535397
)
536398

537-
default_options.register_option("optimize_tileable_graph", True, validator=is_bool)
538-
539-
# eager mode
540-
default_options.register_option("eager_mode", False, validator=is_bool)
399+
# debug
400+
default_options.register_option("warn_duplicated_execution", False, validator=is_bool)
541401

542402
# client serialize type
543403
default_options.register_option("client.serial_type", "arrow", validator=is_string)

mars/core/entity/tileables.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import numpy as np
2323

24-
from ...serialization.serializables import FieldTypes, TupleField
24+
from ...serialization.serializables import FieldTypes, BoolField, TupleField
2525
from ...typing import OperandType, TileableType, ChunkType
2626
from ...utils import on_serialize_shape, on_deserialize_shape, on_serialize_nsplits
2727
from ..base import Base
@@ -273,6 +273,8 @@ class TileableData(EntityData, _ExecutableMixin):
273273
FieldTypes.tuple(FieldTypes.uint64),
274274
on_serialize=on_serialize_nsplits,
275275
)
276+
# cache tileable data, if true, this data will be materialized
277+
cache = BoolField("cache", default=False)
276278

277279
def __init__(self: TileableType, *args, **kwargs):
278280
if kwargs.get("_nsplits", None) is not None:

mars/dataframe/base/diff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
class DataFrameDiff(DataFrameOperandMixin, DataFrameOperand):
29-
_op_type_ = opcodes.DATAFRAME_DIFF
29+
_op_type_ = opcodes.DIFF
3030

3131
_periods = Int64Field("periods")
3232
_axis = Int8Field("axis")

0 commit comments

Comments
 (0)