@@ -615,7 +615,7 @@ async def channel():
615615 writer .get_extra_info = dict (peername = remote_addr , sockname = remote_addr ).get
616616 return reader , writer
617617 async def wait_ssh_connection (self , local_addr = None , family = 0 , tunnel = None ):
618- if self .sshconn is not None :
618+ if self .sshconn is not None and not self . sshconn . cancelled () :
619619 if not self .sshconn .done ():
620620 await self .sshconn
621621 else :
@@ -633,18 +633,24 @@ async def wait_ssh_connection(self, local_addr=None, family=0, tunnel=None):
633633 conn = await asyncssh .connect (host = self .host_name , port = self .port , local_addr = local_addr , family = family , x509_trusted_certs = None , known_hosts = None , username = username , password = password , client_keys = client_keys , keepalive_interval = 60 , tunnel = tunnel )
634634 self .sshconn .set_result (conn )
635635 async def wait_open_connection (self , host , port , local_addr , family , tunnel = None ):
636- await self .wait_ssh_connection (local_addr , family , tunnel )
637- conn = self .sshconn .result ()
638- if isinstance (self .jump , ProxySSH ):
639- reader , writer = await self .jump .wait_open_connection (host , port , None , None , conn )
640- else :
641- host , port = self .jump .destination (host , port )
642- if self .jump .unix :
643- reader , writer = await conn .open_unix_connection (self .jump .bind )
636+ try :
637+ await self .wait_ssh_connection (local_addr , family , tunnel )
638+ conn = self .sshconn .result ()
639+ if isinstance (self .jump , ProxySSH ):
640+ reader , writer = await self .jump .wait_open_connection (host , port , None , None , conn )
644641 else :
645- reader , writer = await conn .open_connection (host , port )
646- reader , writer = self .patch_stream (reader , writer , host , port )
647- return reader , writer
642+ host , port = self .jump .destination (host , port )
643+ if self .jump .unix :
644+ reader , writer = await conn .open_unix_connection (self .jump .bind )
645+ else :
646+ reader , writer = await conn .open_connection (host , port )
647+ reader , writer = self .patch_stream (reader , writer , host , port )
648+ return reader , writer
649+ except Exception as ex :
650+ if not self .sshconn .done ():
651+ self .sshconn .set_exception (ex )
652+ self .sshconn = None
653+ raise
648654 async def start_server (self , args , stream_handler = stream_handler , tunnel = None ):
649655 if type (self .jump ) is ProxyDirect :
650656 raise Exception ('ssh server mode unsupported' )
0 commit comments