Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions template/server/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,70 @@ async def connect(self):
name="receive_message",
)

async def reconnect(self, max_retries: int = 5, retry_delay: float = 0.1):
"""Reconnect the WebSocket if it's disconnected with retry logic."""
logger.info(f"Attempting to reconnect WebSocket {self.context_id}")

# Close existing connection if any
if self._ws is not None:
try:
await self._ws.close()
except Exception as e:
logger.warning(f"Error closing existing WebSocket: {e}")

# Cancel existing receive task if any
if self._receive_task is not None and not self._receive_task.done():
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass

# Reset WebSocket and task references
self._ws = None
self._receive_task = None

# Attempt to reconnect with fixed delay
for attempt in range(max_retries):
try:
await self.connect()
logger.info(
f"Successfully reconnected WebSocket {self.context_id} on attempt {attempt + 1}"
)
return True
except Exception as e:
if attempt < max_retries - 1:
logger.warning(
f"Reconnection attempt {attempt + 1} failed: {e}. Retrying in {retry_delay}s..."
)
await asyncio.sleep(retry_delay)
else:
logger.error(
f"Failed to reconnect WebSocket {self.context_id} after {max_retries} attempts: {e}"
)
return False

return False

def is_connected(self) -> bool:
"""Check if the WebSocket is connected and healthy."""
return (
self._ws is not None
and not self._ws.closed
and self._receive_task is not None
and not self._receive_task.done()
)

async def ensure_connected(self):
"""Ensure WebSocket is connected, reconnect if necessary."""
if not self.is_connected():
logger.warning(
f"WebSocket {self.context_id} is not connected, attempting to reconnect"
)
success = await self.reconnect()
if not success:
raise Exception(f"Failed to reconnect WebSocket {self.context_id}")

def _get_execute_request(
self, msg_id: str, code: Union[str, StrictStr], background: bool
) -> str:
Expand Down Expand Up @@ -209,11 +273,15 @@ async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]):
cleanup_code = self._reset_env_vars_code(env_vars)
if cleanup_code:
logger.info(f"Cleaning up env vars: {cleanup_code}")
# Ensure WebSocket is connected before sending cleanup request
await self.ensure_connected()
request = self._get_execute_request(message_id, cleanup_code, True)
if self._ws is None:
raise Exception("WebSocket not connected")
await self._ws.send(request)

async for item in self._wait_for_result(message_id):
if item["type"] == "error":
if isinstance(item, dict) and item.get("type") == "error":
logger.error(f"Error during env var cleanup: {item}")
finally:
del self._executions[message_id]
Expand Down Expand Up @@ -242,6 +310,10 @@ async def change_current_directory(
):
message_id = str(uuid.uuid4())
self._executions[message_id] = Execution(in_background=True)

# Ensure WebSocket is connected before changing directory
await self.ensure_connected()

if language == "python":
request = self._get_execute_request(message_id, f"%cd {path}", True)
elif language == "deno":
Expand All @@ -262,10 +334,13 @@ async def change_current_directory(
else:
return

if self._ws is None:
raise Exception("WebSocket not connected")

await self._ws.send(request)

async for item in self._wait_for_result(message_id):
if item["type"] == "error":
if isinstance(item, dict) and item.get("type") == "error":
raise ExecutionError(f"Error during execution: {item}")

async def execute(
Expand All @@ -277,8 +352,8 @@ async def execute(
message_id = str(uuid.uuid4())
self._executions[message_id] = Execution()

if self._ws is None:
raise Exception("WebSocket not connected")
# Ensure WebSocket is connected before executing
await self.ensure_connected()

async with self._lock:
# Wait for any pending cleanup task to complete
Expand Down Expand Up @@ -319,6 +394,8 @@ async def execute(
request = self._get_execute_request(message_id, complete_code, False)

# Send the code for execution
if self._ws is None:
raise Exception("WebSocket not connected")
await self._ws.send(request)

# Stream the results
Expand All @@ -344,6 +421,30 @@ async def _receive_message(self):
except Exception as e:
logger.error(f"WebSocket received error while receiving messages: {str(e)}")

# Attempt to reconnect when connection drops
logger.info("Attempting to reconnect due to connection loss...")
reconnect_success = await self.reconnect()

if reconnect_success:
logger.info("Successfully reconnected after connection loss")
# Continue receiving messages with the new connection
try:
async for message in self._ws:
await self._process_message(json.loads(message))
except Exception as reconnect_e:
logger.error(f"Error in reconnected WebSocket: {str(reconnect_e)}")

# Mark all pending executions as failed due to connection loss
for execution in self._executions.values():
await execution.queue.put(
Error(
name="ConnectionLost",
value="WebSocket connection was lost during execution",
traceback="",
)
)
await execution.queue.put(UnexpectedEndOfExecution())

async def _process_message(self, data: dict):
"""
Process messages from the WebSocket
Expand Down
Loading