Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;

import org.apache.sshd.client.ClientAuthenticationManager;
import org.apache.sshd.client.ClientFactoryManager;
Expand Down Expand Up @@ -235,12 +236,31 @@ ChannelExec createExecChannel(byte[] command, PtyChannelConfigurationHolder ptyC
* error or a non-zero exit status was received. If this happens, then a {@link RemoteException}
* is thrown with a cause of {@link ServerException} containing the remote captured standard
* error - including CR/LF(s)
* @see #executeRemoteCommand(String, OutputStream, Charset)
* @see #executeRemoteCommand(String, Duration)
*/
default String executeRemoteCommand(String command) throws IOException {
return executeRemoteCommand(command, Duration.ZERO);
}

/**
* Execute a command that requires no input and returns its output
*
* @param command The command to execute
* @param timeout Timeout for the remote command execution. Applies to both channel opening and result waiting.
* A zero or negative value means no timeout.
* @return The command's standard output result
* @return The command's standard output result (assumed to be in US-ASCII)
* @throws IOException If failed to execute the command - including if <U>anything</U> was written to the standard
* error or a non-zero exit status was received. If this happens, then a {@link RemoteException}
* is thrown with a cause of {@link ServerException} containing the remote captured standard
* error - including CR/LF(s)
* @see #executeRemoteCommand(String, OutputStream, Charset)
* @see #executeRemoteCommand(String, OutputStream, Charset, Duration)
*/
default String executeRemoteCommand(String command, Duration timeout) throws IOException {
try (ByteArrayOutputStream stderr = new ByteArrayOutputStream()) {
try {
return executeRemoteCommand(command, stderr, StandardCharsets.US_ASCII);
return executeRemoteCommand(command, stderr, StandardCharsets.US_ASCII, timeout);
} finally {
if (stderr.size() > 0) {
String errorMessage = stderr.toString(StandardCharsets.US_ASCII.name());
Expand All @@ -264,15 +284,37 @@ default String executeRemoteCommand(String command) throws IOException {
* was output to the standard error stream, but does check the reported exit status (if any) for
* non-zero value. If non-zero exit status received then a {@link RemoteException} is thrown
* with' a {@link ServerException} cause containing the exits value
* @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset)
* @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, Duration)
*/
default String executeRemoteCommand(String command, OutputStream stderr, Charset charset) throws IOException {
return executeRemoteCommand(command, stderr, charset, Duration.ZERO);
}

/**
* Execute a command that requires no input and returns its output
*
* @param command The command to execute - without a terminating LF
* @param stderr Standard error output stream - if {@code null} then error stream data is ignored.
* <B>Note:</B> if the stream is not {@code null} then it will be left <U>open</U> when this
* method returns or exception is thrown
* @param charset The command {@link Charset} for input/output/error - if {@code null} then US_ASCII is assumed
* @param timeout Timeout for the remote command execution. Applies to both channel opening and result waiting.
* A zero or negative value means no timeout.
* @return The command's standard output result
* @throws IOException If failed to manage the command channel - <B>Note:</B> the code does not check if anything
* was output to the standard error stream, but does check the reported exit status (if any) for
* non-zero value. If non-zero exit status received then a {@link RemoteException} is thrown
* with' a {@link ServerException} cause containing the exits value
* @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, Duration)
*/
default String executeRemoteCommand(String command, OutputStream stderr, Charset charset, Duration timeout)
throws IOException {
if (charset == null) {
charset = StandardCharsets.US_ASCII;
}

try (ByteArrayOutputStream stdout = new ByteArrayOutputStream(Byte.MAX_VALUE)) {
executeRemoteCommand(command, stdout, stderr, charset);
executeRemoteCommand(command, stdout, stderr, charset, timeout);
byte[] outBytes = stdout.toByteArray();
return new String(outBytes, charset);
}
Expand All @@ -290,26 +332,70 @@ default String executeRemoteCommand(String command, OutputStream stderr, Charset
* thrown
* @param charset The command {@link Charset} for output/error - if {@code null} then US_ASCII is assumed
* @throws IOException If failed to execute the command or got a non-zero exit status
* @see ClientChannel#validateCommandExitStatusCode(String, Integer) validateCommandExitStatusCode
* @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, Duration)
*/
default void executeRemoteCommand(
String command, OutputStream stdout, OutputStream stderr, Charset charset)
throws IOException {
executeRemoteCommand(command, stdout, stderr, charset, Duration.ZERO);
}

/**
* Execute a command that requires no input and redirects its STDOUT/STDERR streams to the user-provided ones
*
* @param command The command to execute - without a terminating LF.
* @param stdout Standard output stream - if {@code null} then stream data is ignored. <b>Note:</b> if the
* stream is not {@code null}, it will be left <u>open</u> when this method returns or an
* exception is thrown.
* @param stderr Error output stream - if {@code null} then error stream data is ignored. <b>Note:</b> if the
* stream is not {@code null}, it will be left <u>open</u> when this method returns or an
* exception is thrown.
* @param charset The charset to use for encoding the command and decoding the output/error streams. If
* {@code null}, US-ASCII is assumed.
* @param timeout Timeout for the remote command execution. Applies to both channel opening and result waiting.
* A zero or negative value means no timeout.
* @throws IOException If the command execution fails, times out, or returns a non-zero exit code. A
* {@link RemoteException} may be thrown if the remote side reports an error.
* @see ClientChannel#open()#verify(long, java.util.concurrent.TimeUnit)
* @see ClientChannel#waitFor(Collection, long)
* @see ClientChannel#validateCommandExitStatusCode(String, Integer) validateCommandExitStatusCode
*/
default void executeRemoteCommand(
String command, OutputStream stdout, OutputStream stderr, Charset charset, Duration timeout)
throws IOException {

if (charset == null) {
charset = StandardCharsets.US_ASCII;
}

if (timeout != null && timeout.isNegative()) {
throw new IllegalArgumentException("Timeout must be non-negative");
}

try (OutputStream channelErr = (stderr == null) ? new NullOutputStream() : new NoCloseOutputStream(stderr);
OutputStream channelOut = (stdout == null) ? new NullOutputStream() : new NoCloseOutputStream(stdout);
ClientChannel channel = createExecChannel(command, charset, null, Collections.emptyMap())) {

channel.setOut(channelOut);
channel.setErr(channelErr);
channel.open().await(); // TODO use verify and a configurable timeout

// TODO use a configurable timeout
Collection<ClientChannelEvent> waitMask = channel.waitFor(REMOTE_COMMAND_WAIT_EVENTS, 0L);
long waitTimeoutMillis;
if (timeout != null && !timeout.isZero()) {
long timeoutMillis = timeout.toMillis();
long startTime = System.currentTimeMillis();
channel.open().verify(timeoutMillis, TimeUnit.MILLISECONDS);

long elapsed = System.currentTimeMillis() - startTime;
waitTimeoutMillis = Math.max(1, timeoutMillis - elapsed);
} else {
channel.open().verify(); // wait indefinitely
waitTimeoutMillis = 0L;
}

Collection<ClientChannelEvent> waitMask = channel.waitFor(REMOTE_COMMAND_WAIT_EVENTS, waitTimeoutMillis);
if (waitMask.contains(ClientChannelEvent.TIMEOUT)) {
throw new SocketTimeoutException("Failed to retrieve command result in time: " + command);
throw new SocketTimeoutException(String.format(
"Failed to retrieve command '%s' result within timeout of %s ms", command, timeout));
}

Integer exitStatus = channel.getExitStatus();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.SocketTimeoutException;
import java.nio.charset.StandardCharsets;
import java.rmi.RemoteException;
import java.rmi.ServerException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -58,12 +60,6 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

/**
* @author <a href="mailto:[email protected]">Apache MINA SSHD Project</a>
*/
Expand Down Expand Up @@ -240,6 +236,72 @@ protected boolean handleCommandLine(String command) throws Exception {
assertEquals(Integer.toString(expectedErrorCode), actualErrorMessage, "Mismatched captured error code");
}

@Test
void executeCommandMethodWithConfigurableTimeout() throws Exception {
String expectedCommand = getCurrentTestName() + "-CMD";
String expectedResponse = getCurrentTestName() + "-RSP";
Duration timeout = Duration.ofMillis(10000L);
sshd.setCommandFactory((session, command) -> new CommandExecutionHelper(command) {
private boolean cmdProcessed;

@Override
protected boolean handleCommandLine(String command) throws Exception {
assertEquals(expectedCommand, command, "Mismatched incoming command");
assertFalse(cmdProcessed, "Duplicated command call");
OutputStream stdout = getOutputStream();
Thread.sleep(500L);
stdout.write(expectedResponse.getBytes(StandardCharsets.US_ASCII));
stdout.flush();
cmdProcessed = true;
return false;
}
});

try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
.verify(CONNECT_TIMEOUT)
.getSession()) {
session.addPasswordIdentity(getCurrentTestName());
session.auth().verify(AUTH_TIMEOUT);

// NOTE !!! The LF is only because we are using a buffered reader on the server end to read the command
String actualResponse = session.executeRemoteCommand(expectedCommand + "\n", timeout);
assertEquals(expectedResponse, actualResponse, "Mismatched command response");
}
}

@Test
void exceptionThrownOnExecuteCommandTimeout() throws Exception {
String expectedCommand = getCurrentTestName() + "-CMD";
Duration timeout = Duration.ofMillis(500L);

sshd.setCommandFactory((session, command) -> new CommandExecutionHelper(command) {
private boolean cmdProcessed;

@Override
protected boolean handleCommandLine(String command) throws Exception {
assertEquals(expectedCommand, command, "Mismatched incoming command");
assertFalse(cmdProcessed, "Duplicated command call");
Thread.sleep(timeout.plusMillis(200L).toMillis());
OutputStream stdout = getOutputStream();
stdout.write(command.getBytes(StandardCharsets.US_ASCII));
stdout.flush();
cmdProcessed = true;
return false;
}
});

try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
.verify(CONNECT_TIMEOUT)
.getSession()) {
session.addPasswordIdentity(getCurrentTestName());
session.auth().verify(AUTH_TIMEOUT);

assertThrows(SocketTimeoutException.class, () -> {
session.executeRemoteCommand(expectedCommand + "\n", timeout);
});
}
}

// see SSHD-859
@Test
void connectionContextPropagation() throws Exception {
Expand Down