Skip to content

Commit 79dd902

Browse files
authored
[Ray] Ray client channel get recv when first complied (mars-project#2740)
1 parent 930b3f0 commit 79dd902

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

mars/oscar/backends/ray/communication.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class RayChannelBase(Channel, ABC):
120120
Channel for communications between ray processes.
121121
"""
122122

123-
__slots__ = "_channel_index", "_channel_id", "_in_queue", "_closed"
123+
__slots__ = "_channel_index", "_channel_id", "_closed"
124124

125125
name = "ray"
126126
_channel_index_gen = itertools.count()
@@ -142,7 +142,6 @@ def __init__(
142142
self._channel_id = channel_id or ChannelID(
143143
local_address, _gen_client_id(), self._channel_index, dest_address
144144
)
145-
self._in_queue = asyncio.Queue()
146145
self._closed = asyncio.Event()
147146

148147
@property
@@ -169,7 +168,7 @@ class RayClientChannel(RayChannelBase):
169168
A channel from ray driver/actor to ray actor. Use ray call reply for client channel recv.
170169
"""
171170

172-
__slots__ = ("_peer_actor",)
171+
__slots__ = "_peer_actor", "_done", "_todo"
173172

174173
def __init__(
175174
self,
@@ -181,36 +180,47 @@ def __init__(
181180
super().__init__(None, dest_address, channel_index, channel_id, compression)
182181
# ray actor should be created with the address as the name.
183182
self._peer_actor: "ray.actor.ActorHandle" = ray.get_actor(dest_address)
183+
self._done = asyncio.Queue()
184+
self._todo = set()
185+
186+
def _submit_task(self, message: Any, object_ref: "ray.ObjectRef"):
187+
async def handle_task(message: Any, object_ref: "ray.ObjectRef"):
188+
# use `%.500` to avoid print too long messages
189+
with debug_async_timeout(
190+
"ray_object_retrieval_timeout", "Client sent message is %.500s", message
191+
):
192+
result = await object_ref
193+
if isinstance(result, RayChannelException):
194+
raise result.exc_value.with_traceback(result.exc_traceback)
195+
return result.message
196+
197+
def _on_completion(future):
198+
self._todo.remove(future)
199+
self._done.put_nowait(future)
200+
201+
future = asyncio.ensure_future(handle_task(message, object_ref))
202+
future.add_done_callback(_on_completion)
203+
self._todo.add(future)
184204

185205
@implements(Channel.send)
186206
async def send(self, message: Any):
187207
if self._closed.is_set(): # pragma: no cover
188208
raise ChannelClosed("Channel already closed, cannot send message")
189-
# Put ray object ref to queue
190-
self._in_queue.put_nowait(
191-
(
192-
message,
193-
self._peer_actor.__on_ray_recv__.remote(
194-
self.channel_id, _ArgWrapper(message)
195-
),
196-
)
209+
# Put ray object ref to todo queue
210+
task = self._peer_actor.__on_ray_recv__.remote(
211+
self.channel_id, _ArgWrapper(message)
197212
)
213+
self._submit_task(message, task)
214+
await asyncio.sleep(0)
198215

199216
@implements(Channel.recv)
200217
async def recv(self):
201218
if self._closed.is_set(): # pragma: no cover
202219
raise ChannelClosed("Channel already closed, cannot recv message")
203220
try:
204-
# Wait on ray object ref
205-
message, object_ref = await self._in_queue.get()
206-
# use `%.500` to avoid print too long messages
207-
with debug_async_timeout(
208-
"ray_object_retrieval_timeout", "Client sent message is %.500s", message
209-
):
210-
result = await object_ref
211-
if isinstance(result, RayChannelException):
212-
raise result.exc_value.with_traceback(result.exc_traceback)
213-
return result.message
221+
# Wait first done.
222+
future = await self._done.get()
223+
return future.result()
214224
except ray.exceptions.RayActorError:
215225
if not self._closed.is_set():
216226
# raise a EOFError as the SocketChannel does
@@ -228,7 +238,7 @@ class RayServerChannel(RayChannelBase):
228238
message's reply.
229239
"""
230240

231-
__slots__ = "_out_queue", "_msg_recv_counter", "_msg_sent_counter"
241+
__slots__ = "_in_queue", "_out_queue", "_msg_recv_counter", "_msg_sent_counter"
232242

233243
def __init__(
234244
self,
@@ -238,6 +248,7 @@ def __init__(
238248
compression=None,
239249
):
240250
super().__init__(local_address, None, channel_index, channel_id, compression)
251+
self._in_queue = asyncio.Queue()
241252
self._out_queue = asyncio.Queue()
242253
self._msg_recv_counter = 0
243254
self._msg_sent_counter = 0

0 commit comments

Comments
 (0)