diff --git a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java index b72fc80b6..3a413522f 100644 --- a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java +++ b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java @@ -37,6 +37,7 @@ import java.util.Iterator; import java.util.Map; import java.util.Set; +import java.util.concurrent.TimeUnit; import org.apache.sshd.client.ClientAuthenticationManager; import org.apache.sshd.client.ClientFactoryManager; @@ -235,12 +236,31 @@ ChannelExec createExecChannel(byte[] command, PtyChannelConfigurationHolder ptyC * error or a non-zero exit status was received. If this happens, then a {@link RemoteException} * is thrown with a cause of {@link ServerException} containing the remote captured standard * error - including CR/LF(s) - * @see #executeRemoteCommand(String, OutputStream, Charset) + * @see #executeRemoteCommand(String, Duration) */ default String executeRemoteCommand(String command) throws IOException { + return executeRemoteCommand(command, Duration.ZERO); + } + + /** + * Execute a command that requires no input and returns its output + * + * @param command The command to execute + * @param timeout Timeout for the remote command execution. Applies to both channel opening and result waiting. + * A zero or negative value means no timeout. + * @return The command's standard output result + * @return The command's standard output result (assumed to be in US-ASCII) + * @throws IOException If failed to execute the command - including if anything was written to the standard + * error or a non-zero exit status was received. If this happens, then a {@link RemoteException} + * is thrown with a cause of {@link ServerException} containing the remote captured standard + * error - including CR/LF(s) + * @see #executeRemoteCommand(String, OutputStream, Charset) + * @see #executeRemoteCommand(String, OutputStream, Charset, Duration) + */ + default String executeRemoteCommand(String command, Duration timeout) throws IOException { try (ByteArrayOutputStream stderr = new ByteArrayOutputStream()) { try { - return executeRemoteCommand(command, stderr, StandardCharsets.US_ASCII); + return executeRemoteCommand(command, stderr, StandardCharsets.US_ASCII, timeout); } finally { if (stderr.size() > 0) { String errorMessage = stderr.toString(StandardCharsets.US_ASCII.name()); @@ -264,15 +284,37 @@ default String executeRemoteCommand(String command) throws IOException { * was output to the standard error stream, but does check the reported exit status (if any) for * non-zero value. If non-zero exit status received then a {@link RemoteException} is thrown * with' a {@link ServerException} cause containing the exits value - * @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset) + * @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, Duration) */ default String executeRemoteCommand(String command, OutputStream stderr, Charset charset) throws IOException { + return executeRemoteCommand(command, stderr, charset, Duration.ZERO); + } + + /** + * Execute a command that requires no input and returns its output + * + * @param command The command to execute - without a terminating LF + * @param stderr Standard error output stream - if {@code null} then error stream data is ignored. + * Note: if the stream is not {@code null} then it will be left open when this + * method returns or exception is thrown + * @param charset The command {@link Charset} for input/output/error - if {@code null} then US_ASCII is assumed + * @param timeout Timeout for the remote command execution. Applies to both channel opening and result waiting. + * A zero or negative value means no timeout. + * @return The command's standard output result + * @throws IOException If failed to manage the command channel - Note: the code does not check if anything + * was output to the standard error stream, but does check the reported exit status (if any) for + * non-zero value. If non-zero exit status received then a {@link RemoteException} is thrown + * with' a {@link ServerException} cause containing the exits value + * @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, Duration) + */ + default String executeRemoteCommand(String command, OutputStream stderr, Charset charset, Duration timeout) + throws IOException { if (charset == null) { charset = StandardCharsets.US_ASCII; } try (ByteArrayOutputStream stdout = new ByteArrayOutputStream(Byte.MAX_VALUE)) { - executeRemoteCommand(command, stdout, stderr, charset); + executeRemoteCommand(command, stdout, stderr, charset, timeout); byte[] outBytes = stdout.toByteArray(); return new String(outBytes, charset); } @@ -290,26 +332,70 @@ default String executeRemoteCommand(String command, OutputStream stderr, Charset * thrown * @param charset The command {@link Charset} for output/error - if {@code null} then US_ASCII is assumed * @throws IOException If failed to execute the command or got a non-zero exit status - * @see ClientChannel#validateCommandExitStatusCode(String, Integer) validateCommandExitStatusCode + * @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, Duration) */ default void executeRemoteCommand( String command, OutputStream stdout, OutputStream stderr, Charset charset) throws IOException { + executeRemoteCommand(command, stdout, stderr, charset, Duration.ZERO); + } + + /** + * Execute a command that requires no input and redirects its STDOUT/STDERR streams to the user-provided ones + * + * @param command The command to execute - without a terminating LF. + * @param stdout Standard output stream - if {@code null} then stream data is ignored. Note: if the + * stream is not {@code null}, it will be left open when this method returns or an + * exception is thrown. + * @param stderr Error output stream - if {@code null} then error stream data is ignored. Note: if the + * stream is not {@code null}, it will be left open when this method returns or an + * exception is thrown. + * @param charset The charset to use for encoding the command and decoding the output/error streams. If + * {@code null}, US-ASCII is assumed. + * @param timeout Timeout for the remote command execution. Applies to both channel opening and result waiting. + * A zero or negative value means no timeout. + * @throws IOException If the command execution fails, times out, or returns a non-zero exit code. A + * {@link RemoteException} may be thrown if the remote side reports an error. + * @see ClientChannel#open()#verify(long, java.util.concurrent.TimeUnit) + * @see ClientChannel#waitFor(Collection, long) + * @see ClientChannel#validateCommandExitStatusCode(String, Integer) validateCommandExitStatusCode + */ + default void executeRemoteCommand( + String command, OutputStream stdout, OutputStream stderr, Charset charset, Duration timeout) + throws IOException { + if (charset == null) { charset = StandardCharsets.US_ASCII; } + if (timeout != null && timeout.isNegative()) { + throw new IllegalArgumentException("Timeout must be non-negative"); + } + try (OutputStream channelErr = (stderr == null) ? new NullOutputStream() : new NoCloseOutputStream(stderr); OutputStream channelOut = (stdout == null) ? new NullOutputStream() : new NoCloseOutputStream(stdout); ClientChannel channel = createExecChannel(command, charset, null, Collections.emptyMap())) { + channel.setOut(channelOut); channel.setErr(channelErr); - channel.open().await(); // TODO use verify and a configurable timeout - // TODO use a configurable timeout - Collection waitMask = channel.waitFor(REMOTE_COMMAND_WAIT_EVENTS, 0L); + long waitTimeoutMillis; + if (timeout != null && !timeout.isZero()) { + long timeoutMillis = timeout.toMillis(); + long startTime = System.currentTimeMillis(); + channel.open().verify(timeoutMillis, TimeUnit.MILLISECONDS); + + long elapsed = System.currentTimeMillis() - startTime; + waitTimeoutMillis = Math.max(1, timeoutMillis - elapsed); + } else { + channel.open().verify(); // wait indefinitely + waitTimeoutMillis = 0L; + } + + Collection waitMask = channel.waitFor(REMOTE_COMMAND_WAIT_EVENTS, waitTimeoutMillis); if (waitMask.contains(ClientChannelEvent.TIMEOUT)) { - throw new SocketTimeoutException("Failed to retrieve command result in time: " + command); + throw new SocketTimeoutException(String.format( + "Failed to retrieve command '%s' result within timeout of %s ms", command, timeout)); } Integer exitStatus = channel.getExitStatus(); diff --git a/sshd-core/src/test/java/org/apache/sshd/client/session/ClientSessionTest.java b/sshd-core/src/test/java/org/apache/sshd/client/session/ClientSessionTest.java index ff728425f..98885e4a6 100644 --- a/sshd-core/src/test/java/org/apache/sshd/client/session/ClientSessionTest.java +++ b/sshd-core/src/test/java/org/apache/sshd/client/session/ClientSessionTest.java @@ -23,9 +23,11 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.net.SocketTimeoutException; import java.nio.charset.StandardCharsets; import java.rmi.RemoteException; import java.rmi.ServerException; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -58,12 +60,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestMethodOrder; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; - /** * @author Apache MINA SSHD Project */ @@ -240,6 +236,72 @@ protected boolean handleCommandLine(String command) throws Exception { assertEquals(Integer.toString(expectedErrorCode), actualErrorMessage, "Mismatched captured error code"); } + @Test + void executeCommandMethodWithConfigurableTimeout() throws Exception { + String expectedCommand = getCurrentTestName() + "-CMD"; + String expectedResponse = getCurrentTestName() + "-RSP"; + Duration timeout = Duration.ofMillis(10000L); + sshd.setCommandFactory((session, command) -> new CommandExecutionHelper(command) { + private boolean cmdProcessed; + + @Override + protected boolean handleCommandLine(String command) throws Exception { + assertEquals(expectedCommand, command, "Mismatched incoming command"); + assertFalse(cmdProcessed, "Duplicated command call"); + OutputStream stdout = getOutputStream(); + Thread.sleep(500L); + stdout.write(expectedResponse.getBytes(StandardCharsets.US_ASCII)); + stdout.flush(); + cmdProcessed = true; + return false; + } + }); + + try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port) + .verify(CONNECT_TIMEOUT) + .getSession()) { + session.addPasswordIdentity(getCurrentTestName()); + session.auth().verify(AUTH_TIMEOUT); + + // NOTE !!! The LF is only because we are using a buffered reader on the server end to read the command + String actualResponse = session.executeRemoteCommand(expectedCommand + "\n", timeout); + assertEquals(expectedResponse, actualResponse, "Mismatched command response"); + } + } + + @Test + void exceptionThrownOnExecuteCommandTimeout() throws Exception { + String expectedCommand = getCurrentTestName() + "-CMD"; + Duration timeout = Duration.ofMillis(500L); + + sshd.setCommandFactory((session, command) -> new CommandExecutionHelper(command) { + private boolean cmdProcessed; + + @Override + protected boolean handleCommandLine(String command) throws Exception { + assertEquals(expectedCommand, command, "Mismatched incoming command"); + assertFalse(cmdProcessed, "Duplicated command call"); + Thread.sleep(timeout.plusMillis(200L).toMillis()); + OutputStream stdout = getOutputStream(); + stdout.write(command.getBytes(StandardCharsets.US_ASCII)); + stdout.flush(); + cmdProcessed = true; + return false; + } + }); + + try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port) + .verify(CONNECT_TIMEOUT) + .getSession()) { + session.addPasswordIdentity(getCurrentTestName()); + session.auth().verify(AUTH_TIMEOUT); + + assertThrows(SocketTimeoutException.class, () -> { + session.executeRemoteCommand(expectedCommand + "\n", timeout); + }); + } + } + // see SSHD-859 @Test void connectionContextPropagation() throws Exception {