diff --git a/internal/api/runners.go b/internal/api/runners.go index c82bd3f..1f35c57 100644 --- a/internal/api/runners.go +++ b/internal/api/runners.go @@ -8,7 +8,6 @@ import ( "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/config" "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner" "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto" - "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/execution" "net/http" "net/url" ) @@ -113,8 +112,8 @@ func (r *RunnerController) execute(writer http.ResponseWriter, request *http.Req writeInternalServerError(writer, err, dto.ErrorUnknown) return } - id := execution.ID(newUUID.String()) - targetRunner.Add(id, executionRequest) + id := newUUID.String() + targetRunner.StoreExecution(id, executionRequest) webSocketURL := url.URL{ Scheme: scheme, Host: request.Host, diff --git a/internal/api/runners_test.go b/internal/api/runners_test.go index 8cf761d..a73ad2a 100644 --- a/internal/api/runners_test.go +++ b/internal/api/runners_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/suite" "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner" "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto" - "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/execution" "gitlab.hpi.de/codeocean/codemoon/poseidon/tests" "net/http" "net/http/httptest" @@ -86,15 +85,15 @@ type RunnerRouteTestSuite struct { runnerManager *runner.ManagerMock router *mux.Router runner runner.Runner - executionID execution.ID + executionID string } func (s *RunnerRouteTestSuite) SetupTest() { s.runnerManager = &runner.ManagerMock{} s.router = NewRouter(s.runnerManager, nil) s.runner = runner.NewNomadJob("some-id", nil, nil, nil) - s.executionID = "execution-id" - s.runner.Add(s.executionID, &dto.ExecutionRequest{}) + s.executionID = "execution" + s.runner.StoreExecution(s.executionID, &dto.ExecutionRequest{}) s.runnerManager.On("Get", s.runner.ID()).Return(s.runner, nil) } @@ -201,10 +200,9 @@ func (s *RunnerRouteTestSuite) TestExecuteRoute() { webSocketURL, err := url.Parse(webSocketResponse.WebSocketURL) s.Require().NoError(err) executionID := webSocketURL.Query().Get(ExecutionIDKey) - storedExecutionRequest, ok := s.runner.Pop(execution.ID(executionID)) + ok := s.runner.ExecutionExists(executionID) s.True(ok, "No execution request with this id: ", executionID) - s.Equal(&executionRequest, storedExecutionRequest) }) }) diff --git a/internal/api/websocket.go b/internal/api/websocket.go index db98595..f372196 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -8,7 +8,6 @@ import ( "github.com/gorilla/websocket" "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner" "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto" - "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/execution" "io" "net/http" "sync" @@ -299,9 +298,8 @@ func (wp *webSocketProxy) writeMessage(messageType int, data []byte) error { // connectToRunner is the endpoint for websocket connections. func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request *http.Request) { targetRunner, _ := runner.FromContext(request.Context()) - executionID := execution.ID(request.URL.Query().Get(ExecutionIDKey)) - executionRequest, ok := targetRunner.Pop(executionID) - if !ok { + executionID := request.URL.Query().Get(ExecutionIDKey) + if !targetRunner.ExecutionExists(executionID) { writeNotFound(writer, ErrUnknownExecutionID) return } @@ -317,7 +315,11 @@ func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request * } log.WithField("runnerId", targetRunner.ID()).WithField("executionID", executionID).Info("Running execution") - exit, cancel := targetRunner.ExecuteInteractively(executionRequest, proxy.Stdin, proxy.Stdout, proxy.Stderr) + exit, cancel, err := targetRunner.ExecuteInteractively(executionID, proxy.Stdin, proxy.Stdout, proxy.Stderr) + if err != nil { + proxy.closeWithError(fmt.Sprintf("execution failed with: %v", err)) + return + } proxy.waitForExit(exit, cancel) } diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go index 4ec024e..fa9d798 100644 --- a/internal/api/websocket_test.go +++ b/internal/api/websocket_test.go @@ -15,7 +15,6 @@ import ( "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/nomad" "gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner" "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto" - "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/execution" "gitlab.hpi.de/codeocean/codemoon/poseidon/tests" "gitlab.hpi.de/codeocean/codemoon/poseidon/tests/helpers" "io" @@ -34,7 +33,7 @@ func TestWebSocketTestSuite(t *testing.T) { type WebSocketTestSuite struct { suite.Suite router *mux.Router - executionID execution.ID + executionID string runner runner.Runner apiMock *nomad.ExecutorAPIMock server *httptest.Server @@ -46,7 +45,7 @@ func (s *WebSocketTestSuite) SetupTest() { // default execution s.executionID = "execution-id" - s.runner.Add(s.executionID, &executionRequestHead) + s.runner.StoreExecution(s.executionID, &executionRequestHead) mockAPIExecuteHead(s.apiMock) runnerManager := &runner.ManagerMock{} @@ -125,8 +124,8 @@ func (s *WebSocketTestSuite) TestWebsocketConnection() { } func (s *WebSocketTestSuite) TestCancelWebSocketConnection() { - executionID := execution.ID("sleeping-execution") - s.runner.Add(executionID, &executionRequestSleep) + executionID := "sleeping-execution" + s.runner.StoreExecution(executionID, &executionRequestSleep) canceled := mockAPIExecuteSleep(s.apiMock) wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID) @@ -156,10 +155,10 @@ func (s *WebSocketTestSuite) TestCancelWebSocketConnection() { } func (s *WebSocketTestSuite) TestWebSocketConnectionTimeout() { - executionID := execution.ID("time-out-execution") + executionID := "time-out-execution" limitExecution := executionRequestSleep limitExecution.TimeLimit = 2 - s.runner.Add(executionID, &limitExecution) + s.runner.StoreExecution(executionID, &limitExecution) canceled := mockAPIExecuteSleep(s.apiMock) wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID) @@ -190,8 +189,8 @@ func (s *WebSocketTestSuite) TestWebSocketConnectionTimeout() { } func (s *WebSocketTestSuite) TestWebsocketStdoutAndStderr() { - executionID := execution.ID("ls-execution") - s.runner.Add(executionID, &executionRequestLs) + executionID := "ls-execution" + s.runner.StoreExecution(executionID, &executionRequestLs) mockAPIExecuteLs(s.apiMock) wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID) @@ -210,8 +209,8 @@ func (s *WebSocketTestSuite) TestWebsocketStdoutAndStderr() { } func (s *WebSocketTestSuite) TestWebsocketError() { - executionID := execution.ID("error-execution") - s.runner.Add(executionID, &executionRequestError) + executionID := "error-execution" + s.runner.StoreExecution(executionID, &executionRequestError) mockAPIExecuteError(s.apiMock) wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID) @@ -229,8 +228,8 @@ func (s *WebSocketTestSuite) TestWebsocketError() { } func (s *WebSocketTestSuite) TestWebsocketNonZeroExit() { - executionID := execution.ID("exit-execution") - s.runner.Add(executionID, &executionRequestExitNonZero) + executionID := "exit-execution" + s.runner.StoreExecution(executionID, &executionRequestExitNonZero) mockAPIExecuteExitNonZero(s.apiMock) wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID) @@ -251,8 +250,8 @@ func TestWebsocketTLS(t *testing.T) { runnerID := "runner-id" r, apiMock := newNomadAllocationWithMockedAPIClient(runnerID) - executionID := execution.ID("execution-id") - r.Add(executionID, &executionRequestLs) + executionID := "execution-id" + r.StoreExecution(executionID, &executionRequestLs) mockAPIExecuteLs(apiMock) runnerManager := &runner.ManagerMock{} @@ -380,7 +379,7 @@ func newNomadAllocationWithMockedAPIClient(runnerID string) (runner.Runner, *nom } func webSocketURL(scheme string, server *httptest.Server, router *mux.Router, - runnerID string, executionID execution.ID, + runnerID string, executionID string, ) (*url.URL, error) { websocketURL, err := url.Parse(server.URL) if err != nil { @@ -396,7 +395,7 @@ func webSocketURL(scheme string, server *httptest.Server, router *mux.Router, return websocketURL, nil } -func (s *WebSocketTestSuite) webSocketURL(scheme, runnerID string, executionID execution.ID) (*url.URL, error) { +func (s *WebSocketTestSuite) webSocketURL(scheme, runnerID, executionID string) (*url.URL, error) { return webSocketURL(scheme, s.server, s.router, runnerID, executionID) } diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 4a86919..16ee3bb 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -29,26 +29,40 @@ const ( executionTimeoutGracePeriod = 3 * time.Second ) -var ErrorFileCopyFailed = errors.New("file copy failed") +var ( + ErrorUnknownExecution = errors.New("unknown execution") + ErrorFileCopyFailed = errors.New("file copy failed") +) + +type ExitInfo struct { + Code uint8 + Err error +} type Runner interface { + InactivityTimer + // ID returns the id of the runner. ID() string + // MappedPorts returns the mapped ports of the runner. MappedPorts() []*dto.MappedPort - execution.Storer - InactivityTimer + // StoreExecution adds a new execution to the runner that can then be executed using ExecuteInteractively. + StoreExecution(id string, executionRequest *dto.ExecutionRequest) + + // ExecutionExists returns whether the execution with the given id is already stored. + ExecutionExists(id string) bool // ExecuteInteractively runs the given execution request and forwards from and to the given reader and writers. // An ExitInfo is sent to the exit channel on command completion. // Output from the runner is forwarded immediately. ExecuteInteractively( - request *dto.ExecutionRequest, + id string, stdin io.ReadWriter, stdout, stderr io.Writer, - ) (exit <-chan ExitInfo, cancel context.CancelFunc) + ) (exit <-chan ExitInfo, cancel context.CancelFunc, err error) // UpdateFileSystem processes a dto.UpdateFileSystemRequest by first deleting each given dto.FilePath recursively // and then copying each given dto.File to the runner. @@ -60,8 +74,8 @@ type Runner interface { // NomadJob is an abstraction to communicate with Nomad environments. type NomadJob struct { - execution.Storer InactivityTimer + executions execution.Storer id string portMappings []nomadApi.PortMapping api nomad.ExecutorAPI @@ -76,7 +90,7 @@ func NewNomadJob(id string, portMappings []nomadApi.PortMapping, id: id, portMappings: portMappings, api: apiClient, - Storer: execution.NewLocalStorage(), + executions: execution.NewLocalStorage(), manager: manager, } job.InactivityTimer = NewInactivityTimer(job, manager) @@ -98,9 +112,74 @@ func (r *NomadJob) MappedPorts() []*dto.MappedPort { return ports } -type ExitInfo struct { - Code uint8 - Err error +func (r *NomadJob) StoreExecution(id string, request *dto.ExecutionRequest) { + r.executions.Add(execution.ID(id), request) +} + +func (r *NomadJob) ExecutionExists(id string) bool { + return r.executions.Exists(execution.ID(id)) +} + +func (r *NomadJob) ExecuteInteractively( + id string, + stdin io.ReadWriter, + stdout, stderr io.Writer, +) (<-chan ExitInfo, context.CancelFunc, error) { + request, ok := r.executions.Pop(execution.ID(id)) + if !ok { + return nil, nil, ErrorUnknownExecution + } + + r.ResetTimeout() + + command, ctx, cancel := prepareExecution(request) + exitInternal := make(chan ExitInfo) + exit := make(chan ExitInfo, 1) + ctxExecute, cancelExecute := context.WithCancel(context.Background()) + + go r.executeCommand(ctxExecute, command, stdin, stdout, stderr, exitInternal) + go r.handleExitOrContextDone(ctx, cancelExecute, exitInternal, exit, stdin) + + return exit, cancel, nil +} + +func (r *NomadJob) UpdateFileSystem(copyRequest *dto.UpdateFileSystemRequest) error { + r.ResetTimeout() + + var tarBuffer bytes.Buffer + if err := createTarArchiveForFiles(copyRequest.Copy, &tarBuffer); err != nil { + return err + } + + fileDeletionCommand := fileDeletionCommand(copyRequest.Delete) + copyCommand := "tar --extract --absolute-names --verbose --file=/dev/stdin;" + updateFileCommand := (&dto.ExecutionRequest{Command: fileDeletionCommand + copyCommand}).FullCommand() + stdOut := bytes.Buffer{} + stdErr := bytes.Buffer{} + exitCode, err := r.api.ExecuteCommand(r.id, context.Background(), updateFileCommand, false, + &tarBuffer, &stdOut, &stdErr) + + if err != nil { + return fmt.Errorf( + "%w: nomad error during file copy: %v", + nomad.ErrorExecutorCommunicationFailed, + err) + } + if exitCode != 0 { + return fmt.Errorf( + "%w: stderr output '%s' and stdout output '%s'", + ErrorFileCopyFailed, + stdErr.String(), + stdOut.String()) + } + return nil +} + +func (r *NomadJob) Destroy() error { + if err := r.manager.Return(r); err != nil { + return fmt.Errorf("error while destroying runner: %w", err) + } + return nil } func prepareExecution(request *dto.ExecutionRequest) ( @@ -164,61 +243,6 @@ func (r *NomadJob) handleExitOrContextDone(ctx context.Context, cancelExecute co } } -func (r *NomadJob) ExecuteInteractively( - request *dto.ExecutionRequest, - stdin io.ReadWriter, - stdout, stderr io.Writer, -) (<-chan ExitInfo, context.CancelFunc) { - r.ResetTimeout() - - command, ctx, cancel := prepareExecution(request) - exitInternal := make(chan ExitInfo) - exit := make(chan ExitInfo, 1) - ctxExecute, cancelExecute := context.WithCancel(context.Background()) - go r.executeCommand(ctxExecute, command, stdin, stdout, stderr, exitInternal) - go r.handleExitOrContextDone(ctx, cancelExecute, exitInternal, exit, stdin) - return exit, cancel -} - -func (r *NomadJob) UpdateFileSystem(copyRequest *dto.UpdateFileSystemRequest) error { - r.ResetTimeout() - - var tarBuffer bytes.Buffer - if err := createTarArchiveForFiles(copyRequest.Copy, &tarBuffer); err != nil { - return err - } - - fileDeletionCommand := fileDeletionCommand(copyRequest.Delete) - copyCommand := "tar --extract --absolute-names --verbose --file=/dev/stdin;" - updateFileCommand := (&dto.ExecutionRequest{Command: fileDeletionCommand + copyCommand}).FullCommand() - stdOut := bytes.Buffer{} - stdErr := bytes.Buffer{} - exitCode, err := r.api.ExecuteCommand(r.id, context.Background(), updateFileCommand, false, - &tarBuffer, &stdOut, &stdErr) - - if err != nil { - return fmt.Errorf( - "%w: nomad error during file copy: %v", - nomad.ErrorExecutorCommunicationFailed, - err) - } - if exitCode != 0 { - return fmt.Errorf( - "%w: stderr output '%s' and stdout output '%s'", - ErrorFileCopyFailed, - stdErr.String(), - stdOut.String()) - } - return nil -} - -func (r *NomadJob) Destroy() error { - if err := r.manager.Return(r); err != nil { - return fmt.Errorf("error while destroying runner: %w", err) - } - return nil -} - func createTarArchiveForFiles(filesToCopy []dto.File, w io.Writer) error { tarWriter := tar.NewWriter(w) for _, file := range filesToCopy { diff --git a/internal/runner/runner_mock.go b/internal/runner/runner_mock.go index a558189..2ed7c44 100644 --- a/internal/runner/runner_mock.go +++ b/internal/runner/runner_mock.go @@ -4,7 +4,6 @@ package runner import ( context "context" - "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/execution" io "io" dto "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto" @@ -19,11 +18,6 @@ type RunnerMock struct { mock.Mock } -// Add provides a mock function with given fields: id, executionRequest -func (_m *RunnerMock) Add(id execution.ID, executionRequest *dto.ExecutionRequest) { - _m.Called(id, executionRequest) -} - // Destroy provides a mock function with given fields: func (_m *RunnerMock) Destroy() error { ret := _m.Called() @@ -38,13 +32,13 @@ func (_m *RunnerMock) Destroy() error { return r0 } -// ExecuteInteractively provides a mock function with given fields: request, stdin, stdout, stderr -func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin io.ReadWriter, stdout io.Writer, stderr io.Writer) (<-chan ExitInfo, context.CancelFunc) { - ret := _m.Called(request, stdin, stdout, stderr) +// ExecuteInteractively provides a mock function with given fields: id, stdin, stdout, stderr +func (_m *RunnerMock) ExecuteInteractively(id string, stdin io.ReadWriter, stdout io.Writer, stderr io.Writer) (<-chan ExitInfo, context.CancelFunc, error) { + ret := _m.Called(id, stdin, stdout, stderr) var r0 <-chan ExitInfo - if rf, ok := ret.Get(0).(func(*dto.ExecutionRequest, io.ReadWriter, io.Writer, io.Writer) <-chan ExitInfo); ok { - r0 = rf(request, stdin, stdout, stderr) + if rf, ok := ret.Get(0).(func(string, io.ReadWriter, io.Writer, io.Writer) <-chan ExitInfo); ok { + r0 = rf(id, stdin, stdout, stderr) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(<-chan ExitInfo) @@ -52,15 +46,36 @@ func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin } var r1 context.CancelFunc - if rf, ok := ret.Get(1).(func(*dto.ExecutionRequest, io.ReadWriter, io.Writer, io.Writer) context.CancelFunc); ok { - r1 = rf(request, stdin, stdout, stderr) + if rf, ok := ret.Get(1).(func(string, io.ReadWriter, io.Writer, io.Writer) context.CancelFunc); ok { + r1 = rf(id, stdin, stdout, stderr) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(context.CancelFunc) } } - return r0, r1 + var r2 error + if rf, ok := ret.Get(2).(func(string, io.ReadWriter, io.Writer, io.Writer) error); ok { + r2 = rf(id, stdin, stdout, stderr) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// ExecutionExists provides a mock function with given fields: id +func (_m *RunnerMock) ExecutionExists(id string) bool { + ret := _m.Called(id) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(id) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 } // ID provides a mock function with given fields: @@ -93,29 +108,6 @@ func (_m *RunnerMock) MappedPorts() []*dto.MappedPort { return r0 } -// Pop provides a mock function with given fields: id -func (_m *RunnerMock) Pop(id execution.ID) (*dto.ExecutionRequest, bool) { - ret := _m.Called(id) - - var r0 *dto.ExecutionRequest - if rf, ok := ret.Get(0).(func(execution.ID) *dto.ExecutionRequest); ok { - r0 = rf(id) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*dto.ExecutionRequest) - } - } - - var r1 bool - if rf, ok := ret.Get(1).(func(execution.ID) bool); ok { - r1 = rf(id) - } else { - r1 = ret.Get(1).(bool) - } - - return r0, r1 -} - // ResetTimeout provides a mock function with given fields: func (_m *RunnerMock) ResetTimeout() { _m.Called() @@ -131,6 +123,11 @@ func (_m *RunnerMock) StopTimeout() { _m.Called() } +// StoreExecution provides a mock function with given fields: id, executionRequest +func (_m *RunnerMock) StoreExecution(id string, executionRequest *dto.ExecutionRequest) { + _m.Called(id, executionRequest) +} + // TimeoutPassed provides a mock function with given fields: func (_m *RunnerMock) TimeoutPassed() bool { ret := _m.Called() diff --git a/internal/runner/runner_test.go b/internal/runner/runner_test.go index 7e90393..7443162 100644 --- a/internal/runner/runner_test.go +++ b/internal/runner/runner_test.go @@ -22,6 +22,8 @@ import ( "time" ) +const defaultExecutionID = "execution-id" + func TestIdIsStored(t *testing.T) { runner := NewNomadJob(tests.DefaultJobID, nil, nil, nil) assert.Equal(t, tests.DefaultJobID, runner.ID()) @@ -49,9 +51,9 @@ func TestExecutionRequestIsStored(t *testing.T) { TimeLimit: 10, Environment: nil, } - id := execution.ID("test-execution") - runner.Add(id, executionRequest) - storedExecutionRunner, ok := runner.Pop(id) + id := "test-execution" + runner.StoreExecution(id, executionRequest) + storedExecutionRunner, ok := runner.executions.Pop(execution.ID(id)) assert.True(t, ok, "Getting an execution should not return ok false") assert.Equal(t, executionRequest, storedExecutionRunner) @@ -119,7 +121,7 @@ func (s *ExecuteInteractivelyTestSuite) SetupTest() { s.manager.On("Return", mock.Anything).Return(nil) s.runner = &NomadJob{ - Storer: execution.NewLocalStorage(), + executions: execution.NewLocalStorage(), InactivityTimer: s.timer, id: tests.DefaultRunnerID, api: s.apiMock, @@ -127,9 +129,16 @@ func (s *ExecuteInteractivelyTestSuite) SetupTest() { } } +func (s *ExecuteInteractivelyTestSuite) TestReturnsErrorWhenExecutionDoesNotExist() { + _, _, err := s.runner.ExecuteInteractively("non-existent-id", nil, nil, nil) + s.ErrorIs(err, ErrorUnknownExecution) +} + func (s *ExecuteInteractivelyTestSuite) TestCallsApi() { request := &dto.ExecutionRequest{Command: "echo 'Hello World!'"} - s.runner.ExecuteInteractively(request, nil, nil, nil) + s.runner.StoreExecution(defaultExecutionID, request) + _, _, err := s.runner.ExecuteInteractively(defaultExecutionID, nil, nil, nil) + s.Require().NoError(err) time.Sleep(tests.ShortTimeout) s.apiMock.AssertCalled(s.T(), "ExecuteCommand", tests.DefaultRunnerID, mock.Anything, request.FullCommand(), @@ -143,7 +152,9 @@ func (s *ExecuteInteractivelyTestSuite) TestReturnsAfterTimeout() { timeLimit := 1 executionRequest := &dto.ExecutionRequest{TimeLimit: timeLimit} - exit, _ := s.runner.ExecuteInteractively(executionRequest, &nullio.ReadWriter{}, nil, nil) + s.runner.StoreExecution(defaultExecutionID, executionRequest) + exit, _, err := s.runner.ExecuteInteractively(defaultExecutionID, &nullio.ReadWriter{}, nil, nil) + s.Require().NoError(err) select { case <-exit: @@ -172,7 +183,9 @@ func (s *ExecuteInteractivelyTestSuite) TestSendsSignalAfterTimeout() { }).Return(0, nil) timeLimit := 1 executionRequest := &dto.ExecutionRequest{TimeLimit: timeLimit} - _, _ = s.runner.ExecuteInteractively(executionRequest, bytes.NewBuffer(make([]byte, 1)), nil, nil) + s.runner.StoreExecution(defaultExecutionID, executionRequest) + _, _, err := s.runner.ExecuteInteractively(defaultExecutionID, bytes.NewBuffer(make([]byte, 1)), nil, nil) + s.Require().NoError(err) select { case <-time.After(2 * (time.Duration(timeLimit) * time.Second)): s.FailNow("The execution should receive a SIGQUIT after the timeout") @@ -186,21 +199,27 @@ func (s *ExecuteInteractivelyTestSuite) TestDestroysRunnerAfterTimeoutAndSignal( }) timeLimit := 1 executionRequest := &dto.ExecutionRequest{TimeLimit: timeLimit} - _, _ = s.runner.ExecuteInteractively(executionRequest, bytes.NewBuffer(make([]byte, 1)), nil, nil) + s.runner.StoreExecution(defaultExecutionID, executionRequest) + _, _, err := s.runner.ExecuteInteractively(defaultExecutionID, bytes.NewBuffer(make([]byte, 1)), nil, nil) + s.Require().NoError(err) <-time.After(executionTimeoutGracePeriod + time.Duration(timeLimit)*time.Second + tests.ShortTimeout) s.manager.AssertCalled(s.T(), "Return", s.runner) } func (s *ExecuteInteractivelyTestSuite) TestResetTimerGetsCalled() { executionRequest := &dto.ExecutionRequest{} - s.runner.ExecuteInteractively(executionRequest, nil, nil, nil) + s.runner.StoreExecution(defaultExecutionID, executionRequest) + _, _, err := s.runner.ExecuteInteractively(defaultExecutionID, nil, nil, nil) + s.Require().NoError(err) s.timer.AssertCalled(s.T(), "ResetTimeout") } func (s *ExecuteInteractivelyTestSuite) TestExitHasTimeoutErrorIfRunnerTimesOut() { s.mockedTimeoutPassedCall.Return(true) executionRequest := &dto.ExecutionRequest{} - exitChannel, _ := s.runner.ExecuteInteractively(executionRequest, &nullio.ReadWriter{}, nil, nil) + s.runner.StoreExecution(defaultExecutionID, executionRequest) + exitChannel, _, err := s.runner.ExecuteInteractively(defaultExecutionID, &nullio.ReadWriter{}, nil, nil) + s.Require().NoError(err) exit := <-exitChannel s.Equal(ErrorRunnerInactivityTimeout, exit.Err) } @@ -225,7 +244,7 @@ func (s *UpdateFileSystemTestSuite) SetupTest() { s.timer.On("ResetTimeout").Return() s.timer.On("TimeoutPassed").Return(false) s.runner = &NomadJob{ - Storer: execution.NewLocalStorage(), + executions: execution.NewLocalStorage(), InactivityTimer: s.timer, id: tests.DefaultRunnerID, api: s.apiMock, diff --git a/internal/runner/storage_test.go b/internal/runner/storage_test.go index d1471c4..e4307fc 100644 --- a/internal/runner/storage_test.go +++ b/internal/runner/storage_test.go @@ -20,7 +20,7 @@ type RunnerPoolTestSuite struct { func (s *RunnerPoolTestSuite) SetupTest() { s.runnerStorage = NewLocalRunnerStorage() s.runner = NewRunner(tests.DefaultRunnerID, nil) - s.runner.Add(tests.DefaultExecutionID, &dto.ExecutionRequest{Command: "true"}) + s.runner.StoreExecution(tests.DefaultExecutionID, &dto.ExecutionRequest{Command: "true"}) } func (s *RunnerPoolTestSuite) TestAddedRunnerCanBeRetrieved() { diff --git a/pkg/execution/execution.go b/pkg/execution/execution.go index 5a59e8e..04afd93 100644 --- a/pkg/execution/execution.go +++ b/pkg/execution/execution.go @@ -13,6 +13,9 @@ type Storer interface { // It overwrites the existing execution if an execution with the same id already exists. Add(id ID, executionRequest *dto.ExecutionRequest) + // Exists returns whether the execution with the given id exists in the store. + Exists(id ID) bool + // Pop deletes the execution with the given id from the storage and returns it. // If no such execution exists, ok is false and true otherwise. Pop(id ID) (request *dto.ExecutionRequest, ok bool) diff --git a/pkg/execution/local_storage.go b/pkg/execution/local_storage.go index 7e756d1..74ff95c 100644 --- a/pkg/execution/local_storage.go +++ b/pkg/execution/local_storage.go @@ -26,6 +26,13 @@ func (s *localStorage) Add(id ID, executionRequest *dto.ExecutionRequest) { s.executions[id] = executionRequest } +func (s *localStorage) Exists(id ID) bool { + s.Lock() + defer s.Unlock() + _, ok := s.executions[id] + return ok +} + func (s *localStorage) Pop(id ID) (*dto.ExecutionRequest, bool) { s.Lock() defer s.Unlock()