From 8d24bda61aa07a1c143393e485ca2518d33556c4 Mon Sep 17 00:00:00 2001 From: Konrad Hanff Date: Wed, 21 Jul 2021 09:12:44 +0200 Subject: [PATCH] Send SIGQUIT when cancelling an execution When the context passed to Nomad Allocation Exec is cancelled, the process is not terminated. Instead, just the WebSocket connection is closed. In order to terminate long-running processes, a special character is injected into the standard input stream. This character is parsed by the tty line discipline (tty has to be true). The line discipline sends a SIGQUIT signal to the process, terminating it and producing a core dump (in a file called 'core'). The SIGQUIT signal can be caught but isn't by default, which is why the runner is destroyed if the program does not terminate during a grace period after the signal was sent. --- internal/api/websocket.go | 34 +++++- internal/api/websocket_test.go | 23 ++-- internal/nomad/nomad.go | 4 +- internal/nomad/nomad_test.go | 10 +- internal/runner/runner.go | 105 ++++++++++++++---- internal/runner/runner_mock.go | 22 +++- internal/runner/runner_test.go | 62 +++++++++-- pkg/nullio/nullio.go | 24 ++++ .../nullio_test.go} | 6 +- pkg/nullreader/nullreader.go | 10 -- 10 files changed, 237 insertions(+), 63 deletions(-) create mode 100644 pkg/nullio/nullio.go rename pkg/{nullreader/nullreader_test.go => nullio/nullio_test.go} (81%) delete mode 100644 pkg/nullreader/nullreader.go diff --git a/internal/api/websocket.go b/internal/api/websocket.go index cad2a8b..c5d2e1b 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -24,8 +24,11 @@ type webSocketConnection interface { SetCloseHandler(handler func(code int, text string) error) } +// WebSocketReader is an interface that is intended for providing abstraction around reading from a WebSocket. +// Besides io.Reader, it also implements io.Writer. The Write method is used to inject data into the WebSocket stream. type WebSocketReader interface { io.Reader + io.Writer startReadInputLoop() context.CancelFunc } @@ -38,12 +41,17 @@ type codeOceanToRawReader struct { // and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here // instead of bytes.Buffer. buffer chan byte + // The priorityBuffer is a buffer for injecting data into stdin of the execution from Poseidon, + // for example the character that causes the tty to generate a SIGQUIT signal. + // It is always read before the regular buffer. + priorityBuffer chan byte } func newCodeOceanToRawReader(connection webSocketConnection) *codeOceanToRawReader { return &codeOceanToRawReader{ - connection: connection, - buffer: make(chan byte, CodeOceanToRawReaderBufferSize), + connection: connection, + buffer: make(chan byte, CodeOceanToRawReaderBufferSize), + priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize), } } @@ -101,16 +109,20 @@ func (cr *codeOceanToRawReader) startReadInputLoop() context.CancelFunc { } // Read implements the io.Reader interface. -// It returns bytes from the buffer. +// It returns bytes from the buffer or priorityBuffer. func (cr *codeOceanToRawReader) Read(p []byte) (int, error) { if len(p) == 0 { return 0, nil } // Ensure to not return until at least one byte has been read to avoid busy waiting. - p[0] = <-cr.buffer + select { + case p[0] = <-cr.priorityBuffer: + case p[0] = <-cr.buffer: + } var n int for n = 1; n < len(p); n++ { select { + case p[n] = <-cr.priorityBuffer: case p[n] = <-cr.buffer: default: return n, nil @@ -119,6 +131,20 @@ func (cr *codeOceanToRawReader) Read(p []byte) (int, error) { return n, nil } +// Write implements the io.Writer interface. +// Data written to a codeOceanToRawReader using this method is returned by Read before other data from the WebSocket. +func (cr *codeOceanToRawReader) Write(p []byte) (n int, err error) { + var c byte + for n, c = range p { + select { + case cr.priorityBuffer <- c: + default: + break + } + } + return n, nil +} + // rawToCodeOceanWriter is an io.Writer implementation that, when written to, wraps the written data in the appropriate // json structure and sends it to the CodeOcean via WebSocket. type rawToCodeOceanWriter struct { diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go index cecdfd7..7111d39 100644 --- a/internal/api/websocket_test.go +++ b/internal/api/websocket_test.go @@ -223,7 +223,7 @@ func (s *WebSocketTestSuite) TestWebsocketError() { s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure)) _, _, errMessages := helpers.WebSocketOutputMessages(messages) - s.Equal(1, len(errMessages)) + s.Require().Equal(1, len(errMessages)) s.Equal("Error executing the request", errMessages[0]) } @@ -370,10 +370,12 @@ func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *tes // --- Test suite specific test helpers --- -func newNomadAllocationWithMockedAPIClient(runnerID string) (r runner.Runner, executorAPIMock *nomad.ExecutorAPIMock) { - executorAPIMock = &nomad.ExecutorAPIMock{} - r = runner.NewNomadJob(runnerID, nil, executorAPIMock, nil) - return +func newNomadAllocationWithMockedAPIClient(runnerID string) (runner.Runner, *nomad.ExecutorAPIMock) { + executorAPIMock := &nomad.ExecutorAPIMock{} + manager := &runner.ManagerMock{} + manager.On("Return", mock.Anything).Return(nil) + r := runner.NewNomadJob(runnerID, nil, executorAPIMock, manager) + return r, executorAPIMock } func webSocketURL(scheme string, server *httptest.Server, router *mux.Router, @@ -429,14 +431,21 @@ func mockAPIExecuteHead(api *nomad.ExecutorAPIMock) { var executionRequestSleep = dto.ExecutionRequest{Command: "sleep infinity"} -// mockAPIExecuteSleep mocks the ExecuteCommand method of an ExecutorAPI to sleep until the execution is canceled. +// mockAPIExecuteSleep mocks the ExecuteCommand method of an ExecutorAPI to sleep +// until the execution receives a SIGQUIT. func mockAPIExecuteSleep(api *nomad.ExecutorAPIMock) <-chan bool { canceled := make(chan bool, 1) mockAPIExecute(api, &executionRequestSleep, func(_ string, ctx context.Context, _ []string, _ bool, stdin io.Reader, stdout io.Writer, stderr io.Writer, ) (int, error) { - <-ctx.Done() + var err error + buffer := make([]byte, 1) //nolint:makezero // if the length is zero, the Read call never reads anything + for n := 0; !(n == 1 && buffer[0] == runner.SIGQUIT); n, err = stdin.Read(buffer) { + if err != nil { + return 0, fmt.Errorf("error while reading stdin: %w", err) + } + } close(canceled) return 0, ctx.Err() }) diff --git a/internal/nomad/nomad.go b/internal/nomad/nomad.go index 8881083..a56c007 100644 --- a/internal/nomad/nomad.go +++ b/internal/nomad/nomad.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/config" "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/logging" - "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullreader" + "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullio" "io" "net/url" "strconv" @@ -348,7 +348,7 @@ func (a *APIClient) executeCommandInteractivelyWithStderr(allocationID string, c go func() { // Catch stderr in separate execution. exit, err := a.Execute(allocationID, ctx, stderrFifoCommand(currentNanoTime), true, - nullreader.NullReader{}, stderr, io.Discard) + nullio.Reader{}, stderr, io.Discard) if err != nil { log.WithError(err).WithField("runner", allocationID).Warn("Stderr task finished with error") } diff --git a/internal/nomad/nomad_test.go b/internal/nomad/nomad_test.go index e89464e..d96eead 100644 --- a/internal/nomad/nomad_test.go +++ b/internal/nomad/nomad_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/config" - "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullreader" + "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullio" "gitlab.hpi.de/codeocean/codemoon/poseidon/tests" "io" "net/url" @@ -679,7 +679,7 @@ func (s *ExecuteCommandTestSuite) TestWithSeparateStderr() { }) exitCode, err := s.nomadAPIClient. - ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, &stdout, &stderr) + ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, &stdout, &stderr) s.Require().NoError(err) s.apiMock.AssertNumberOfCalls(s.T(), "Execute", 2) @@ -710,7 +710,7 @@ func (s *ExecuteCommandTestSuite) TestWithSeparateStderrReturnsCommandError() { s.mockExecute(s.testCommandArray, 1, tests.ErrDefault, func(args mock.Arguments) {}) s.mockExecute(mock.AnythingOfType("[]string"), 1, nil, func(args mock.Arguments) {}) _, err := s.nomadAPIClient. - ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, io.Discard, io.Discard) + ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, io.Discard, io.Discard) s.Equal(tests.ErrDefault, err) } @@ -732,7 +732,7 @@ func (s *ExecuteCommandTestSuite) TestWithoutSeparateStderr() { }) exitCode, err := s.nomadAPIClient. - ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, &stdout, &stderr) + ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, &stdout, &stderr) s.Require().NoError(err) s.apiMock.AssertNumberOfCalls(s.T(), "Execute", 1) @@ -745,7 +745,7 @@ func (s *ExecuteCommandTestSuite) TestWithoutSeparateStderrReturnsCommandError() config.Config.Server.InteractiveStderr = false s.mockExecute(s.testCommandArray, 1, tests.ErrDefault, func(args mock.Arguments) {}) _, err := s.nomadAPIClient. - ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, io.Discard, io.Discard) + ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, io.Discard, io.Discard) s.ErrorIs(err, tests.ErrDefault) } diff --git a/internal/runner/runner.go b/internal/runner/runner.go index fae635e..87eed0c 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -25,6 +25,11 @@ type ExecutionID string const ( // runnerContextKey is the key used to store runners in context.Context. runnerContextKey ContextKey = "runner" + // SIGQUIT is the character that causes a tty to send the SIGQUIT signal to the controlled process. + SIGQUIT = 0x1c + // executionTimeoutGracePeriod is the time to wait after sending a SIGQUIT signal to a timed out execution. + // If the execution does not return after this grace period, the runner is destroyed. + executionTimeoutGracePeriod = 3 * time.Second ) var ( @@ -143,7 +148,7 @@ type Runner interface { // Output from the runner is forwarded immediately. ExecuteInteractively( request *dto.ExecutionRequest, - stdin io.Reader, + stdin io.ReadWriter, stdout, stderr io.Writer, ) (exit <-chan ExitInfo, cancel context.CancelFunc) @@ -151,6 +156,9 @@ type Runner interface { // UpdateFileSystem processes a dto.UpdateFileSystemRequest by first deleting each given dto.FilePath recursively // and then copying each given dto.File to the runner. UpdateFileSystem(request *dto.UpdateFileSystemRequest) error + + // Destroy destroys the Runner in Nomad. + Destroy() error } // NomadJob is an abstraction to communicate with Nomad environments. @@ -160,6 +168,7 @@ type NomadJob struct { id string portMappings []nomadApi.PortMapping api nomad.ExecutorAPI + manager Manager } // NewNomadJob creates a new NomadJob with the provided id. @@ -171,6 +180,7 @@ func NewNomadJob(id string, portMappings []nomadApi.PortMapping, portMappings: portMappings, api: apiClient, ExecutionStorage: NewLocalExecutionStorage(), + manager: manager, } job.InactivityTimer = NewInactivityTimer(job, manager) return job @@ -196,30 +206,80 @@ type ExitInfo struct { Err error } -func (r *NomadJob) ExecuteInteractively( - request *dto.ExecutionRequest, - stdin io.Reader, - stdout, stderr io.Writer, -) (<-chan ExitInfo, context.CancelFunc) { - r.ResetTimeout() - - command := request.FullCommand() - var ctx context.Context - var cancel context.CancelFunc +func prepareExecution(request *dto.ExecutionRequest) ( + command []string, ctx context.Context, cancel context.CancelFunc, +) { + command = request.FullCommand() if request.TimeLimit == 0 { ctx, cancel = context.WithCancel(context.Background()) } else { ctx, cancel = context.WithTimeout(context.Background(), time.Duration(request.TimeLimit)*time.Second) } - exit := make(chan ExitInfo) - go func() { - exitCode, err := r.api.ExecuteCommand(r.id, ctx, command, true, stdin, stdout, stderr) - if err == nil && r.TimeoutPassed() { - err = ErrorRunnerInactivityTimeout - } - exit <- ExitInfo{uint8(exitCode), err} + return command, ctx, cancel +} + +func (r *NomadJob) executeCommand(ctx context.Context, command []string, + stdin io.ReadWriter, stdout, stderr io.Writer, exit chan<- ExitInfo, +) { + exitCode, err := r.api.ExecuteCommand(r.id, ctx, command, true, stdin, stdout, stderr) + if err == nil && r.TimeoutPassed() { + err = ErrorRunnerInactivityTimeout + } + exit <- ExitInfo{uint8(exitCode), err} +} + +func (r *NomadJob) handleExitOrContextDone(ctx context.Context, cancelExecute context.CancelFunc, + exitInternal <-chan ExitInfo, exit chan<- ExitInfo, stdin io.ReadWriter, +) { + defer cancelExecute() + select { + case exitInfo := <-exitInternal: + exit <- exitInfo close(exit) - }() + return + case <-ctx.Done(): + // From this time on until the WebSocket connection to the client is closed in /internal/api/websocket.go + // waitForExit, output can still be forwarded to the client. We accept this race condition because adding + // a locking mechanism would complicate the interfaces used (currently io.Writer). + exit <- ExitInfo{255, ctx.Err()} + close(exit) + } + // This injects the SIGQUIT character into the stdin. This character is parsed by the tty line discipline + // (tty has to be true) and converted to a SIGQUIT signal sent to the foreground process attached to the tty. + // By default, SIGQUIT causes the process to terminate and produces a core dump. Processes can catch this signal + // and ignore it, which is why we destroy the runner if the process does not terminate after a grace period. + n, err := stdin.Write([]byte{SIGQUIT}) + if n != 1 { + log.WithField("runner", r.id).Warn("Could not send SIGQUIT because nothing was written") + } + if err != nil { + log.WithField("runner", r.id).WithError(err).Warn("Could not send SIGQUIT due to error") + } + + select { + case <-exitInternal: + log.WithField("runner", r.id).Debug("Execution terminated after SIGQUIT") + case <-time.After(executionTimeoutGracePeriod): + log.WithField("runner", r.id).Info("Execution did not quit after SIGQUIT") + if err := r.Destroy(); err != nil { + log.WithField("runner", r.id).Error("Error when destroying runner") + } + } +} + +func (r *NomadJob) ExecuteInteractively( + request *dto.ExecutionRequest, + stdin io.ReadWriter, + stdout, stderr io.Writer, +) (<-chan ExitInfo, context.CancelFunc) { + r.ResetTimeout() + + command, ctx, cancel := prepareExecution(request) + exitInternal := make(chan ExitInfo) + exit := make(chan ExitInfo, 1) + ctxExecute, cancelExecute := context.WithCancel(context.Background()) + go r.executeCommand(ctxExecute, command, stdin, stdout, stderr, exitInternal) + go r.handleExitOrContextDone(ctx, cancelExecute, exitInternal, exit, stdin) return exit, cancel } @@ -255,6 +315,13 @@ func (r *NomadJob) UpdateFileSystem(copyRequest *dto.UpdateFileSystemRequest) er return nil } +func (r *NomadJob) Destroy() error { + if err := r.manager.Return(r); err != nil { + return fmt.Errorf("error while destroying runner: %w", err) + } + return nil +} + func createTarArchiveForFiles(filesToCopy []dto.File, w io.Writer) error { tarWriter := tar.NewWriter(w) for _, file := range filesToCopy { diff --git a/internal/runner/runner_mock.go b/internal/runner/runner_mock.go index 408c454..6ba68d0 100644 --- a/internal/runner/runner_mock.go +++ b/internal/runner/runner_mock.go @@ -23,12 +23,26 @@ func (_m *RunnerMock) Add(id ExecutionID, executionRequest *dto.ExecutionRequest _m.Called(id, executionRequest) } +// Destroy provides a mock function with given fields: +func (_m *RunnerMock) Destroy() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + // ExecuteInteractively provides a mock function with given fields: request, stdin, stdout, stderr -func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin io.Reader, stdout io.Writer, stderr io.Writer) (<-chan ExitInfo, context.CancelFunc) { +func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin io.ReadWriter, stdout io.Writer, stderr io.Writer) (<-chan ExitInfo, context.CancelFunc) { ret := _m.Called(request, stdin, stdout, stderr) var r0 <-chan ExitInfo - if rf, ok := ret.Get(0).(func(*dto.ExecutionRequest, io.Reader, io.Writer, io.Writer) <-chan ExitInfo); ok { + if rf, ok := ret.Get(0).(func(*dto.ExecutionRequest, io.ReadWriter, io.Writer, io.Writer) <-chan ExitInfo); ok { r0 = rf(request, stdin, stdout, stderr) } else { if ret.Get(0) != nil { @@ -37,7 +51,7 @@ func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin } var r1 context.CancelFunc - if rf, ok := ret.Get(1).(func(*dto.ExecutionRequest, io.Reader, io.Writer, io.Writer) context.CancelFunc); ok { + if rf, ok := ret.Get(1).(func(*dto.ExecutionRequest, io.ReadWriter, io.Writer, io.Writer) context.CancelFunc); ok { r1 = rf(request, stdin, stdout, stderr) } else { if ret.Get(1) != nil { @@ -48,7 +62,7 @@ func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin return r0, r1 } -// Id provides a mock function with given fields: +// ID provides a mock function with given fields: func (_m *RunnerMock) ID() string { ret := _m.Called() diff --git a/internal/runner/runner_test.go b/internal/runner/runner_test.go index 0cfebc3..81d7481 100644 --- a/internal/runner/runner_test.go +++ b/internal/runner/runner_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/suite" "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/nomad" "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto" + "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullio" "gitlab.hpi.de/codeocean/codemoon/poseidon/tests" "io" "regexp" @@ -82,6 +83,15 @@ func TestFromContextReturnsIsNotOkWhenContextHasNoRunner(t *testing.T) { assert.False(t, ok) } +func TestDestroyReturnsRunner(t *testing.T) { + manager := &ManagerMock{} + manager.On("Return", mock.Anything).Return(nil) + runner := NewRunner(tests.DefaultRunnerID, manager) + err := runner.Destroy() + assert.NoError(t, err) + manager.AssertCalled(t, "Return", runner) +} + func TestExecuteInteractivelyTestSuite(t *testing.T) { suite.Run(t, new(ExecuteInteractivelyTestSuite)) } @@ -91,6 +101,7 @@ type ExecuteInteractivelyTestSuite struct { runner *NomadJob apiMock *nomad.ExecutorAPIMock timer *InactivityTimerMock + manager *ManagerMock mockedExecuteCommandCall *mock.Call mockedTimeoutPassedCall *mock.Call } @@ -103,11 +114,15 @@ func (s *ExecuteInteractivelyTestSuite) SetupTest() { s.timer = &InactivityTimerMock{} s.timer.On("ResetTimeout").Return() s.mockedTimeoutPassedCall = s.timer.On("TimeoutPassed").Return(false) + s.manager = &ManagerMock{} + s.manager.On("Return", mock.Anything).Return(nil) + s.runner = &NomadJob{ ExecutionStorage: NewLocalExecutionStorage(), InactivityTimer: s.timer, id: tests.DefaultRunnerID, api: s.apiMock, + manager: s.manager, } } @@ -122,15 +137,12 @@ func (s *ExecuteInteractivelyTestSuite) TestCallsApi() { func (s *ExecuteInteractivelyTestSuite) TestReturnsAfterTimeout() { s.mockedExecuteCommandCall.Run(func(args mock.Arguments) { - ctx, ok := args.Get(1).(context.Context) - s.Require().True(ok) - <-ctx.Done() - }). - Return(0, nil) + select {} + }).Return(0, nil) timeLimit := 1 execution := &dto.ExecutionRequest{TimeLimit: timeLimit} - exit, _ := s.runner.ExecuteInteractively(execution, nil, nil, nil) + exit, _ := s.runner.ExecuteInteractively(execution, &nullio.ReadWriter{}, nil, nil) select { case <-exit: @@ -142,20 +154,52 @@ func (s *ExecuteInteractivelyTestSuite) TestReturnsAfterTimeout() { case <-time.After(time.Duration(timeLimit) * time.Second): s.FailNow("ExecuteInteractively should return after the time limit") case exitInfo := <-exit: - s.Equal(uint8(0), exitInfo.Code) + s.Equal(uint8(255), exitInfo.Code) } } +func (s *ExecuteInteractivelyTestSuite) TestSendsSignalAfterTimeout() { + quit := make(chan struct{}) + s.mockedExecuteCommandCall.Run(func(args mock.Arguments) { + stdin, ok := args.Get(4).(io.Reader) + s.Require().True(ok) + buffer := make([]byte, 1) //nolint:makezero,lll // If the length is zero, the Read call never reads anything. gofmt want this alignment. + for n := 0; !(n == 1 && buffer[0] == SIGQUIT); n, _ = stdin.Read(buffer) { //nolint:errcheck,lll // Read returns EOF errors but that is expected. This nolint makes the line too long. + time.After(tests.ShortTimeout) + } + close(quit) + }).Return(0, nil) + timeLimit := 1 + execution := &dto.ExecutionRequest{TimeLimit: timeLimit} + _, _ = s.runner.ExecuteInteractively(execution, bytes.NewBuffer(make([]byte, 1)), nil, nil) + select { + case <-time.After(2 * (time.Duration(timeLimit) * time.Second)): + s.FailNow("The execution should receive a SIGQUIT after the timeout") + case <-quit: + } +} + +func (s *ExecuteInteractivelyTestSuite) TestDestroysRunnerAfterTimeoutAndSignal() { + s.mockedExecuteCommandCall.Run(func(args mock.Arguments) { + select {} + }) + timeLimit := 1 + execution := &dto.ExecutionRequest{TimeLimit: timeLimit} + _, _ = s.runner.ExecuteInteractively(execution, bytes.NewBuffer(make([]byte, 1)), nil, nil) + <-time.After(executionTimeoutGracePeriod + time.Duration(timeLimit)*time.Second + tests.ShortTimeout) + s.manager.AssertCalled(s.T(), "Return", s.runner) +} + func (s *ExecuteInteractivelyTestSuite) TestResetTimerGetsCalled() { execution := &dto.ExecutionRequest{} s.runner.ExecuteInteractively(execution, nil, nil, nil) s.timer.AssertCalled(s.T(), "ResetTimeout") } -func (s *ExecuteInteractivelyTestSuite) TestExitHasTimeoutErrorIfExecutionTimesOut() { +func (s *ExecuteInteractivelyTestSuite) TestExitHasTimeoutErrorIfRunnerTimesOut() { s.mockedTimeoutPassedCall.Return(true) execution := &dto.ExecutionRequest{} - exitChannel, _ := s.runner.ExecuteInteractively(execution, nil, nil, nil) + exitChannel, _ := s.runner.ExecuteInteractively(execution, &nullio.ReadWriter{}, nil, nil) exit := <-exitChannel s.Equal(ErrorRunnerInactivityTimeout, exit.Err) } diff --git a/pkg/nullio/nullio.go b/pkg/nullio/nullio.go new file mode 100644 index 0000000..47d3be3 --- /dev/null +++ b/pkg/nullio/nullio.go @@ -0,0 +1,24 @@ +package nullio + +import ( + "fmt" + "io" +) + +// Reader is a struct that implements the io.Reader interface. Read does not return when called. +type Reader struct{} + +func (r Reader) Read(_ []byte) (int, error) { + // An empty select blocks forever. + select {} +} + +// ReadWriter implements io.ReadWriter and does nothing on Read an Write. +type ReadWriter struct { + Reader +} + +func (nrw *ReadWriter) Write(p []byte) (int, error) { + n, err := io.Discard.Write(p) + return n, fmt.Errorf("error writing to io.Discard: %w", err) +} diff --git a/pkg/nullreader/nullreader_test.go b/pkg/nullio/nullio_test.go similarity index 81% rename from pkg/nullreader/nullreader_test.go rename to pkg/nullio/nullio_test.go index e606555..cee7a57 100644 --- a/pkg/nullreader/nullreader_test.go +++ b/pkg/nullio/nullio_test.go @@ -1,4 +1,4 @@ -package nullreader +package nullio import ( "github.com/stretchr/testify/assert" @@ -9,8 +9,8 @@ import ( const shortTimeout = 100 * time.Millisecond -func TestNullReaderDoesNotReturnImmediately(t *testing.T) { - reader := &NullReader{} +func TestReaderDoesNotReturnImmediately(t *testing.T) { + reader := &Reader{} readerReturned := make(chan bool) go func() { p := make([]byte, 0, 5) diff --git a/pkg/nullreader/nullreader.go b/pkg/nullreader/nullreader.go deleted file mode 100644 index 040ad6e..0000000 --- a/pkg/nullreader/nullreader.go +++ /dev/null @@ -1,10 +0,0 @@ -package nullreader - -// NullReader is a struct that implements the io.Reader interface and returns nothing when reading -// from it. -type NullReader struct{} - -func (r NullReader) Read(_ []byte) (int, error) { - // An empty select blocks forever. - select {} -}