diff --git a/CHANGES.md b/CHANGES.md index 5cfec3e9f..798f553a8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -29,6 +29,7 @@ ## Bug Fixes * [GH-650](https://github.com/apache/mina-sshd/issues/650) Use the correct key from a user certificate in server-side pubkey auth +* [GH-663](https://github.com/apache/mina-sshd/issues/663) Fix racy `IoSession` creation * [GH-664](https://github.com/apache/mina-sshd/issues/664) Skip MAC negotiation if an AEAD cipher was negotiated ## New Features diff --git a/sshd-core/src/main/java/org/apache/sshd/common/io/nio2/Nio2Connector.java b/sshd-core/src/main/java/org/apache/sshd/common/io/nio2/Nio2Connector.java index 84941b68c..33d98a067 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/io/nio2/Nio2Connector.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/io/nio2/Nio2Connector.java @@ -215,8 +215,10 @@ protected void onCompleted(Void result, Object attachment) { handler.sessionCreated(session); sessionId = session.getId(); - sessions.put(sessionId, session); - future.setSession(session); + IoSession registered = mapSession(session); + if (registered == session) { + future.setSession(session); + } if (session != future.getSession()) { session.close(true); throw new CancellationException(); diff --git a/sshd-core/src/main/java/org/apache/sshd/common/io/nio2/Nio2Service.java b/sshd-core/src/main/java/org/apache/sshd/common/io/nio2/Nio2Service.java index 82f869fe3..ebdabc2d8 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/io/nio2/Nio2Service.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/io/nio2/Nio2Service.java @@ -72,6 +72,7 @@ public abstract class Nio2Service extends AbstractInnerCloseable implements IoSe private final AsynchronousChannelGroup group; private final ExecutorService executor; private IoServiceEventListener eventListener; + private boolean noMoreSessions; protected Nio2Service(PropertyResolver propertyResolver, IoHandler handler, AsynchronousChannelGroup group, ExecutorService resumeTasks) { @@ -127,7 +128,7 @@ public void dispose() { @Override protected Closeable getInnerCloseable() { return builder() - .parallel(toString(), sessions.values()) + .parallel(toString(), snapshot()) .build(); } @@ -140,6 +141,23 @@ public void sessionClosed(Nio2Session session) { unmapSession(session.getId()); } + private Collection snapshot() { + synchronized (this) { + noMoreSessions = true; + } + return sessions.values(); + } + + protected IoSession mapSession(IoSession session) { + synchronized (this) { + if (noMoreSessions) { + return null; + } + sessions.put(session.getId(), session); + return session; + } + } + protected void unmapSession(Long sessionId) { if (sessionId != null) { IoSession ioSession = sessions.remove(sessionId); diff --git a/sshd-core/src/test/java/org/apache/sshd/common/io/IoConnectionTest.java b/sshd-core/src/test/java/org/apache/sshd/common/io/IoConnectionTest.java new file mode 100644 index 000000000..e47307e14 --- /dev/null +++ b/sshd-core/src/test/java/org/apache/sshd/common/io/IoConnectionTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.common.io; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.mina.core.buffer.IoBuffer; +import org.apache.mina.core.service.IoHandlerAdapter; +import org.apache.mina.transport.socket.nio.NioSocketAcceptor; +import org.apache.sshd.client.SshClient; +import org.apache.sshd.common.future.SshFutureListener; +import org.apache.sshd.common.util.Readable; +import org.apache.sshd.util.test.BaseTestSupport; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Tests for low-level connections. + */ +class IoConnectionTest extends BaseTestSupport { + + private static final Logger LOG = LoggerFactory.getLogger(IoConnectionTest.class); + + @Test + void connectorRace() throws Exception { + CountDownLatch connectionMade = new CountDownLatch(1); + CountDownLatch connectorClosing = new CountDownLatch(1); + CountDownLatch futureTriggered = new CountDownLatch(1); + CountDownLatch ioSessionClosed = new CountDownLatch(1); + AtomicReference session = new AtomicReference<>(); + AtomicBoolean connectorIsClosing = new AtomicBoolean(); + AtomicBoolean sessionWaited = new AtomicBoolean(); + + SshClient client = setupTestClient(); + IoServiceFactory serviceFactory = DefaultIoServiceFactoryFactory.getDefaultIoServiceFactoryFactoryInstance() + .create(client); + IoConnector connector = serviceFactory.createConnector(new IoHandler() { + + @Override + public void sessionCreated(org.apache.sshd.common.io.IoSession session) throws Exception { + connectionMade.countDown(); + sessionWaited.set(connectorClosing.await(5, TimeUnit.SECONDS)); + } + + @Override + public void sessionClosed(org.apache.sshd.common.io.IoSession session) throws Exception { + ioSessionClosed.countDown(); + } + + @Override + public void exceptionCaught(org.apache.sshd.common.io.IoSession session, Throwable cause) throws Exception { + // Nothing + } + + @Override + public void messageReceived(org.apache.sshd.common.io.IoSession session, Readable message) throws Exception { + // Nothing; we're not actually sending or receiving data. + } + }); + NioSocketAcceptor acceptor = startEchoServer(); + try { + InetSocketAddress connectAddress = new InetSocketAddress(InetAddress.getByName(TEST_LOCALHOST), + acceptor.getLocalAddress().getPort()); + IoConnectFuture future = connector.connect(connectAddress, null, null); + connectionMade.await(5, TimeUnit.SECONDS); + connector.close(); + connectorClosing.countDown(); + future.addListener(new SshFutureListener() { + + @Override + public void operationComplete(IoConnectFuture future) { + session.set(future.getSession()); + connectorIsClosing.set(!connector.isOpen()); + futureTriggered.countDown(); + } + }); + assertTrue(futureTriggered.await(5, TimeUnit.SECONDS)); + Throwable error = future.getException(); + if (error != null) { + LOG.info("{}: Connect future was terminated exceptionally: {} ", getCurrentTestName(), error); + error.printStackTrace(); + } else if (future.isCanceled()) { + LOG.info("{}: Connect future was canceled", getCurrentTestName()); + } + assertEquals(0, connectionMade.getCount()); + assertTrue(sessionWaited.get()); + assertNull(session.get()); + assertTrue(connectorIsClosing.get()); + // Since sessionCreated() was called we also expect sessionClosed() to get called eventually. + assertTrue(ioSessionClosed.await(5, TimeUnit.SECONDS)); + } finally { + acceptor.dispose(false); + } + } + + private NioSocketAcceptor startEchoServer() throws IOException { + NioSocketAcceptor acceptor = new NioSocketAcceptor(); + acceptor.setHandler(new IoHandlerAdapter() { + + @Override + public void messageReceived(org.apache.mina.core.session.IoSession session, Object message) throws Exception { + IoBuffer recv = (IoBuffer) message; + IoBuffer sent = IoBuffer.allocate(recv.remaining()); + sent.put(recv); + sent.flip(); + session.write(sent); + } + }); + acceptor.setReuseAddress(true); + acceptor.bind(new InetSocketAddress(0)); + return acceptor; + } +} diff --git a/sshd-mina/src/main/java/org/apache/sshd/mina/MinaConnector.java b/sshd-mina/src/main/java/org/apache/sshd/mina/MinaConnector.java index 8ae10d870..cd218e18d 100644 --- a/sshd-mina/src/main/java/org/apache/sshd/mina/MinaConnector.java +++ b/sshd-mina/src/main/java/org/apache/sshd/mina/MinaConnector.java @@ -187,7 +187,7 @@ public void setSession(org.apache.sshd.common.io.IoSession session) { Throwable t = cf.getException(); if (t != null) { future.setException(t); - } else if (cf.isCanceled()) { + } else if (cf.isCanceled() || !isOpen()) { IoSession ioSession = createdSession.getAndSet(null); CancelFuture cancellation = future.cancel(); if (ioSession != null) { diff --git a/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoAcceptor.java b/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoAcceptor.java index b4f68ee65..4543f44ae 100644 --- a/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoAcceptor.java +++ b/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoAcceptor.java @@ -37,12 +37,10 @@ import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; -import io.netty.channel.group.DefaultChannelGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.logging.LogLevel; import io.netty.handler.logging.LoggingHandler; -import io.netty.util.concurrent.GlobalEventExecutor; import org.apache.sshd.common.future.CloseFuture; import org.apache.sshd.common.io.IoAcceptor; import org.apache.sshd.common.io.IoHandler; @@ -64,10 +62,9 @@ public class NettyIoAcceptor extends NettyIoService implements IoAcceptor { protected final Map boundAddresses = new ConcurrentHashMap<>(); public NettyIoAcceptor(NettyIoServiceFactory factory, IoHandler handler) { - super(factory, handler); + super(factory, handler, "sshd-acceptor-channels"); Boolean reuseaddr = CoreModuleProperties.SOCKET_REUSEADDR.getRequired(factory.manager); - channelGroup = new DefaultChannelGroup("sshd-acceptor-channels", GlobalEventExecutor.INSTANCE); bootstrap.group(factory.eventLoopGroup) .channel(NioServerSocketChannel.class) .option(ChannelOption.SO_BACKLOG, CoreModuleProperties.SOCKET_BACKLOG.getRequired(factory.manager)) diff --git a/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoConnector.java b/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoConnector.java index dd1f7b447..9d8572c0a 100644 --- a/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoConnector.java +++ b/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoConnector.java @@ -26,12 +26,10 @@ import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; -import io.netty.channel.group.DefaultChannelGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.logging.LogLevel; import io.netty.handler.logging.LoggingHandler; -import io.netty.util.concurrent.GlobalEventExecutor; import org.apache.sshd.common.AttributeRepository; import org.apache.sshd.common.future.CancelFuture; import org.apache.sshd.common.io.DefaultIoConnectFuture; @@ -51,8 +49,7 @@ public class NettyIoConnector extends NettyIoService implements IoConnector { private static final LoggingHandler LOGGING_TRACE = new LoggingHandler(NettyIoConnector.class, LogLevel.TRACE); public NettyIoConnector(NettyIoServiceFactory factory, IoHandler handler) { - super(factory, handler); - channelGroup = new DefaultChannelGroup("sshd-connector-channels", GlobalEventExecutor.INSTANCE); + super(factory, handler, "sshd-connector-channels"); } @Override diff --git a/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoService.java b/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoService.java index ebf0245d3..f8acd2302 100644 --- a/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoService.java +++ b/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoService.java @@ -19,13 +19,18 @@ package org.apache.sshd.netty; +import java.util.Collections; import java.util.Map; import java.util.Objects; +import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import io.netty.channel.Channel; import io.netty.channel.group.ChannelGroup; +import io.netty.channel.group.DefaultChannelGroup; import io.netty.util.AttributeKey; +import io.netty.util.concurrent.GlobalEventExecutor; import org.apache.sshd.common.io.IoConnectFuture; import org.apache.sshd.common.io.IoHandler; import org.apache.sshd.common.io.IoService; @@ -44,16 +49,46 @@ public abstract class NettyIoService extends AbstractCloseable implements IoServ protected final AtomicLong sessionSeq = new AtomicLong(); protected final Map sessions = new ConcurrentHashMap<>(); - protected ChannelGroup channelGroup; + protected final ChannelGroup channelGroup; protected final NettyIoServiceFactory factory; protected final IoHandler handler; + private boolean noMoreSessions; private IoServiceEventListener eventListener; - protected NettyIoService(NettyIoServiceFactory factory, IoHandler handler) { + protected NettyIoService(NettyIoServiceFactory factory, IoHandler handler, String channelGroupName) { this.factory = Objects.requireNonNull(factory, "No factory instance provided"); this.handler = Objects.requireNonNull(handler, "No I/O handler provied"); this.eventListener = factory.getIoServiceEventListener(); + this.channelGroup = new DefaultChannelGroup(Objects.requireNonNull(channelGroupName, "No channel group name"), + GlobalEventExecutor.INSTANCE); + } + + @Override + protected void doCloseImmediately() { + synchronized (this) { + noMoreSessions = true; + } + channelGroup.close(); + super.doCloseImmediately(); + } + + protected void registerChannel(Channel channel) throws CancellationException { + synchronized (this) { + if (noMoreSessions) { + throw new CancellationException("NettyIoService closed"); + } + channelGroup.add(channel); + } + } + + protected void mapSession(IoSession session) throws CancellationException { + synchronized (this) { + if (noMoreSessions) { + throw new CancellationException("NettyIoService closed; cannot register new session"); + } + sessions.put(session.getId(), session); + } } @Override @@ -68,6 +103,6 @@ public void setIoServiceEventListener(IoServiceEventListener listener) { @Override public Map getManagedSessions() { - return sessions; + return Collections.unmodifiableMap(sessions); } } diff --git a/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoSession.java b/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoSession.java index 02c90ed56..2ac593ec0 100644 --- a/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoSession.java +++ b/sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoSession.java @@ -241,8 +241,6 @@ protected void doCloseImmediately() { protected void channelActive(ChannelHandlerContext ctx) throws Exception { context = ctx; Channel channel = ctx.channel(); - service.channelGroup.add(channel); - service.sessions.put(id, NettyIoSession.this); prev = context.newPromise().setSuccess(); remoteAddr = channel.remoteAddress(); // If handler.sessionCreated() propagates an exception, we'll have a NettyIoSession without SSH session. We'll @@ -254,7 +252,9 @@ protected void channelActive(ChannelHandlerContext ctx) throws Exception { Attribute connectFuture = channel.attr(NettyIoService.CONNECT_FUTURE_KEY); IoConnectFuture future = connectFuture.get(); try { + service.registerChannel(channel); handler.sessionCreated(NettyIoSession.this); + service.mapSession(this); if (future != null) { future.setSession(NettyIoSession.this); if (future.getSession() != NettyIoSession.this) { @@ -262,7 +262,7 @@ protected void channelActive(ChannelHandlerContext ctx) throws Exception { } } } catch (Throwable e) { - log.warn("channelActive(session={}): could not create SSH session ({}); closing", this, e.getClass().getName(), e); + warn("channelActive(session={}): could not create SSH session ({}); closing", this, e.getClass().getName(), e); try { if (future != null) { future.setException(e);