diff --git a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java
index b72fc80b6..3a413522f 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java
@@ -37,6 +37,7 @@
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.TimeUnit;
import org.apache.sshd.client.ClientAuthenticationManager;
import org.apache.sshd.client.ClientFactoryManager;
@@ -235,12 +236,31 @@ ChannelExec createExecChannel(byte[] command, PtyChannelConfigurationHolder ptyC
* error or a non-zero exit status was received. If this happens, then a {@link RemoteException}
* is thrown with a cause of {@link ServerException} containing the remote captured standard
* error - including CR/LF(s)
- * @see #executeRemoteCommand(String, OutputStream, Charset)
+ * @see #executeRemoteCommand(String, Duration)
*/
default String executeRemoteCommand(String command) throws IOException {
+ return executeRemoteCommand(command, Duration.ZERO);
+ }
+
+ /**
+ * Execute a command that requires no input and returns its output
+ *
+ * @param command The command to execute
+ * @param timeout Timeout for the remote command execution. Applies to both channel opening and result waiting.
+ * A zero or negative value means no timeout.
+ * @return The command's standard output result
+ * @return The command's standard output result (assumed to be in US-ASCII)
+ * @throws IOException If failed to execute the command - including if anything was written to the standard
+ * error or a non-zero exit status was received. If this happens, then a {@link RemoteException}
+ * is thrown with a cause of {@link ServerException} containing the remote captured standard
+ * error - including CR/LF(s)
+ * @see #executeRemoteCommand(String, OutputStream, Charset)
+ * @see #executeRemoteCommand(String, OutputStream, Charset, Duration)
+ */
+ default String executeRemoteCommand(String command, Duration timeout) throws IOException {
try (ByteArrayOutputStream stderr = new ByteArrayOutputStream()) {
try {
- return executeRemoteCommand(command, stderr, StandardCharsets.US_ASCII);
+ return executeRemoteCommand(command, stderr, StandardCharsets.US_ASCII, timeout);
} finally {
if (stderr.size() > 0) {
String errorMessage = stderr.toString(StandardCharsets.US_ASCII.name());
@@ -264,15 +284,37 @@ default String executeRemoteCommand(String command) throws IOException {
* was output to the standard error stream, but does check the reported exit status (if any) for
* non-zero value. If non-zero exit status received then a {@link RemoteException} is thrown
* with' a {@link ServerException} cause containing the exits value
- * @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset)
+ * @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, Duration)
*/
default String executeRemoteCommand(String command, OutputStream stderr, Charset charset) throws IOException {
+ return executeRemoteCommand(command, stderr, charset, Duration.ZERO);
+ }
+
+ /**
+ * Execute a command that requires no input and returns its output
+ *
+ * @param command The command to execute - without a terminating LF
+ * @param stderr Standard error output stream - if {@code null} then error stream data is ignored.
+ * Note: if the stream is not {@code null} then it will be left open when this
+ * method returns or exception is thrown
+ * @param charset The command {@link Charset} for input/output/error - if {@code null} then US_ASCII is assumed
+ * @param timeout Timeout for the remote command execution. Applies to both channel opening and result waiting.
+ * A zero or negative value means no timeout.
+ * @return The command's standard output result
+ * @throws IOException If failed to manage the command channel - Note: the code does not check if anything
+ * was output to the standard error stream, but does check the reported exit status (if any) for
+ * non-zero value. If non-zero exit status received then a {@link RemoteException} is thrown
+ * with' a {@link ServerException} cause containing the exits value
+ * @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, Duration)
+ */
+ default String executeRemoteCommand(String command, OutputStream stderr, Charset charset, Duration timeout)
+ throws IOException {
if (charset == null) {
charset = StandardCharsets.US_ASCII;
}
try (ByteArrayOutputStream stdout = new ByteArrayOutputStream(Byte.MAX_VALUE)) {
- executeRemoteCommand(command, stdout, stderr, charset);
+ executeRemoteCommand(command, stdout, stderr, charset, timeout);
byte[] outBytes = stdout.toByteArray();
return new String(outBytes, charset);
}
@@ -290,26 +332,70 @@ default String executeRemoteCommand(String command, OutputStream stderr, Charset
* thrown
* @param charset The command {@link Charset} for output/error - if {@code null} then US_ASCII is assumed
* @throws IOException If failed to execute the command or got a non-zero exit status
- * @see ClientChannel#validateCommandExitStatusCode(String, Integer) validateCommandExitStatusCode
+ * @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, Duration)
*/
default void executeRemoteCommand(
String command, OutputStream stdout, OutputStream stderr, Charset charset)
throws IOException {
+ executeRemoteCommand(command, stdout, stderr, charset, Duration.ZERO);
+ }
+
+ /**
+ * Execute a command that requires no input and redirects its STDOUT/STDERR streams to the user-provided ones
+ *
+ * @param command The command to execute - without a terminating LF.
+ * @param stdout Standard output stream - if {@code null} then stream data is ignored. Note: if the
+ * stream is not {@code null}, it will be left open when this method returns or an
+ * exception is thrown.
+ * @param stderr Error output stream - if {@code null} then error stream data is ignored. Note: if the
+ * stream is not {@code null}, it will be left open when this method returns or an
+ * exception is thrown.
+ * @param charset The charset to use for encoding the command and decoding the output/error streams. If
+ * {@code null}, US-ASCII is assumed.
+ * @param timeout Timeout for the remote command execution. Applies to both channel opening and result waiting.
+ * A zero or negative value means no timeout.
+ * @throws IOException If the command execution fails, times out, or returns a non-zero exit code. A
+ * {@link RemoteException} may be thrown if the remote side reports an error.
+ * @see ClientChannel#open()#verify(long, java.util.concurrent.TimeUnit)
+ * @see ClientChannel#waitFor(Collection, long)
+ * @see ClientChannel#validateCommandExitStatusCode(String, Integer) validateCommandExitStatusCode
+ */
+ default void executeRemoteCommand(
+ String command, OutputStream stdout, OutputStream stderr, Charset charset, Duration timeout)
+ throws IOException {
+
if (charset == null) {
charset = StandardCharsets.US_ASCII;
}
+ if (timeout != null && timeout.isNegative()) {
+ throw new IllegalArgumentException("Timeout must be non-negative");
+ }
+
try (OutputStream channelErr = (stderr == null) ? new NullOutputStream() : new NoCloseOutputStream(stderr);
OutputStream channelOut = (stdout == null) ? new NullOutputStream() : new NoCloseOutputStream(stdout);
ClientChannel channel = createExecChannel(command, charset, null, Collections.emptyMap())) {
+
channel.setOut(channelOut);
channel.setErr(channelErr);
- channel.open().await(); // TODO use verify and a configurable timeout
- // TODO use a configurable timeout
- Collection waitMask = channel.waitFor(REMOTE_COMMAND_WAIT_EVENTS, 0L);
+ long waitTimeoutMillis;
+ if (timeout != null && !timeout.isZero()) {
+ long timeoutMillis = timeout.toMillis();
+ long startTime = System.currentTimeMillis();
+ channel.open().verify(timeoutMillis, TimeUnit.MILLISECONDS);
+
+ long elapsed = System.currentTimeMillis() - startTime;
+ waitTimeoutMillis = Math.max(1, timeoutMillis - elapsed);
+ } else {
+ channel.open().verify(); // wait indefinitely
+ waitTimeoutMillis = 0L;
+ }
+
+ Collection waitMask = channel.waitFor(REMOTE_COMMAND_WAIT_EVENTS, waitTimeoutMillis);
if (waitMask.contains(ClientChannelEvent.TIMEOUT)) {
- throw new SocketTimeoutException("Failed to retrieve command result in time: " + command);
+ throw new SocketTimeoutException(String.format(
+ "Failed to retrieve command '%s' result within timeout of %s ms", command, timeout));
}
Integer exitStatus = channel.getExitStatus();
diff --git a/sshd-core/src/test/java/org/apache/sshd/client/session/ClientSessionTest.java b/sshd-core/src/test/java/org/apache/sshd/client/session/ClientSessionTest.java
index ff728425f..98885e4a6 100644
--- a/sshd-core/src/test/java/org/apache/sshd/client/session/ClientSessionTest.java
+++ b/sshd-core/src/test/java/org/apache/sshd/client/session/ClientSessionTest.java
@@ -23,9 +23,11 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
+import java.net.SocketTimeoutException;
import java.nio.charset.StandardCharsets;
import java.rmi.RemoteException;
import java.rmi.ServerException;
+import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -58,12 +60,6 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertSame;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-import static org.junit.jupiter.api.Assertions.fail;
-
/**
* @author Apache MINA SSHD Project
*/
@@ -240,6 +236,72 @@ protected boolean handleCommandLine(String command) throws Exception {
assertEquals(Integer.toString(expectedErrorCode), actualErrorMessage, "Mismatched captured error code");
}
+ @Test
+ void executeCommandMethodWithConfigurableTimeout() throws Exception {
+ String expectedCommand = getCurrentTestName() + "-CMD";
+ String expectedResponse = getCurrentTestName() + "-RSP";
+ Duration timeout = Duration.ofMillis(10000L);
+ sshd.setCommandFactory((session, command) -> new CommandExecutionHelper(command) {
+ private boolean cmdProcessed;
+
+ @Override
+ protected boolean handleCommandLine(String command) throws Exception {
+ assertEquals(expectedCommand, command, "Mismatched incoming command");
+ assertFalse(cmdProcessed, "Duplicated command call");
+ OutputStream stdout = getOutputStream();
+ Thread.sleep(500L);
+ stdout.write(expectedResponse.getBytes(StandardCharsets.US_ASCII));
+ stdout.flush();
+ cmdProcessed = true;
+ return false;
+ }
+ });
+
+ try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
+ .verify(CONNECT_TIMEOUT)
+ .getSession()) {
+ session.addPasswordIdentity(getCurrentTestName());
+ session.auth().verify(AUTH_TIMEOUT);
+
+ // NOTE !!! The LF is only because we are using a buffered reader on the server end to read the command
+ String actualResponse = session.executeRemoteCommand(expectedCommand + "\n", timeout);
+ assertEquals(expectedResponse, actualResponse, "Mismatched command response");
+ }
+ }
+
+ @Test
+ void exceptionThrownOnExecuteCommandTimeout() throws Exception {
+ String expectedCommand = getCurrentTestName() + "-CMD";
+ Duration timeout = Duration.ofMillis(500L);
+
+ sshd.setCommandFactory((session, command) -> new CommandExecutionHelper(command) {
+ private boolean cmdProcessed;
+
+ @Override
+ protected boolean handleCommandLine(String command) throws Exception {
+ assertEquals(expectedCommand, command, "Mismatched incoming command");
+ assertFalse(cmdProcessed, "Duplicated command call");
+ Thread.sleep(timeout.plusMillis(200L).toMillis());
+ OutputStream stdout = getOutputStream();
+ stdout.write(command.getBytes(StandardCharsets.US_ASCII));
+ stdout.flush();
+ cmdProcessed = true;
+ return false;
+ }
+ });
+
+ try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
+ .verify(CONNECT_TIMEOUT)
+ .getSession()) {
+ session.addPasswordIdentity(getCurrentTestName());
+ session.auth().verify(AUTH_TIMEOUT);
+
+ assertThrows(SocketTimeoutException.class, () -> {
+ session.executeRemoteCommand(expectedCommand + "\n", timeout);
+ });
+ }
+ }
+
// see SSHD-859
@Test
void connectionContextPropagation() throws Exception {