diff --git a/pkg/cli/clisqlshell/BUILD.bazel b/pkg/cli/clisqlshell/BUILD.bazel index 622e64023d37..8eaaaf14da88 100644 --- a/pkg/cli/clisqlshell/BUILD.bazel +++ b/pkg/cli/clisqlshell/BUILD.bazel @@ -26,6 +26,7 @@ go_library( "//pkg/sql/scanner", "//pkg/sql/sqlfsm", "//pkg/util/envutil", + "//pkg/util/syncutil", "@com_github_cockroachdb_errors//:errors", "@com_github_knz_go_libedit//:go-libedit", ], diff --git a/pkg/cli/clisqlshell/context.go b/pkg/cli/clisqlshell/context.go index 6ac36880c7d3..227901d3d85f 100644 --- a/pkg/cli/clisqlshell/context.go +++ b/pkg/cli/clisqlshell/context.go @@ -15,6 +15,7 @@ import ( "time" democlusterapi "github.com/cockroachdb/cockroach/pkg/cli/democluster/api" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" ) // Context represents the external configuration of the interactive @@ -71,4 +72,11 @@ type internalContext struct { // current database name, if known. This is maintained on a best-effort basis. dbName string + + // state about the current query. + mu struct { + syncutil.Mutex + cancelFn func() + doneCh chan struct{} + } } diff --git a/pkg/cli/clisqlshell/sql.go b/pkg/cli/clisqlshell/sql.go index d6bbd8385c71..a97e1037a41c 100644 --- a/pkg/cli/clisqlshell/sql.go +++ b/pkg/cli/clisqlshell/sql.go @@ -19,6 +19,7 @@ import ( "net/url" "os" "os/exec" + "os/signal" "path/filepath" "regexp" "sort" @@ -1607,10 +1608,11 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { } // Now run the statement/query. - c.exitErr = c.sqlExecCtx.RunQueryAndFormatResults( - context.Background(), - c.conn, c.iCtx.stdout, c.iCtx.stderr, - clisqlclient.MakeQuery(c.concatLines)) + c.exitErr = c.runWithInterruptableCtx(func(ctx context.Context) error { + return c.sqlExecCtx.RunQueryAndFormatResults(ctx, + c.conn, c.iCtx.stdout, c.iCtx.stderr, + clisqlclient.MakeQuery(c.concatLines)) + }) if c.exitErr != nil { if !c.singleStatement { clierror.OutputError(c.iCtx.stderr, c.exitErr, true /*showSeverity*/, false /*verbose*/) @@ -1640,10 +1642,11 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { if strings.Contains(c.iCtx.autoTrace, "kv") { traceType = "kv" } - if err := c.sqlExecCtx.RunQueryAndFormatResults( - context.Background(), - c.conn, c.iCtx.stdout, c.iCtx.stderr, - clisqlclient.MakeQuery(fmt.Sprintf("SHOW %s TRACE FOR SESSION", traceType))); err != nil { + if err := c.runWithInterruptableCtx(func(ctx context.Context) error { + return c.sqlExecCtx.RunQueryAndFormatResults(ctx, + c.conn, c.iCtx.stdout, c.iCtx.stderr, + clisqlclient.MakeQuery(fmt.Sprintf("SHOW %s TRACE FOR SESSION", traceType))) + }); err != nil { clierror.OutputError(c.iCtx.stderr, err, true /*showSeverity*/, false /*verbose*/) if c.exitErr == nil { // Both the query and SET tracing had encountered no error @@ -1705,6 +1708,9 @@ func NewShell( // RunInteractive implements the Shell interface. func (c *cliState) RunInteractive(cmdIn, cmdOut, cmdErr *os.File) (exitErr error) { + finalFn := c.maybeHandleInterrupt() + defer finalFn() + return c.doRunShell(cliStart, cmdIn, cmdOut, cmdErr) } @@ -1986,3 +1992,85 @@ func (c *cliState) serverSideParse(sql string) (helpText string, err error) { } return "", nil } + +func (c *cliState) maybeHandleInterrupt() func() { + if !c.cliCtx.IsInteractive { + return func() {} + } + intCh := make(chan os.Signal, 1) + signal.Notify(intCh, os.Interrupt) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + for { + select { + case <-intCh: + c.iCtx.mu.Lock() + cancelFn, doneCh := c.iCtx.mu.cancelFn, c.iCtx.mu.doneCh + c.iCtx.mu.Unlock() + if cancelFn == nil { + // No query currently executing; nothing to do. + continue + } + + fmt.Fprintf(c.iCtx.stderr, "\nattempting to cancel query...\n") + // Cancel the query's context, which should make the driver + // send a cancellation message. + cancelFn() + + // Now wait for the shell to process the cancellation. + // + // If it takes too long (e.g. server has become unresponsive, + // or we're connected to a pre-22.1 server which does not + // support cancellation), fall back to the previous behavior + // which is to interrupt the shell altogether. + tooLongTimer := time.After(3 * time.Second) + wait: + for { + select { + case <-doneCh: + break wait + case <-tooLongTimer: + fmt.Fprintln(c.iCtx.stderr, "server does not respond to query cancellation; a second interrupt will stop the shell.") + signal.Reset(os.Interrupt) + } + } + // Re-arm the signal handler. + signal.Notify(intCh, os.Interrupt) + + case <-ctx.Done(): + // Shell is terminating. + return + } + } + }() + return cancel +} + +func (c *cliState) runWithInterruptableCtx(fn func(ctx context.Context) error) error { + if !c.cliCtx.IsInteractive { + return fn(context.Background()) + } + // The cancellation function can be used by the Ctrl+C handler + // to cancel this query. + ctx, cancel := context.WithCancel(context.Background()) + // doneCh will be used on the return path to teach the Ctrl+C + // handler that the query has been cancelled successfully. + doneCh := make(chan struct{}) + defer func() { close(doneCh) }() + + // Inform the Ctrl+C handler that this query is executing. + c.iCtx.mu.Lock() + c.iCtx.mu.cancelFn = cancel + c.iCtx.mu.doneCh = doneCh + c.iCtx.mu.Unlock() + defer func() { + c.iCtx.mu.Lock() + c.iCtx.mu.cancelFn = nil + c.iCtx.mu.doneCh = nil + c.iCtx.mu.Unlock() + }() + + // Now run the query. + err := fn(ctx) + return err +} diff --git a/pkg/cli/interactive_tests/test_interrupt.tcl b/pkg/cli/interactive_tests/test_interrupt.tcl new file mode 100644 index 000000000000..404163830f00 --- /dev/null +++ b/pkg/cli/interactive_tests/test_interrupt.tcl @@ -0,0 +1,108 @@ +#! /usr/bin/env expect -f + +source [file join [file dirname $argv0] common.tcl] + +start_server $argv + +spawn $argv sql +eexpect "defaultdb>" + +start_test "Check that interrupt with a partial line clears the line." +send "asasaa" +interrupt +eexpect "defaultdb>" +end_test + +start_test "Check that interrupt with a multiline clears the current line." +send "select\r" +eexpect " -> " +send "'XXX'" +interrupt +send "'YYY';\r" +eexpect "column" +eexpect "YYY" +eexpect "defaultdb>" +end_test + +start_test "Check that interrupt with an empty line prints an information message." +interrupt +eexpect "terminate input to exit" +eexpect "defaultdb>" +end_test + +start_test "Check that interrupt can cancel a query." +send "select pg_sleep(1000000);\r" +eexpect "\r\n" +sleep 0.4 +interrupt +eexpect "query execution canceled" +eexpect "57014" + +# TODO(knz): we currently need to trigger a reconnection +# before we get a healthy prompt. This will be fixed +# in a later version. +send "\r" +eexpect "defaultdb>" +end_test + +# Quit the SQL client, and open a unix shell. +send_eof +eexpect eof +spawn /bin/bash +set shell2_spawn_id $spawn_id +send "PS1=':''/# '\r" +eexpect ":/# " + +start_test "Check that interactive, non-terminal queries are cancellable without terminating the shell." +send "$argv sql >/dev/null\r" +eexpect "\r\n" +sleep 0.4 +send "select 'A'+3;\r" +eexpect "\r\n" +# Sanity check: we can still process _some_ SQL. +# stderr is not redirected so we still see errors. +eexpect "unsupported binary operator" + +# Now on to check cancellation. +send "select pg_sleep(10000);\r" +sleep 0.4 +interrupt +eexpect "query execution canceled" +eexpect "57014" + +# TODO(knz): we currently need to trigger a reconnection +# before we get a healthy prompt. This will be fixed +# in a later version. +send "\rselect 1;\r" + +# Send another query, expect an error. The shell should +# not have terminated by this point. +send "select 'A'+3;\r" +eexpect "\r\n" +eexpect "unsupported binary operator" + +# Terminate SQL shell, expect unix shell. +send_eof +eexpect ":/# " + +start_test "Check that non-interactive interrupts terminate the SQL shell." +send "cat | $argv sql\r" +eexpect "\r\n" +sleep 0.4 +# Sanity check. +send "select 'XX'||'YY';\r" +eexpect "XXYY" + +# Check what interrupt does. +send "select pg_sleep(10000);\r" +sleep 0.4 +interrupt +# This exits the SQL shell directly. expect unix shell. +eexpect ":/# " + +end_test + +send "exit 0\r" +eexpect eof + +stop_server $argv