Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
## Bug Fixes

* [GH-650](https://github.com/apache/mina-sshd/issues/650) Use the correct key from a user certificate in server-side pubkey auth
* [GH-663](https://github.com/apache/mina-sshd/issues/663) Fix racy `IoSession` creation
* [GH-664](https://github.com/apache/mina-sshd/issues/664) Skip MAC negotiation if an AEAD cipher was negotiated

## New Features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,10 @@ protected void onCompleted(Void result, Object attachment) {

handler.sessionCreated(session);
sessionId = session.getId();
sessions.put(sessionId, session);
future.setSession(session);
IoSession registered = mapSession(session);
if (registered == session) {
future.setSession(session);
}
if (session != future.getSession()) {
session.close(true);
throw new CancellationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public abstract class Nio2Service extends AbstractInnerCloseable implements IoSe
private final AsynchronousChannelGroup group;
private final ExecutorService executor;
private IoServiceEventListener eventListener;
private boolean noMoreSessions;

protected Nio2Service(PropertyResolver propertyResolver, IoHandler handler, AsynchronousChannelGroup group,
ExecutorService resumeTasks) {
Expand Down Expand Up @@ -127,7 +128,7 @@ public void dispose() {
@Override
protected Closeable getInnerCloseable() {
return builder()
.parallel(toString(), sessions.values())
.parallel(toString(), snapshot())
.build();
}

Expand All @@ -140,6 +141,23 @@ public void sessionClosed(Nio2Session session) {
unmapSession(session.getId());
}

private Collection<IoSession> snapshot() {
synchronized (this) {
noMoreSessions = true;
}
return sessions.values();
}

protected IoSession mapSession(IoSession session) {
synchronized (this) {
if (noMoreSessions) {
return null;
}
sessions.put(session.getId(), session);
return session;
}
}

protected void unmapSession(Long sessionId) {
if (sessionId != null) {
IoSession ioSession = sessions.remove(sessionId);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sshd.common.io;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import org.apache.mina.core.buffer.IoBuffer;
import org.apache.mina.core.service.IoHandlerAdapter;
import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.common.future.SshFutureListener;
import org.apache.sshd.common.util.Readable;
import org.apache.sshd.util.test.BaseTestSupport;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Tests for low-level connections.
*/
class IoConnectionTest extends BaseTestSupport {

private static final Logger LOG = LoggerFactory.getLogger(IoConnectionTest.class);

@Test
void connectorRace() throws Exception {
CountDownLatch connectionMade = new CountDownLatch(1);
CountDownLatch connectorClosing = new CountDownLatch(1);
CountDownLatch futureTriggered = new CountDownLatch(1);
CountDownLatch ioSessionClosed = new CountDownLatch(1);
AtomicReference<IoSession> session = new AtomicReference<>();
AtomicBoolean connectorIsClosing = new AtomicBoolean();
AtomicBoolean sessionWaited = new AtomicBoolean();

SshClient client = setupTestClient();
IoServiceFactory serviceFactory = DefaultIoServiceFactoryFactory.getDefaultIoServiceFactoryFactoryInstance()
.create(client);
IoConnector connector = serviceFactory.createConnector(new IoHandler() {

@Override
public void sessionCreated(org.apache.sshd.common.io.IoSession session) throws Exception {
connectionMade.countDown();
sessionWaited.set(connectorClosing.await(5, TimeUnit.SECONDS));
}

@Override
public void sessionClosed(org.apache.sshd.common.io.IoSession session) throws Exception {
ioSessionClosed.countDown();
}

@Override
public void exceptionCaught(org.apache.sshd.common.io.IoSession session, Throwable cause) throws Exception {
// Nothing
}

@Override
public void messageReceived(org.apache.sshd.common.io.IoSession session, Readable message) throws Exception {
// Nothing; we're not actually sending or receiving data.
}
});
NioSocketAcceptor acceptor = startEchoServer();
try {
InetSocketAddress connectAddress = new InetSocketAddress(InetAddress.getByName(TEST_LOCALHOST),
acceptor.getLocalAddress().getPort());
IoConnectFuture future = connector.connect(connectAddress, null, null);
connectionMade.await(5, TimeUnit.SECONDS);
connector.close();
connectorClosing.countDown();
future.addListener(new SshFutureListener<IoConnectFuture>() {

@Override
public void operationComplete(IoConnectFuture future) {
session.set(future.getSession());
connectorIsClosing.set(!connector.isOpen());
futureTriggered.countDown();
}
});
assertTrue(futureTriggered.await(5, TimeUnit.SECONDS));
Throwable error = future.getException();
if (error != null) {
LOG.info("{}: Connect future was terminated exceptionally: {} ", getCurrentTestName(), error);
error.printStackTrace();
} else if (future.isCanceled()) {
LOG.info("{}: Connect future was canceled", getCurrentTestName());
}
assertEquals(0, connectionMade.getCount());
assertTrue(sessionWaited.get());
assertNull(session.get());
assertTrue(connectorIsClosing.get());
// Since sessionCreated() was called we also expect sessionClosed() to get called eventually.
assertTrue(ioSessionClosed.await(5, TimeUnit.SECONDS));
} finally {
acceptor.dispose(false);
}
}

private NioSocketAcceptor startEchoServer() throws IOException {
NioSocketAcceptor acceptor = new NioSocketAcceptor();
acceptor.setHandler(new IoHandlerAdapter() {

@Override
public void messageReceived(org.apache.mina.core.session.IoSession session, Object message) throws Exception {
IoBuffer recv = (IoBuffer) message;
IoBuffer sent = IoBuffer.allocate(recv.remaining());
sent.put(recv);
sent.flip();
session.write(sent);
}
});
acceptor.setReuseAddress(true);
acceptor.bind(new InetSocketAddress(0));
return acceptor;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public void setSession(org.apache.sshd.common.io.IoSession session) {
Throwable t = cf.getException();
if (t != null) {
future.setException(t);
} else if (cf.isCanceled()) {
} else if (cf.isCanceled() || !isOpen()) {
IoSession ioSession = createdSession.getAndSet(null);
CancelFuture cancellation = future.cancel();
if (ioSession != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.apache.sshd.common.future.CloseFuture;
import org.apache.sshd.common.io.IoAcceptor;
import org.apache.sshd.common.io.IoHandler;
Expand All @@ -64,10 +62,9 @@ public class NettyIoAcceptor extends NettyIoService implements IoAcceptor {
protected final Map<SocketAddress, Channel> boundAddresses = new ConcurrentHashMap<>();

public NettyIoAcceptor(NettyIoServiceFactory factory, IoHandler handler) {
super(factory, handler);
super(factory, handler, "sshd-acceptor-channels");

Boolean reuseaddr = CoreModuleProperties.SOCKET_REUSEADDR.getRequired(factory.manager);
channelGroup = new DefaultChannelGroup("sshd-acceptor-channels", GlobalEventExecutor.INSTANCE);
bootstrap.group(factory.eventLoopGroup)
.channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, CoreModuleProperties.SOCKET_BACKLOG.getRequired(factory.manager))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.apache.sshd.common.AttributeRepository;
import org.apache.sshd.common.future.CancelFuture;
import org.apache.sshd.common.io.DefaultIoConnectFuture;
Expand All @@ -51,8 +49,7 @@ public class NettyIoConnector extends NettyIoService implements IoConnector {
private static final LoggingHandler LOGGING_TRACE = new LoggingHandler(NettyIoConnector.class, LogLevel.TRACE);

public NettyIoConnector(NettyIoServiceFactory factory, IoHandler handler) {
super(factory, handler);
channelGroup = new DefaultChannelGroup("sshd-connector-channels", GlobalEventExecutor.INSTANCE);
super(factory, handler, "sshd-connector-channels");
}

@Override
Expand Down
41 changes: 38 additions & 3 deletions sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoService.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@

package org.apache.sshd.netty;

import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

import io.netty.channel.Channel;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.apache.sshd.common.io.IoConnectFuture;
import org.apache.sshd.common.io.IoHandler;
import org.apache.sshd.common.io.IoService;
Expand All @@ -44,16 +49,46 @@ public abstract class NettyIoService extends AbstractCloseable implements IoServ

protected final AtomicLong sessionSeq = new AtomicLong();
protected final Map<Long, IoSession> sessions = new ConcurrentHashMap<>();
protected ChannelGroup channelGroup;
protected final ChannelGroup channelGroup;
protected final NettyIoServiceFactory factory;
protected final IoHandler handler;
private boolean noMoreSessions;

private IoServiceEventListener eventListener;

protected NettyIoService(NettyIoServiceFactory factory, IoHandler handler) {
protected NettyIoService(NettyIoServiceFactory factory, IoHandler handler, String channelGroupName) {
this.factory = Objects.requireNonNull(factory, "No factory instance provided");
this.handler = Objects.requireNonNull(handler, "No I/O handler provied");
this.eventListener = factory.getIoServiceEventListener();
this.channelGroup = new DefaultChannelGroup(Objects.requireNonNull(channelGroupName, "No channel group name"),
GlobalEventExecutor.INSTANCE);
}

@Override
protected void doCloseImmediately() {
synchronized (this) {
noMoreSessions = true;
}
channelGroup.close();
super.doCloseImmediately();
}

protected void registerChannel(Channel channel) throws CancellationException {
synchronized (this) {
if (noMoreSessions) {
throw new CancellationException("NettyIoService closed");
}
channelGroup.add(channel);
}
}

protected void mapSession(IoSession session) throws CancellationException {
synchronized (this) {
if (noMoreSessions) {
throw new CancellationException("NettyIoService closed; cannot register new session");
}
sessions.put(session.getId(), session);
}
}

@Override
Expand All @@ -68,6 +103,6 @@ public void setIoServiceEventListener(IoServiceEventListener listener) {

@Override
public Map<Long, IoSession> getManagedSessions() {
return sessions;
return Collections.unmodifiableMap(sessions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,6 @@ protected void doCloseImmediately() {
protected void channelActive(ChannelHandlerContext ctx) throws Exception {
context = ctx;
Channel channel = ctx.channel();
service.channelGroup.add(channel);
service.sessions.put(id, NettyIoSession.this);
prev = context.newPromise().setSuccess();
remoteAddr = channel.remoteAddress();
// If handler.sessionCreated() propagates an exception, we'll have a NettyIoSession without SSH session. We'll
Expand All @@ -254,15 +252,17 @@ protected void channelActive(ChannelHandlerContext ctx) throws Exception {
Attribute<IoConnectFuture> connectFuture = channel.attr(NettyIoService.CONNECT_FUTURE_KEY);
IoConnectFuture future = connectFuture.get();
try {
service.registerChannel(channel);
handler.sessionCreated(NettyIoSession.this);
service.mapSession(this);
if (future != null) {
future.setSession(NettyIoSession.this);
if (future.getSession() != NettyIoSession.this) {
close(true);
}
}
} catch (Throwable e) {
log.warn("channelActive(session={}): could not create SSH session ({}); closing", this, e.getClass().getName(), e);
warn("channelActive(session={}): could not create SSH session ({}); closing", this, e.getClass().getName(), e);
try {
if (future != null) {
future.setException(e);
Expand Down