@@ -41,7 +41,6 @@ use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
4141use log:: { error, warn} ;
4242use sc_network_common:: protocol:: ProtocolName ;
4343use std:: {
44- convert:: Infallible ,
4544 io, mem,
4645 pin:: Pin ,
4746 task:: { Context , Poll } ,
@@ -221,10 +220,7 @@ where
221220
222221 /// Equivalent to `Stream::poll_next`, except that it only drives the handshake and is
223222 /// guaranteed to not generate any notification.
224- pub fn poll_process (
225- self : Pin < & mut Self > ,
226- cx : & mut Context ,
227- ) -> Poll < Result < Infallible , io:: Error > > {
223+ pub fn poll_process ( self : Pin < & mut Self > , cx : & mut Context ) -> Poll < Result < ( ) , io:: Error > > {
228224 let mut this = self . project ( ) ;
229225
230226 loop {
@@ -246,8 +242,10 @@ where
246242 } ,
247243 NotificationsInSubstreamHandshake :: Flush => {
248244 match Sink :: poll_flush ( this. socket . as_mut ( ) , cx) ? {
249- Poll :: Ready ( ( ) ) =>
250- * this. handshake = NotificationsInSubstreamHandshake :: Sent ,
245+ Poll :: Ready ( ( ) ) => {
246+ * this. handshake = NotificationsInSubstreamHandshake :: Sent ;
247+ return Poll :: Ready ( Ok ( ( ) ) )
248+ } ,
251249 Poll :: Pending => {
252250 * this. handshake = NotificationsInSubstreamHandshake :: Flush ;
253251 return Poll :: Pending
@@ -260,7 +258,7 @@ where
260258 st @ NotificationsInSubstreamHandshake :: ClosingInResponseToRemote |
261259 st @ NotificationsInSubstreamHandshake :: BothSidesClosed => {
262260 * this. handshake = st;
263- return Poll :: Pending
261+ return Poll :: Ready ( Ok ( ( ) ) )
264262 } ,
265263 }
266264 }
@@ -443,6 +441,21 @@ where
443441
444442 fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context ) -> Poll < Result < ( ) , Self :: Error > > {
445443 let mut this = self . project ( ) ;
444+
445+ // `Sink::poll_flush` does not expose stream closed error until we write something into
446+ // the stream, so the code below makes sure we detect that the substream was closed
447+ // even if we don't write anything into it.
448+ match Stream :: poll_next ( this. socket . as_mut ( ) , cx) {
449+ Poll :: Pending => { } ,
450+ Poll :: Ready ( Some ( _) ) => {
451+ error ! (
452+ target: "sub-libp2p" ,
453+ "Unexpected incoming data in `NotificationsOutSubstream`" ,
454+ ) ;
455+ } ,
456+ Poll :: Ready ( None ) => return Poll :: Ready ( Err ( NotificationsOutError :: Terminated ) ) ,
457+ }
458+
446459 Sink :: poll_flush ( this. socket . as_mut ( ) , cx) . map_err ( NotificationsOutError :: Io )
447460 }
448461
@@ -492,13 +505,21 @@ pub enum NotificationsOutError {
492505 /// I/O error on the substream.
493506 #[ error( transparent) ]
494507 Io ( #[ from] io:: Error ) ,
508+
509+ /// End of incoming data detected on out substream.
510+ #[ error( "substream was closed/reset" ) ]
511+ Terminated ,
495512}
496513
497514#[ cfg( test) ]
498515mod tests {
499- use super :: { NotificationsIn , NotificationsInOpen , NotificationsOut , NotificationsOutOpen } ;
500- use futures:: { channel:: oneshot, prelude:: * } ;
516+ use super :: {
517+ NotificationsIn , NotificationsInOpen , NotificationsOut , NotificationsOutError ,
518+ NotificationsOutOpen ,
519+ } ;
520+ use futures:: { channel:: oneshot, future, prelude:: * } ;
501521 use libp2p:: core:: upgrade;
522+ use std:: { pin:: Pin , task:: Poll } ;
502523 use tokio:: net:: { TcpListener , TcpStream } ;
503524 use tokio_util:: compat:: TokioAsyncReadCompatExt ;
504525
@@ -691,4 +712,95 @@ mod tests {
691712
692713 client. await . unwrap ( ) ;
693714 }
715+
716+ #[ tokio:: test]
717+ async fn send_handshake_without_polling_for_incoming_data ( ) {
718+ const PROTO_NAME : & str = "/test/proto/1" ;
719+ let ( listener_addr_tx, listener_addr_rx) = oneshot:: channel ( ) ;
720+
721+ let client = tokio:: spawn ( async move {
722+ let socket = TcpStream :: connect ( listener_addr_rx. await . unwrap ( ) ) . await . unwrap ( ) ;
723+ let NotificationsOutOpen { handshake, .. } = upgrade:: apply_outbound (
724+ socket. compat ( ) ,
725+ NotificationsOut :: new ( PROTO_NAME , Vec :: new ( ) , & b"initial message" [ ..] , 1024 * 1024 ) ,
726+ upgrade:: Version :: V1 ,
727+ )
728+ . await
729+ . unwrap ( ) ;
730+
731+ assert_eq ! ( handshake, b"hello world" ) ;
732+ } ) ;
733+
734+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
735+ listener_addr_tx. send ( listener. local_addr ( ) . unwrap ( ) ) . unwrap ( ) ;
736+
737+ let ( socket, _) = listener. accept ( ) . await . unwrap ( ) ;
738+ let NotificationsInOpen { handshake, mut substream, .. } = upgrade:: apply_inbound (
739+ socket. compat ( ) ,
740+ NotificationsIn :: new ( PROTO_NAME , Vec :: new ( ) , 1024 * 1024 ) ,
741+ )
742+ . await
743+ . unwrap ( ) ;
744+
745+ assert_eq ! ( handshake, b"initial message" ) ;
746+ substream. send_handshake ( & b"hello world" [ ..] ) ;
747+
748+ // Actually send the handshake.
749+ future:: poll_fn ( |cx| Pin :: new ( & mut substream) . poll_process ( cx) ) . await . unwrap ( ) ;
750+
751+ client. await . unwrap ( ) ;
752+ }
753+
754+ #[ tokio:: test]
755+ async fn can_detect_dropped_out_substream_without_writing_data ( ) {
756+ const PROTO_NAME : & str = "/test/proto/1" ;
757+ let ( listener_addr_tx, listener_addr_rx) = oneshot:: channel ( ) ;
758+
759+ let client = tokio:: spawn ( async move {
760+ let socket = TcpStream :: connect ( listener_addr_rx. await . unwrap ( ) ) . await . unwrap ( ) ;
761+ let NotificationsOutOpen { handshake, mut substream, .. } = upgrade:: apply_outbound (
762+ socket. compat ( ) ,
763+ NotificationsOut :: new ( PROTO_NAME , Vec :: new ( ) , & b"initial message" [ ..] , 1024 * 1024 ) ,
764+ upgrade:: Version :: V1 ,
765+ )
766+ . await
767+ . unwrap ( ) ;
768+
769+ assert_eq ! ( handshake, b"hello world" ) ;
770+
771+ future:: poll_fn ( |cx| match Pin :: new ( & mut substream) . poll_flush ( cx) {
772+ Poll :: Pending => Poll :: Pending ,
773+ Poll :: Ready ( Ok ( ( ) ) ) => {
774+ cx. waker ( ) . wake_by_ref ( ) ;
775+ Poll :: Pending
776+ } ,
777+ Poll :: Ready ( Err ( e) ) => {
778+ assert ! ( matches!( e, NotificationsOutError :: Terminated ) ) ;
779+ Poll :: Ready ( ( ) )
780+ } ,
781+ } )
782+ . await ;
783+ } ) ;
784+
785+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
786+ listener_addr_tx. send ( listener. local_addr ( ) . unwrap ( ) ) . unwrap ( ) ;
787+
788+ let ( socket, _) = listener. accept ( ) . await . unwrap ( ) ;
789+ let NotificationsInOpen { handshake, mut substream, .. } = upgrade:: apply_inbound (
790+ socket. compat ( ) ,
791+ NotificationsIn :: new ( PROTO_NAME , Vec :: new ( ) , 1024 * 1024 ) ,
792+ )
793+ . await
794+ . unwrap ( ) ;
795+
796+ assert_eq ! ( handshake, b"initial message" ) ;
797+
798+ // Send the handhsake.
799+ substream. send_handshake ( & b"hello world" [ ..] ) ;
800+ future:: poll_fn ( |cx| Pin :: new ( & mut substream) . poll_process ( cx) ) . await . unwrap ( ) ;
801+
802+ drop ( substream) ;
803+
804+ client. await . unwrap ( ) ;
805+ }
694806}
0 commit comments