@@ -240,7 +240,7 @@ async def udp_sendto(self, host, port, data, answer_cb, local_addr=None):
240240DIRECT = ProxyDirect ()
241241
242242class ProxySimple (ProxyDirect ):
243- def __init__ (self , protos , cipher , users , rule , bind ,
243+ def __init__ (self , jump , protos , cipher , users , rule , bind ,
244244 host_name , port , unix , lbind , sslclient , sslserver ):
245245 super ().__init__ (lbind )
246246 self .protos = protos
@@ -253,7 +253,7 @@ def __init__(self, protos, cipher, users, rule, bind,
253253 self .unix = unix
254254 self .sslclient = sslclient
255255 self .sslserver = sslserver
256- self .jump = None
256+ self .jump = jump
257257 def logtext (self , host , port ):
258258 return f' -> { self .rproto .name + ("+ssl" if self .sslclient else "" )} { self .bind } ' + self .jump .logtext (host , port )
259259 def match_rule (self , host , port ):
@@ -350,69 +350,91 @@ def quic_event_received(s, event):
350350 reader , writer = conn ._create_stream (stream_id )
351351 self .patch_writer (writer )
352352 return reader , writer
353- async def start_server (self , args , stream_handler = stream_handler ):
353+ def start_server (self , args , stream_handler = stream_handler ):
354354 import aioquic .asyncio
355355 def handler (reader , writer ):
356356 self .patch_writer (writer )
357357 asyncio .ensure_future (stream_handler (reader , writer , ** vars (self ), ** args ))
358- server = await aioquic .asyncio .serve (
358+ return aioquic .asyncio .serve (
359359 self .host_name ,
360360 self .port ,
361361 configuration = self .quicserver ,
362362 stream_handler = handler
363363 )
364- return server
365364
366365class ProxySSH (ProxySimple ):
367366 def __init__ (self , ** kw ):
368367 super ().__init__ (** kw )
369- self .streams = None
368+ self .sshconn = None
370369 def logtext (self , host , port ):
371370 return f' -> sshtunnel { self .bind } ' + self .jump .logtext (host , port )
372- async def wait_open_connection (self , * args , tunnel = None ):
373- if self .streams is not None :
374- if not self .streams .done ():
375- await self .streams
376- return self .streams .result ()
377- self .streams = asyncio .get_event_loop ().create_future ()
378- try :
379- import asyncssh
380- except Exception :
381- raise Exception ('Missing library: "pip3 install asyncssh"' )
382- username , password = self .auth .decode ().split (':' , 1 )
383- if password .startswith (':' ):
384- client_keys = [password [1 :]]
385- password = None
371+ def patch_stream (self , ssh_reader , writer , host , port ):
372+ reader = asyncio .StreamReader ()
373+ async def channel ():
374+ while not writer .is_closing ():
375+ buf = await ssh_reader .read (65536 )
376+ if not buf :
377+ break
378+ reader .feed_data (buf )
379+ reader .feed_eof ()
380+ asyncio .ensure_future (channel ())
381+ remote_addr = ('ssh:' + str (host ), port )
382+ writer .get_extra_info = dict (peername = remote_addr , sockname = remote_addr ).get
383+ return reader , writer
384+ async def wait_ssh_connection (self , local_addr = None , family = 0 , tunnel = None ):
385+ if self .sshconn is not None :
386+ if not self .sshconn .done ():
387+ await self .sshconn
386388 else :
387- client_keys = None
388- conn = await asyncssh .connect (host = self .host_name , port = self .port , x509_trusted_certs = None , known_hosts = None , username = username , password = password , client_keys = client_keys , keepalive_interval = 60 , tunnel = tunnel )
389- if not self .streams .done ():
390- self .streams .set_result ((conn , None ))
391- return conn , None
392- async def prepare_ciphers_and_headers (self , reader_remote , writer_remote , host , port ):
393- whost , wport = self .jump .destination (host , port )
389+ self .sshconn = asyncio .get_event_loop ().create_future ()
390+ try :
391+ import asyncssh
392+ except Exception :
393+ raise Exception ('Missing library: "pip3 install asyncssh"' )
394+ username , password = self .auth .decode ().split (':' , 1 )
395+ if password .startswith (':' ):
396+ client_keys = [password [1 :]]
397+ password = None
398+ else :
399+ client_keys = None
400+ conn = await asyncssh .connect (host = self .host_name , port = self .port , local_addr = local_addr , family = family , x509_trusted_certs = None , known_hosts = None , username = username , password = password , client_keys = client_keys , keepalive_interval = 60 , tunnel = tunnel )
401+ self .sshconn .set_result (conn )
402+ async def wait_open_connection (self , host , port , local_addr , family , tunnel = None ):
403+ await self .wait_ssh_connection (local_addr , family , tunnel )
404+ conn = self .sshconn .result ()
394405 if isinstance (self .jump , ProxySSH ):
395- reader_remote , writer_remote = await self .jump .wait_open_connection (tunnel = reader_remote )
406+ reader , writer = await self .jump .wait_open_connection (host , port , None , None , conn )
396407 else :
408+ host , port = self .jump .destination (host , port )
397409 if self .jump .unix :
398- ssh_reader_stream , writer_remote = await reader_remote .open_unix_connection (self .jump .bind )
410+ reader , writer = await conn .open_unix_connection (self .jump .bind )
399411 else :
400- ssh_reader_stream , writer_remote = await reader_remote .open_connection (whost , wport )
401- reader_remote = asyncio .StreamReader ()
402- async def channel ():
403- while not writer_remote .is_closing ():
404- buf = await ssh_reader_stream .read (65536 )
405- if not buf :
406- break
407- reader_remote .feed_data (buf )
408- reader_remote .feed_eof ()
409- asyncio .ensure_future (channel ())
410- return await self .jump .prepare_ciphers_and_headers (reader_remote , writer_remote , host , port )
412+ reader , writer = await conn .open_connection (host , port )
413+ reader , writer = self .patch_stream (reader , writer , host , port )
414+ return reader , writer
415+ async def start_server (self , args , stream_handler = stream_handler , tunnel = None ):
416+ await self .wait_ssh_connection (tunnel = tunnel )
417+ conn = self .sshconn .result ()
418+ if isinstance (self .jump , ProxySSH ):
419+ return await self .jump .start_server (args , stream_handler , conn )
420+ else :
421+ def handler (host , port ):
422+ def handler_stream (reader , writer ):
423+ reader , writer = self .patch_stream (reader , writer , host , port )
424+ return stream_handler (reader , writer , ** vars (self .jump ), ** args )
425+ return handler_stream
426+ if self .jump .unix :
427+ return await conn .start_unix_server (handler , self .jump .bind )
428+ else :
429+ return await conn .start_server (handler , self .jump .host_name , self .jump .port )
411430
412431class ProxyBackward (ProxySimple ):
413432 def __init__ (self , backward , backward_num , ** kw ):
414433 super ().__init__ (** kw )
415434 self .backward = backward
435+ self .server = backward
436+ while type (self .server .jump ) != ProxyDirect :
437+ self .server = self .server .jump
416438 self .backward_num = backward_num
417439 self .closed = False
418440 self .writers = set ()
@@ -430,7 +452,7 @@ def close(self):
430452 except Exception :
431453 pass
432454 async def start_server (self , args , stream_handler = stream_handler ):
433- handler = functools .partial (stream_handler , ** vars (self ), ** args )
455+ handler = functools .partial (stream_handler , ** vars (self . server ), ** args )
434456 for _ in range (self .backward_num ):
435457 asyncio .ensure_future (self .start_server_run (handler ))
436458 return self
@@ -443,7 +465,9 @@ async def start_server_run(self, handler):
443465 if self .closed :
444466 writer .close ()
445467 break
446- writer .write (self .auth or b'\x01 ' )
468+ if isinstance (self .server , ProxyQUIC ):
469+ writer .write (b'\x01 ' )
470+ writer .write (self .server .auth )
447471 self .writers .add (writer )
448472 try :
449473 data = await reader .read_n (1 )
@@ -456,6 +480,7 @@ async def start_server_run(self, handler):
456480 writer .close ()
457481 errwait = 0
458482 self .writers .discard (writer )
483+ writer = None
459484 except Exception as ex :
460485 try :
461486 writer .close ()
@@ -466,11 +491,14 @@ async def start_server_run(self, handler):
466491 errwait = min (errwait * 1.3 + 0.1 , 30 )
467492 def start_backward_client (self , args ):
468493 async def handler (reader , writer , ** kw ):
469- auth = self .auth or b'\x01 '
470- try :
471- assert auth == (await reader .read_n (len (auth )))
472- except Exception :
473- return
494+ auth = self .server .auth
495+ if isinstance (self .server , ProxyQUIC ):
496+ auth = b'\x01 ' + auth
497+ if auth :
498+ try :
499+ assert auth == (await reader .read_n (len (auth )))
500+ except Exception :
501+ return
474502 await self .conn .put ((reader , writer ))
475503 return self .backward .start_server (args , handler )
476504
@@ -484,11 +512,10 @@ def compile_rule(filename):
484512def proxies_by_uri (uri_jumps ):
485513 jump = DIRECT
486514 for uri in reversed (uri_jumps .split ('__' )):
487- proxy = proxy_by_uri (uri )
488- proxy .jump , jump = jump , proxy
515+ jump = proxy_by_uri (uri , jump )
489516 return jump
490517
491- def proxy_by_uri (uri ):
518+ def proxy_by_uri (uri , jump ):
492519 scheme , _ , uri = uri .partition ('://' )
493520 url = urllib .parse .urlparse ('s://' + uri )
494521 rawprotos = [i .lower () for i in scheme .split ('+' )]
@@ -551,7 +578,7 @@ def proxy_by_uri(uri):
551578 if 'direct' in protonames :
552579 return ProxyDirect (lbind = lbind )
553580 else :
554- params = dict (protos = protos , cipher = cipher , users = users , rule = url .query , bind = loc or urlpath ,
581+ params = dict (jump = jump , protos = protos , cipher = cipher , users = users , rule = url .query , bind = loc or urlpath ,
555582 host_name = host_name , port = port , unix = not loc , lbind = lbind , sslclient = sslclient , sslserver = sslserver )
556583 if 'quic' in rawprotos :
557584 proxy = ProxyQUIC (quicserver , quicclient , ** params )
@@ -600,9 +627,9 @@ async def test_url(url, rserver):
600627
601628def main ():
602629 parser = argparse .ArgumentParser (description = __description__ + '\n Supported protocols: http,socks4,socks5,shadowsocks,shadowsocksr,redirect,pf,tunnel' , epilog = f'Online help: <{ __url__ } >' )
603- parser .add_argument ('-l' , dest = 'listen' , default = [], action = 'append' , type = proxy_by_uri , help = 'tcp server uri (default: http+socks4+socks5://:8080/)' )
630+ parser .add_argument ('-l' , dest = 'listen' , default = [], action = 'append' , type = proxies_by_uri , help = 'tcp server uri (default: http+socks4+socks5://:8080/)' )
604631 parser .add_argument ('-r' , dest = 'rserver' , default = [], action = 'append' , type = proxies_by_uri , help = 'tcp remote server uri (default: direct)' )
605- parser .add_argument ('-ul' , dest = 'ulisten' , default = [], action = 'append' , type = proxy_by_uri , help = 'udp server setting uri (default: none)' )
632+ parser .add_argument ('-ul' , dest = 'ulisten' , default = [], action = 'append' , type = proxies_by_uri , help = 'udp server setting uri (default: none)' )
606633 parser .add_argument ('-ur' , dest = 'urserver' , default = [], action = 'append' , type = proxies_by_uri , help = 'udp remote server uri (default: direct)' )
607634 parser .add_argument ('-b' , dest = 'block' , type = compile_rule , help = 'block regex rules' )
608635 parser .add_argument ('-a' , dest = 'alived' , default = 0 , type = int , help = 'interval to check remote alive (default: no check)' )
0 commit comments