From 44c41bc507ac2f46fb169a1fea18b95660aa6360 Mon Sep 17 00:00:00 2001 From: Mish <10400064+mishushakov@users.noreply.github.com> Date: Fri, 19 Sep 2025 23:51:57 +0200 Subject: [PATCH 1/2] pre-flight check to ensure web socket is connected and reconnect --- template/server/messaging.py | 87 ++++++++++++++++++++++++++++++++++-- 1 file changed, 83 insertions(+), 4 deletions(-) diff --git a/template/server/messaging.py b/template/server/messaging.py index e541351..42c4adf 100644 --- a/template/server/messaging.py +++ b/template/server/messaging.py @@ -80,6 +80,62 @@ async def connect(self): name="receive_message", ) + async def reconnect(self, max_retries: int = 5, retry_delay: float = 0.1): + """Reconnect the WebSocket if it's disconnected with retry logic.""" + logger.info(f"Attempting to reconnect WebSocket {self.context_id}") + + # Close existing connection if any + if self._ws is not None: + try: + await self._ws.close() + except Exception as e: + logger.warning(f"Error closing existing WebSocket: {e}") + + # Cancel existing receive task if any + if self._receive_task is not None and not self._receive_task.done(): + self._receive_task.cancel() + try: + await self._receive_task + except asyncio.CancelledError: + pass + + # Reset WebSocket and task references + self._ws = None + self._receive_task = None + + # Attempt to reconnect with fixed delay + for attempt in range(max_retries): + try: + await self.connect() + logger.info(f"Successfully reconnected WebSocket {self.context_id} on attempt {attempt + 1}") + return True + except Exception as e: + if attempt < max_retries - 1: + logger.warning(f"Reconnection attempt {attempt + 1} failed: {e}. Retrying in {retry_delay}s...") + await asyncio.sleep(retry_delay) + else: + logger.error(f"Failed to reconnect WebSocket {self.context_id} after {max_retries} attempts: {e}") + return False + + return False + + def is_connected(self) -> bool: + """Check if the WebSocket is connected and healthy.""" + return ( + self._ws is not None + and not self._ws.closed + and self._receive_task is not None + and not self._receive_task.done() + ) + + async def ensure_connected(self): + """Ensure WebSocket is connected, reconnect if necessary.""" + if not self.is_connected(): + logger.warning(f"WebSocket {self.context_id} is not connected, attempting to reconnect") + success = await self.reconnect() + if not success: + raise Exception(f"Failed to reconnect WebSocket {self.context_id}") + def _get_execute_request( self, msg_id: str, code: Union[str, StrictStr], background: bool ) -> str: @@ -209,11 +265,15 @@ async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]): cleanup_code = self._reset_env_vars_code(env_vars) if cleanup_code: logger.info(f"Cleaning up env vars: {cleanup_code}") + # Ensure WebSocket is connected before sending cleanup request + await self.ensure_connected() request = self._get_execute_request(message_id, cleanup_code, True) + if self._ws is None: + raise Exception("WebSocket not connected") await self._ws.send(request) async for item in self._wait_for_result(message_id): - if item["type"] == "error": + if isinstance(item, dict) and item.get("type") == "error": logger.error(f"Error during env var cleanup: {item}") finally: del self._executions[message_id] @@ -242,6 +302,10 @@ async def change_current_directory( ): message_id = str(uuid.uuid4()) self._executions[message_id] = Execution(in_background=True) + + # Ensure WebSocket is connected before changing directory + await self.ensure_connected() + if language == "python": request = self._get_execute_request(message_id, f"%cd {path}", True) elif language == "deno": @@ -262,10 +326,13 @@ async def change_current_directory( else: return + if self._ws is None: + raise Exception("WebSocket not connected") + await self._ws.send(request) async for item in self._wait_for_result(message_id): - if item["type"] == "error": + if isinstance(item, dict) and item.get("type") == "error": raise ExecutionError(f"Error during execution: {item}") async def execute( @@ -277,8 +344,8 @@ async def execute( message_id = str(uuid.uuid4()) self._executions[message_id] = Execution() - if self._ws is None: - raise Exception("WebSocket not connected") + # Ensure WebSocket is connected before executing + await self.ensure_connected() async with self._lock: # Wait for any pending cleanup task to complete @@ -319,6 +386,8 @@ async def execute( request = self._get_execute_request(message_id, complete_code, False) # Send the code for execution + if self._ws is None: + raise Exception("WebSocket not connected") await self._ws.send(request) # Stream the results @@ -343,6 +412,16 @@ async def _receive_message(self): await self._process_message(json.loads(message)) except Exception as e: logger.error(f"WebSocket received error while receiving messages: {str(e)}") + # Mark all pending executions as failed due to connection loss + for execution in self._executions.values(): + await execution.queue.put( + Error( + name="ConnectionLost", + value="WebSocket connection was lost during execution", + traceback="", + ) + ) + await execution.queue.put(UnexpectedEndOfExecution()) async def _process_message(self, data: dict): """ From 21ecc8aeb332e11b3fd617c10bef8d27a233f8bc Mon Sep 17 00:00:00 2001 From: Mish <10400064+mishushakov@users.noreply.github.com> Date: Fri, 19 Sep 2025 23:59:12 +0200 Subject: [PATCH 2/2] reconnect on receive --- template/server/messaging.py | 52 +++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/template/server/messaging.py b/template/server/messaging.py index 42c4adf..1be4f2d 100644 --- a/template/server/messaging.py +++ b/template/server/messaging.py @@ -83,14 +83,14 @@ async def connect(self): async def reconnect(self, max_retries: int = 5, retry_delay: float = 0.1): """Reconnect the WebSocket if it's disconnected with retry logic.""" logger.info(f"Attempting to reconnect WebSocket {self.context_id}") - + # Close existing connection if any if self._ws is not None: try: await self._ws.close() except Exception as e: logger.warning(f"Error closing existing WebSocket: {e}") - + # Cancel existing receive task if any if self._receive_task is not None and not self._receive_task.done(): self._receive_task.cancel() @@ -98,40 +98,48 @@ async def reconnect(self, max_retries: int = 5, retry_delay: float = 0.1): await self._receive_task except asyncio.CancelledError: pass - + # Reset WebSocket and task references self._ws = None self._receive_task = None - + # Attempt to reconnect with fixed delay for attempt in range(max_retries): try: await self.connect() - logger.info(f"Successfully reconnected WebSocket {self.context_id} on attempt {attempt + 1}") + logger.info( + f"Successfully reconnected WebSocket {self.context_id} on attempt {attempt + 1}" + ) return True except Exception as e: if attempt < max_retries - 1: - logger.warning(f"Reconnection attempt {attempt + 1} failed: {e}. Retrying in {retry_delay}s...") + logger.warning( + f"Reconnection attempt {attempt + 1} failed: {e}. Retrying in {retry_delay}s..." + ) await asyncio.sleep(retry_delay) else: - logger.error(f"Failed to reconnect WebSocket {self.context_id} after {max_retries} attempts: {e}") + logger.error( + f"Failed to reconnect WebSocket {self.context_id} after {max_retries} attempts: {e}" + ) return False - + return False def is_connected(self) -> bool: """Check if the WebSocket is connected and healthy.""" return ( - self._ws is not None - and not self._ws.closed - and self._receive_task is not None + self._ws is not None + and not self._ws.closed + and self._receive_task is not None and not self._receive_task.done() ) async def ensure_connected(self): """Ensure WebSocket is connected, reconnect if necessary.""" if not self.is_connected(): - logger.warning(f"WebSocket {self.context_id} is not connected, attempting to reconnect") + logger.warning( + f"WebSocket {self.context_id} is not connected, attempting to reconnect" + ) success = await self.reconnect() if not success: raise Exception(f"Failed to reconnect WebSocket {self.context_id}") @@ -302,10 +310,10 @@ async def change_current_directory( ): message_id = str(uuid.uuid4()) self._executions[message_id] = Execution(in_background=True) - + # Ensure WebSocket is connected before changing directory await self.ensure_connected() - + if language == "python": request = self._get_execute_request(message_id, f"%cd {path}", True) elif language == "deno": @@ -328,7 +336,7 @@ async def change_current_directory( if self._ws is None: raise Exception("WebSocket not connected") - + await self._ws.send(request) async for item in self._wait_for_result(message_id): @@ -412,6 +420,20 @@ async def _receive_message(self): await self._process_message(json.loads(message)) except Exception as e: logger.error(f"WebSocket received error while receiving messages: {str(e)}") + + # Attempt to reconnect when connection drops + logger.info("Attempting to reconnect due to connection loss...") + reconnect_success = await self.reconnect() + + if reconnect_success: + logger.info("Successfully reconnected after connection loss") + # Continue receiving messages with the new connection + try: + async for message in self._ws: + await self._process_message(json.loads(message)) + except Exception as reconnect_e: + logger.error(f"Error in reconnected WebSocket: {str(reconnect_e)}") + # Mark all pending executions as failed due to connection loss for execution in self._executions.values(): await execution.queue.put(