Skip to content

Commit 9423511

Browse files
authored
Fix dtype infer in DataFrame arithmetic on datetime consts (mars-project#2879)
1 parent 674cdf2 commit 9423511

File tree

4 files changed

+40
-2
lines changed

4 files changed

+40
-2
lines changed

mars/core/graph/tests/test_graph.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
import pytest
1617

1718
from .... import tensor as mt
@@ -103,4 +104,17 @@ def test_to_dot():
103104
graph = arr2.build_graph(fuse_enabled=False, tile=True)
104105

105106
dot = str(graph.to_dot(trunc_key=5))
106-
assert all(str(n.op.key)[5] in dot for n in graph) is True
107+
try:
108+
assert all(str(n.op.key)[5] in dot for n in graph) is True
109+
except AssertionError:
110+
graph_reprs = []
111+
for n in graph:
112+
graph_reprs.append(
113+
f"{n.op.key} -> {[succ.op.key for succ in graph.successors(n)]}"
114+
)
115+
logging.error(
116+
"Unexpected error in test_to_dot.\ndot = %r\ngraph_repr: %r",
117+
dot,
118+
"\n".join(graph_reprs),
119+
)
120+
raise

mars/dataframe/arithmetic/tests/test_arithmetic.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import datetime
1516
import itertools
1617
import operator
1718
from dataclasses import dataclass
@@ -1550,3 +1551,19 @@ def test_arithmetic_lazy_chunk_meta():
15501551
pd.testing.assert_index_equal(chunk.index_value.to_pandas(), pd.RangeIndex(3))
15511552
assert chunk._FIELD_VALUES.get("_columns_value") is None
15521553
pd.testing.assert_index_equal(chunk.columns_value.to_pandas(), pd.RangeIndex(3))
1554+
1555+
1556+
def test_datetime_arithmetic():
1557+
data1 = (
1558+
pd.Series([pd.Timedelta(days=d) for d in range(10)]) + datetime.datetime.now()
1559+
)
1560+
s1 = from_pandas_series(data1)
1561+
1562+
assert (s1 + pd.Timedelta(days=10)).dtype == (data1 + pd.Timedelta(days=10)).dtype
1563+
assert (s1 + datetime.timedelta(days=10)).dtype == (
1564+
data1 + datetime.timedelta(days=10)
1565+
).dtype
1566+
assert (s1 - pd.Timestamp.now()).dtype == (data1 - pd.Timestamp.now()).dtype
1567+
assert (s1 - datetime.datetime.now()).dtype == (
1568+
data1 - datetime.datetime.now()
1569+
).dtype

mars/services/scheduling/worker/quota.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,9 @@ async def update_mem_stats(self):
334334
await self._process_requests()
335335
self._last_memory_available = cur_mem_available
336336
self._report_quota_info()
337-
self.ref().update_mem_stats.tell_delay(delay=self._refresh_time)
337+
self._stat_refresh_task = self.ref().update_mem_stats.tell_delay(
338+
delay=self._refresh_time
339+
)
338340

339341
async def _has_space(self, delta: int):
340342
if self._hard_limit is None:

mars/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import asyncio
1818
import dataclasses
19+
import datetime
1920
import enum
2021
import functools
2122
import importlib
@@ -1033,6 +1034,10 @@ def is_object_dtype(dtype: np.dtype) -> bool:
10331034
def get_dtype(dtype: Union[np.dtype, pd.api.extensions.ExtensionDtype]):
10341035
if pd.api.types.is_extension_array_dtype(dtype):
10351036
return dtype
1037+
elif dtype is pd.Timestamp or dtype is datetime.datetime:
1038+
return np.dtype("datetime64[ns]")
1039+
elif dtype is pd.Timedelta or dtype is datetime.timedelta:
1040+
return np.dtype("timedelta64[ns]")
10361041
else:
10371042
return np.dtype(dtype)
10381043

0 commit comments

Comments
 (0)