diff --git a/internal/api/websocket.go b/internal/api/websocket.go index cad2a8b..c5d2e1b 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -24,8 +24,11 @@ type webSocketConnection interface { SetCloseHandler(handler func(code int, text string) error) } +// WebSocketReader is an interface that is intended for providing abstraction around reading from a WebSocket. +// Besides io.Reader, it also implements io.Writer. The Write method is used to inject data into the WebSocket stream. type WebSocketReader interface { io.Reader + io.Writer startReadInputLoop() context.CancelFunc } @@ -38,12 +41,17 @@ type codeOceanToRawReader struct { // and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here // instead of bytes.Buffer. buffer chan byte + // The priorityBuffer is a buffer for injecting data into stdin of the execution from Poseidon, + // for example the character that causes the tty to generate a SIGQUIT signal. + // It is always read before the regular buffer. + priorityBuffer chan byte } func newCodeOceanToRawReader(connection webSocketConnection) *codeOceanToRawReader { return &codeOceanToRawReader{ - connection: connection, - buffer: make(chan byte, CodeOceanToRawReaderBufferSize), + connection: connection, + buffer: make(chan byte, CodeOceanToRawReaderBufferSize), + priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize), } } @@ -101,16 +109,20 @@ func (cr *codeOceanToRawReader) startReadInputLoop() context.CancelFunc { } // Read implements the io.Reader interface. -// It returns bytes from the buffer. +// It returns bytes from the buffer or priorityBuffer. func (cr *codeOceanToRawReader) Read(p []byte) (int, error) { if len(p) == 0 { return 0, nil } // Ensure to not return until at least one byte has been read to avoid busy waiting. - p[0] = <-cr.buffer + select { + case p[0] = <-cr.priorityBuffer: + case p[0] = <-cr.buffer: + } var n int for n = 1; n < len(p); n++ { select { + case p[n] = <-cr.priorityBuffer: case p[n] = <-cr.buffer: default: return n, nil @@ -119,6 +131,20 @@ func (cr *codeOceanToRawReader) Read(p []byte) (int, error) { return n, nil } +// Write implements the io.Writer interface. +// Data written to a codeOceanToRawReader using this method is returned by Read before other data from the WebSocket. +func (cr *codeOceanToRawReader) Write(p []byte) (n int, err error) { + var c byte + for n, c = range p { + select { + case cr.priorityBuffer <- c: + default: + break + } + } + return n, nil +} + // rawToCodeOceanWriter is an io.Writer implementation that, when written to, wraps the written data in the appropriate // json structure and sends it to the CodeOcean via WebSocket. type rawToCodeOceanWriter struct { diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go index cecdfd7..7111d39 100644 --- a/internal/api/websocket_test.go +++ b/internal/api/websocket_test.go @@ -223,7 +223,7 @@ func (s *WebSocketTestSuite) TestWebsocketError() { s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure)) _, _, errMessages := helpers.WebSocketOutputMessages(messages) - s.Equal(1, len(errMessages)) + s.Require().Equal(1, len(errMessages)) s.Equal("Error executing the request", errMessages[0]) } @@ -370,10 +370,12 @@ func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *tes // --- Test suite specific test helpers --- -func newNomadAllocationWithMockedAPIClient(runnerID string) (r runner.Runner, executorAPIMock *nomad.ExecutorAPIMock) { - executorAPIMock = &nomad.ExecutorAPIMock{} - r = runner.NewNomadJob(runnerID, nil, executorAPIMock, nil) - return +func newNomadAllocationWithMockedAPIClient(runnerID string) (runner.Runner, *nomad.ExecutorAPIMock) { + executorAPIMock := &nomad.ExecutorAPIMock{} + manager := &runner.ManagerMock{} + manager.On("Return", mock.Anything).Return(nil) + r := runner.NewNomadJob(runnerID, nil, executorAPIMock, manager) + return r, executorAPIMock } func webSocketURL(scheme string, server *httptest.Server, router *mux.Router, @@ -429,14 +431,21 @@ func mockAPIExecuteHead(api *nomad.ExecutorAPIMock) { var executionRequestSleep = dto.ExecutionRequest{Command: "sleep infinity"} -// mockAPIExecuteSleep mocks the ExecuteCommand method of an ExecutorAPI to sleep until the execution is canceled. +// mockAPIExecuteSleep mocks the ExecuteCommand method of an ExecutorAPI to sleep +// until the execution receives a SIGQUIT. func mockAPIExecuteSleep(api *nomad.ExecutorAPIMock) <-chan bool { canceled := make(chan bool, 1) mockAPIExecute(api, &executionRequestSleep, func(_ string, ctx context.Context, _ []string, _ bool, stdin io.Reader, stdout io.Writer, stderr io.Writer, ) (int, error) { - <-ctx.Done() + var err error + buffer := make([]byte, 1) //nolint:makezero // if the length is zero, the Read call never reads anything + for n := 0; !(n == 1 && buffer[0] == runner.SIGQUIT); n, err = stdin.Read(buffer) { + if err != nil { + return 0, fmt.Errorf("error while reading stdin: %w", err) + } + } close(canceled) return 0, ctx.Err() }) diff --git a/internal/nomad/nomad.go b/internal/nomad/nomad.go index 8881083..a56c007 100644 --- a/internal/nomad/nomad.go +++ b/internal/nomad/nomad.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/config" "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/logging" - "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullreader" + "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullio" "io" "net/url" "strconv" @@ -348,7 +348,7 @@ func (a *APIClient) executeCommandInteractivelyWithStderr(allocationID string, c go func() { // Catch stderr in separate execution. exit, err := a.Execute(allocationID, ctx, stderrFifoCommand(currentNanoTime), true, - nullreader.NullReader{}, stderr, io.Discard) + nullio.Reader{}, stderr, io.Discard) if err != nil { log.WithError(err).WithField("runner", allocationID).Warn("Stderr task finished with error") } diff --git a/internal/nomad/nomad_test.go b/internal/nomad/nomad_test.go index e89464e..d96eead 100644 --- a/internal/nomad/nomad_test.go +++ b/internal/nomad/nomad_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/config" - "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullreader" + "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullio" "gitlab.hpi.de/codeocean/codemoon/poseidon/tests" "io" "net/url" @@ -679,7 +679,7 @@ func (s *ExecuteCommandTestSuite) TestWithSeparateStderr() { }) exitCode, err := s.nomadAPIClient. - ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, &stdout, &stderr) + ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, &stdout, &stderr) s.Require().NoError(err) s.apiMock.AssertNumberOfCalls(s.T(), "Execute", 2) @@ -710,7 +710,7 @@ func (s *ExecuteCommandTestSuite) TestWithSeparateStderrReturnsCommandError() { s.mockExecute(s.testCommandArray, 1, tests.ErrDefault, func(args mock.Arguments) {}) s.mockExecute(mock.AnythingOfType("[]string"), 1, nil, func(args mock.Arguments) {}) _, err := s.nomadAPIClient. - ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, io.Discard, io.Discard) + ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, io.Discard, io.Discard) s.Equal(tests.ErrDefault, err) } @@ -732,7 +732,7 @@ func (s *ExecuteCommandTestSuite) TestWithoutSeparateStderr() { }) exitCode, err := s.nomadAPIClient. - ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, &stdout, &stderr) + ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, &stdout, &stderr) s.Require().NoError(err) s.apiMock.AssertNumberOfCalls(s.T(), "Execute", 1) @@ -745,7 +745,7 @@ func (s *ExecuteCommandTestSuite) TestWithoutSeparateStderrReturnsCommandError() config.Config.Server.InteractiveStderr = false s.mockExecute(s.testCommandArray, 1, tests.ErrDefault, func(args mock.Arguments) {}) _, err := s.nomadAPIClient. - ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, io.Discard, io.Discard) + ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, io.Discard, io.Discard) s.ErrorIs(err, tests.ErrDefault) } diff --git a/internal/runner/runner.go b/internal/runner/runner.go index fae635e..87eed0c 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -25,6 +25,11 @@ type ExecutionID string const ( // runnerContextKey is the key used to store runners in context.Context. runnerContextKey ContextKey = "runner" + // SIGQUIT is the character that causes a tty to send the SIGQUIT signal to the controlled process. + SIGQUIT = 0x1c + // executionTimeoutGracePeriod is the time to wait after sending a SIGQUIT signal to a timed out execution. + // If the execution does not return after this grace period, the runner is destroyed. + executionTimeoutGracePeriod = 3 * time.Second ) var ( @@ -143,7 +148,7 @@ type Runner interface { // Output from the runner is forwarded immediately. ExecuteInteractively( request *dto.ExecutionRequest, - stdin io.Reader, + stdin io.ReadWriter, stdout, stderr io.Writer, ) (exit <-chan ExitInfo, cancel context.CancelFunc) @@ -151,6 +156,9 @@ type Runner interface { // UpdateFileSystem processes a dto.UpdateFileSystemRequest by first deleting each given dto.FilePath recursively // and then copying each given dto.File to the runner. UpdateFileSystem(request *dto.UpdateFileSystemRequest) error + + // Destroy destroys the Runner in Nomad. + Destroy() error } // NomadJob is an abstraction to communicate with Nomad environments. @@ -160,6 +168,7 @@ type NomadJob struct { id string portMappings []nomadApi.PortMapping api nomad.ExecutorAPI + manager Manager } // NewNomadJob creates a new NomadJob with the provided id. @@ -171,6 +180,7 @@ func NewNomadJob(id string, portMappings []nomadApi.PortMapping, portMappings: portMappings, api: apiClient, ExecutionStorage: NewLocalExecutionStorage(), + manager: manager, } job.InactivityTimer = NewInactivityTimer(job, manager) return job @@ -196,30 +206,80 @@ type ExitInfo struct { Err error } -func (r *NomadJob) ExecuteInteractively( - request *dto.ExecutionRequest, - stdin io.Reader, - stdout, stderr io.Writer, -) (<-chan ExitInfo, context.CancelFunc) { - r.ResetTimeout() - - command := request.FullCommand() - var ctx context.Context - var cancel context.CancelFunc +func prepareExecution(request *dto.ExecutionRequest) ( + command []string, ctx context.Context, cancel context.CancelFunc, +) { + command = request.FullCommand() if request.TimeLimit == 0 { ctx, cancel = context.WithCancel(context.Background()) } else { ctx, cancel = context.WithTimeout(context.Background(), time.Duration(request.TimeLimit)*time.Second) } - exit := make(chan ExitInfo) - go func() { - exitCode, err := r.api.ExecuteCommand(r.id, ctx, command, true, stdin, stdout, stderr) - if err == nil && r.TimeoutPassed() { - err = ErrorRunnerInactivityTimeout - } - exit <- ExitInfo{uint8(exitCode), err} + return command, ctx, cancel +} + +func (r *NomadJob) executeCommand(ctx context.Context, command []string, + stdin io.ReadWriter, stdout, stderr io.Writer, exit chan<- ExitInfo, +) { + exitCode, err := r.api.ExecuteCommand(r.id, ctx, command, true, stdin, stdout, stderr) + if err == nil && r.TimeoutPassed() { + err = ErrorRunnerInactivityTimeout + } + exit <- ExitInfo{uint8(exitCode), err} +} + +func (r *NomadJob) handleExitOrContextDone(ctx context.Context, cancelExecute context.CancelFunc, + exitInternal <-chan ExitInfo, exit chan<- ExitInfo, stdin io.ReadWriter, +) { + defer cancelExecute() + select { + case exitInfo := <-exitInternal: + exit <- exitInfo close(exit) - }() + return + case <-ctx.Done(): + // From this time on until the WebSocket connection to the client is closed in /internal/api/websocket.go + // waitForExit, output can still be forwarded to the client. We accept this race condition because adding + // a locking mechanism would complicate the interfaces used (currently io.Writer). + exit <- ExitInfo{255, ctx.Err()} + close(exit) + } + // This injects the SIGQUIT character into the stdin. This character is parsed by the tty line discipline + // (tty has to be true) and converted to a SIGQUIT signal sent to the foreground process attached to the tty. + // By default, SIGQUIT causes the process to terminate and produces a core dump. Processes can catch this signal + // and ignore it, which is why we destroy the runner if the process does not terminate after a grace period. + n, err := stdin.Write([]byte{SIGQUIT}) + if n != 1 { + log.WithField("runner", r.id).Warn("Could not send SIGQUIT because nothing was written") + } + if err != nil { + log.WithField("runner", r.id).WithError(err).Warn("Could not send SIGQUIT due to error") + } + + select { + case <-exitInternal: + log.WithField("runner", r.id).Debug("Execution terminated after SIGQUIT") + case <-time.After(executionTimeoutGracePeriod): + log.WithField("runner", r.id).Info("Execution did not quit after SIGQUIT") + if err := r.Destroy(); err != nil { + log.WithField("runner", r.id).Error("Error when destroying runner") + } + } +} + +func (r *NomadJob) ExecuteInteractively( + request *dto.ExecutionRequest, + stdin io.ReadWriter, + stdout, stderr io.Writer, +) (<-chan ExitInfo, context.CancelFunc) { + r.ResetTimeout() + + command, ctx, cancel := prepareExecution(request) + exitInternal := make(chan ExitInfo) + exit := make(chan ExitInfo, 1) + ctxExecute, cancelExecute := context.WithCancel(context.Background()) + go r.executeCommand(ctxExecute, command, stdin, stdout, stderr, exitInternal) + go r.handleExitOrContextDone(ctx, cancelExecute, exitInternal, exit, stdin) return exit, cancel } @@ -255,6 +315,13 @@ func (r *NomadJob) UpdateFileSystem(copyRequest *dto.UpdateFileSystemRequest) er return nil } +func (r *NomadJob) Destroy() error { + if err := r.manager.Return(r); err != nil { + return fmt.Errorf("error while destroying runner: %w", err) + } + return nil +} + func createTarArchiveForFiles(filesToCopy []dto.File, w io.Writer) error { tarWriter := tar.NewWriter(w) for _, file := range filesToCopy { diff --git a/internal/runner/runner_mock.go b/internal/runner/runner_mock.go index 408c454..6ba68d0 100644 --- a/internal/runner/runner_mock.go +++ b/internal/runner/runner_mock.go @@ -23,12 +23,26 @@ func (_m *RunnerMock) Add(id ExecutionID, executionRequest *dto.ExecutionRequest _m.Called(id, executionRequest) } +// Destroy provides a mock function with given fields: +func (_m *RunnerMock) Destroy() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + // ExecuteInteractively provides a mock function with given fields: request, stdin, stdout, stderr -func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin io.Reader, stdout io.Writer, stderr io.Writer) (<-chan ExitInfo, context.CancelFunc) { +func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin io.ReadWriter, stdout io.Writer, stderr io.Writer) (<-chan ExitInfo, context.CancelFunc) { ret := _m.Called(request, stdin, stdout, stderr) var r0 <-chan ExitInfo - if rf, ok := ret.Get(0).(func(*dto.ExecutionRequest, io.Reader, io.Writer, io.Writer) <-chan ExitInfo); ok { + if rf, ok := ret.Get(0).(func(*dto.ExecutionRequest, io.ReadWriter, io.Writer, io.Writer) <-chan ExitInfo); ok { r0 = rf(request, stdin, stdout, stderr) } else { if ret.Get(0) != nil { @@ -37,7 +51,7 @@ func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin } var r1 context.CancelFunc - if rf, ok := ret.Get(1).(func(*dto.ExecutionRequest, io.Reader, io.Writer, io.Writer) context.CancelFunc); ok { + if rf, ok := ret.Get(1).(func(*dto.ExecutionRequest, io.ReadWriter, io.Writer, io.Writer) context.CancelFunc); ok { r1 = rf(request, stdin, stdout, stderr) } else { if ret.Get(1) != nil { @@ -48,7 +62,7 @@ func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin return r0, r1 } -// Id provides a mock function with given fields: +// ID provides a mock function with given fields: func (_m *RunnerMock) ID() string { ret := _m.Called() diff --git a/internal/runner/runner_test.go b/internal/runner/runner_test.go index 0cfebc3..81d7481 100644 --- a/internal/runner/runner_test.go +++ b/internal/runner/runner_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/suite" "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/nomad" "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto" + "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullio" "gitlab.hpi.de/codeocean/codemoon/poseidon/tests" "io" "regexp" @@ -82,6 +83,15 @@ func TestFromContextReturnsIsNotOkWhenContextHasNoRunner(t *testing.T) { assert.False(t, ok) } +func TestDestroyReturnsRunner(t *testing.T) { + manager := &ManagerMock{} + manager.On("Return", mock.Anything).Return(nil) + runner := NewRunner(tests.DefaultRunnerID, manager) + err := runner.Destroy() + assert.NoError(t, err) + manager.AssertCalled(t, "Return", runner) +} + func TestExecuteInteractivelyTestSuite(t *testing.T) { suite.Run(t, new(ExecuteInteractivelyTestSuite)) } @@ -91,6 +101,7 @@ type ExecuteInteractivelyTestSuite struct { runner *NomadJob apiMock *nomad.ExecutorAPIMock timer *InactivityTimerMock + manager *ManagerMock mockedExecuteCommandCall *mock.Call mockedTimeoutPassedCall *mock.Call } @@ -103,11 +114,15 @@ func (s *ExecuteInteractivelyTestSuite) SetupTest() { s.timer = &InactivityTimerMock{} s.timer.On("ResetTimeout").Return() s.mockedTimeoutPassedCall = s.timer.On("TimeoutPassed").Return(false) + s.manager = &ManagerMock{} + s.manager.On("Return", mock.Anything).Return(nil) + s.runner = &NomadJob{ ExecutionStorage: NewLocalExecutionStorage(), InactivityTimer: s.timer, id: tests.DefaultRunnerID, api: s.apiMock, + manager: s.manager, } } @@ -122,15 +137,12 @@ func (s *ExecuteInteractivelyTestSuite) TestCallsApi() { func (s *ExecuteInteractivelyTestSuite) TestReturnsAfterTimeout() { s.mockedExecuteCommandCall.Run(func(args mock.Arguments) { - ctx, ok := args.Get(1).(context.Context) - s.Require().True(ok) - <-ctx.Done() - }). - Return(0, nil) + select {} + }).Return(0, nil) timeLimit := 1 execution := &dto.ExecutionRequest{TimeLimit: timeLimit} - exit, _ := s.runner.ExecuteInteractively(execution, nil, nil, nil) + exit, _ := s.runner.ExecuteInteractively(execution, &nullio.ReadWriter{}, nil, nil) select { case <-exit: @@ -142,20 +154,52 @@ func (s *ExecuteInteractivelyTestSuite) TestReturnsAfterTimeout() { case <-time.After(time.Duration(timeLimit) * time.Second): s.FailNow("ExecuteInteractively should return after the time limit") case exitInfo := <-exit: - s.Equal(uint8(0), exitInfo.Code) + s.Equal(uint8(255), exitInfo.Code) } } +func (s *ExecuteInteractivelyTestSuite) TestSendsSignalAfterTimeout() { + quit := make(chan struct{}) + s.mockedExecuteCommandCall.Run(func(args mock.Arguments) { + stdin, ok := args.Get(4).(io.Reader) + s.Require().True(ok) + buffer := make([]byte, 1) //nolint:makezero,lll // If the length is zero, the Read call never reads anything. gofmt want this alignment. + for n := 0; !(n == 1 && buffer[0] == SIGQUIT); n, _ = stdin.Read(buffer) { //nolint:errcheck,lll // Read returns EOF errors but that is expected. This nolint makes the line too long. + time.After(tests.ShortTimeout) + } + close(quit) + }).Return(0, nil) + timeLimit := 1 + execution := &dto.ExecutionRequest{TimeLimit: timeLimit} + _, _ = s.runner.ExecuteInteractively(execution, bytes.NewBuffer(make([]byte, 1)), nil, nil) + select { + case <-time.After(2 * (time.Duration(timeLimit) * time.Second)): + s.FailNow("The execution should receive a SIGQUIT after the timeout") + case <-quit: + } +} + +func (s *ExecuteInteractivelyTestSuite) TestDestroysRunnerAfterTimeoutAndSignal() { + s.mockedExecuteCommandCall.Run(func(args mock.Arguments) { + select {} + }) + timeLimit := 1 + execution := &dto.ExecutionRequest{TimeLimit: timeLimit} + _, _ = s.runner.ExecuteInteractively(execution, bytes.NewBuffer(make([]byte, 1)), nil, nil) + <-time.After(executionTimeoutGracePeriod + time.Duration(timeLimit)*time.Second + tests.ShortTimeout) + s.manager.AssertCalled(s.T(), "Return", s.runner) +} + func (s *ExecuteInteractivelyTestSuite) TestResetTimerGetsCalled() { execution := &dto.ExecutionRequest{} s.runner.ExecuteInteractively(execution, nil, nil, nil) s.timer.AssertCalled(s.T(), "ResetTimeout") } -func (s *ExecuteInteractivelyTestSuite) TestExitHasTimeoutErrorIfExecutionTimesOut() { +func (s *ExecuteInteractivelyTestSuite) TestExitHasTimeoutErrorIfRunnerTimesOut() { s.mockedTimeoutPassedCall.Return(true) execution := &dto.ExecutionRequest{} - exitChannel, _ := s.runner.ExecuteInteractively(execution, nil, nil, nil) + exitChannel, _ := s.runner.ExecuteInteractively(execution, &nullio.ReadWriter{}, nil, nil) exit := <-exitChannel s.Equal(ErrorRunnerInactivityTimeout, exit.Err) } diff --git a/pkg/nullio/nullio.go b/pkg/nullio/nullio.go new file mode 100644 index 0000000..47d3be3 --- /dev/null +++ b/pkg/nullio/nullio.go @@ -0,0 +1,24 @@ +package nullio + +import ( + "fmt" + "io" +) + +// Reader is a struct that implements the io.Reader interface. Read does not return when called. +type Reader struct{} + +func (r Reader) Read(_ []byte) (int, error) { + // An empty select blocks forever. + select {} +} + +// ReadWriter implements io.ReadWriter and does nothing on Read an Write. +type ReadWriter struct { + Reader +} + +func (nrw *ReadWriter) Write(p []byte) (int, error) { + n, err := io.Discard.Write(p) + return n, fmt.Errorf("error writing to io.Discard: %w", err) +} diff --git a/pkg/nullreader/nullreader_test.go b/pkg/nullio/nullio_test.go similarity index 81% rename from pkg/nullreader/nullreader_test.go rename to pkg/nullio/nullio_test.go index e606555..cee7a57 100644 --- a/pkg/nullreader/nullreader_test.go +++ b/pkg/nullio/nullio_test.go @@ -1,4 +1,4 @@ -package nullreader +package nullio import ( "github.com/stretchr/testify/assert" @@ -9,8 +9,8 @@ import ( const shortTimeout = 100 * time.Millisecond -func TestNullReaderDoesNotReturnImmediately(t *testing.T) { - reader := &NullReader{} +func TestReaderDoesNotReturnImmediately(t *testing.T) { + reader := &Reader{} readerReturned := make(chan bool) go func() { p := make([]byte, 0, 5) diff --git a/pkg/nullreader/nullreader.go b/pkg/nullreader/nullreader.go deleted file mode 100644 index 040ad6e..0000000 --- a/pkg/nullreader/nullreader.go +++ /dev/null @@ -1,10 +0,0 @@ -package nullreader - -// NullReader is a struct that implements the io.Reader interface and returns nothing when reading -// from it. -type NullReader struct{} - -func (r NullReader) Read(_ []byte) (int, error) { - // An empty select blocks forever. - select {} -}