@@ -120,7 +120,7 @@ class RayChannelBase(Channel, ABC):
120120 Channel for communications between ray processes.
121121 """
122122
123- __slots__ = "_channel_index" , "_channel_id" , "_in_queue" , " _closed"
123+ __slots__ = "_channel_index" , "_channel_id" , "_closed"
124124
125125 name = "ray"
126126 _channel_index_gen = itertools .count ()
@@ -142,7 +142,6 @@ def __init__(
142142 self ._channel_id = channel_id or ChannelID (
143143 local_address , _gen_client_id (), self ._channel_index , dest_address
144144 )
145- self ._in_queue = asyncio .Queue ()
146145 self ._closed = asyncio .Event ()
147146
148147 @property
@@ -169,7 +168,7 @@ class RayClientChannel(RayChannelBase):
169168 A channel from ray driver/actor to ray actor. Use ray call reply for client channel recv.
170169 """
171170
172- __slots__ = ( "_peer_actor" ,)
171+ __slots__ = "_peer_actor" , "_done" , "_todo"
173172
174173 def __init__ (
175174 self ,
@@ -181,36 +180,47 @@ def __init__(
181180 super ().__init__ (None , dest_address , channel_index , channel_id , compression )
182181 # ray actor should be created with the address as the name.
183182 self ._peer_actor : "ray.actor.ActorHandle" = ray .get_actor (dest_address )
183+ self ._done = asyncio .Queue ()
184+ self ._todo = set ()
185+
186+ def _submit_task (self , message : Any , object_ref : "ray.ObjectRef" ):
187+ async def handle_task (message : Any , object_ref : "ray.ObjectRef" ):
188+ # use `%.500` to avoid print too long messages
189+ with debug_async_timeout (
190+ "ray_object_retrieval_timeout" , "Client sent message is %.500s" , message
191+ ):
192+ result = await object_ref
193+ if isinstance (result , RayChannelException ):
194+ raise result .exc_value .with_traceback (result .exc_traceback )
195+ return result .message
196+
197+ def _on_completion (future ):
198+ self ._todo .remove (future )
199+ self ._done .put_nowait (future )
200+
201+ future = asyncio .ensure_future (handle_task (message , object_ref ))
202+ future .add_done_callback (_on_completion )
203+ self ._todo .add (future )
184204
185205 @implements (Channel .send )
186206 async def send (self , message : Any ):
187207 if self ._closed .is_set (): # pragma: no cover
188208 raise ChannelClosed ("Channel already closed, cannot send message" )
189- # Put ray object ref to queue
190- self ._in_queue .put_nowait (
191- (
192- message ,
193- self ._peer_actor .__on_ray_recv__ .remote (
194- self .channel_id , _ArgWrapper (message )
195- ),
196- )
209+ # Put ray object ref to todo queue
210+ task = self ._peer_actor .__on_ray_recv__ .remote (
211+ self .channel_id , _ArgWrapper (message )
197212 )
213+ self ._submit_task (message , task )
214+ await asyncio .sleep (0 )
198215
199216 @implements (Channel .recv )
200217 async def recv (self ):
201218 if self ._closed .is_set (): # pragma: no cover
202219 raise ChannelClosed ("Channel already closed, cannot recv message" )
203220 try :
204- # Wait on ray object ref
205- message , object_ref = await self ._in_queue .get ()
206- # use `%.500` to avoid print too long messages
207- with debug_async_timeout (
208- "ray_object_retrieval_timeout" , "Client sent message is %.500s" , message
209- ):
210- result = await object_ref
211- if isinstance (result , RayChannelException ):
212- raise result .exc_value .with_traceback (result .exc_traceback )
213- return result .message
221+ # Wait first done.
222+ future = await self ._done .get ()
223+ return future .result ()
214224 except ray .exceptions .RayActorError :
215225 if not self ._closed .is_set ():
216226 # raise a EOFError as the SocketChannel does
@@ -228,7 +238,7 @@ class RayServerChannel(RayChannelBase):
228238 message's reply.
229239 """
230240
231- __slots__ = "_out_queue" , "_msg_recv_counter" , "_msg_sent_counter"
241+ __slots__ = "_in_queue" , " _out_queue" , "_msg_recv_counter" , "_msg_sent_counter"
232242
233243 def __init__ (
234244 self ,
@@ -238,6 +248,7 @@ def __init__(
238248 compression = None ,
239249 ):
240250 super ().__init__ (local_address , None , channel_index , channel_id , compression )
251+ self ._in_queue = asyncio .Queue ()
241252 self ._out_queue = asyncio .Queue ()
242253 self ._msg_recv_counter = 0
243254 self ._msg_sent_counter = 0
0 commit comments