diff --git a/.golangci.yaml b/.golangci.yaml index c620bd6..798c154 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -93,3 +93,7 @@ issues: - path: _test\.go linters: - noctx + # Always closing the HTTP body unnecessarily complicates the tests + - bodyclose + # We don't need to wrap errors in tests + - wrapcheck \ No newline at end of file diff --git a/api/api_test.go b/api/api_test.go index 37a9e87..69c8cc5 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -56,7 +56,8 @@ func TestNewRouterV1WithAuthenticationEnabled(t *testing.T) { }) t.Run("protected route is not accessible", func(t *testing.T) { - // request an available API route that should be guarded by authentication (which one, in particular, does not matter here) + // request an available API route that should be guarded by authentication. + // (which one, in particular, does not matter here) request, err := http.NewRequest(http.MethodPost, "/api/v1/runners", nil) if err != nil { t.Fatal(err) diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index 4f1f05b..de51502 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -20,53 +20,53 @@ type AuthenticationMiddlewareTestSuite struct { httpAuthenticationMiddleware http.Handler } -func (suite *AuthenticationMiddlewareTestSuite) SetupTest() { +func (s *AuthenticationMiddlewareTestSuite) SetupTest() { correctAuthenticationToken = []byte(testToken) - suite.recorder = httptest.NewRecorder() + s.recorder = httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/api/v1/test", nil) if err != nil { - suite.T().Fatal(err) + s.T().Fatal(err) } - suite.request = request - suite.httpAuthenticationMiddleware = HTTPAuthenticationMiddleware( + s.request = request + s.httpAuthenticationMiddleware = HTTPAuthenticationMiddleware( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) } -func (suite *AuthenticationMiddlewareTestSuite) TearDownTest() { +func (s *AuthenticationMiddlewareTestSuite) TearDownTest() { correctAuthenticationToken = []byte(nil) } -func (suite *AuthenticationMiddlewareTestSuite) TestReturns401WhenHeaderUnset() { - suite.httpAuthenticationMiddleware.ServeHTTP(suite.recorder, suite.request) - assert.Equal(suite.T(), http.StatusUnauthorized, suite.recorder.Code) +func (s *AuthenticationMiddlewareTestSuite) TestReturns401WhenHeaderUnset() { + s.httpAuthenticationMiddleware.ServeHTTP(s.recorder, s.request) + assert.Equal(s.T(), http.StatusUnauthorized, s.recorder.Code) } -func (suite *AuthenticationMiddlewareTestSuite) TestReturns401WhenTokenWrong() { - suite.request.Header.Set(TokenHeader, "Wr0ngT0k3n") - suite.httpAuthenticationMiddleware.ServeHTTP(suite.recorder, suite.request) - assert.Equal(suite.T(), http.StatusUnauthorized, suite.recorder.Code) +func (s *AuthenticationMiddlewareTestSuite) TestReturns401WhenTokenWrong() { + s.request.Header.Set(TokenHeader, "Wr0ngT0k3n") + s.httpAuthenticationMiddleware.ServeHTTP(s.recorder, s.request) + assert.Equal(s.T(), http.StatusUnauthorized, s.recorder.Code) } -func (suite *AuthenticationMiddlewareTestSuite) TestWarnsWhenUnauthorized() { +func (s *AuthenticationMiddlewareTestSuite) TestWarnsWhenUnauthorized() { var hook *test.Hook logger, hook := test.NewNullLogger() log = logger.WithField("pkg", "api/auth") - suite.request.Header.Set(TokenHeader, "Wr0ngT0k3n") - suite.httpAuthenticationMiddleware.ServeHTTP(suite.recorder, suite.request) + s.request.Header.Set(TokenHeader, "Wr0ngT0k3n") + s.httpAuthenticationMiddleware.ServeHTTP(s.recorder, s.request) - assert.Equal(suite.T(), http.StatusUnauthorized, suite.recorder.Code) - assert.Equal(suite.T(), logrus.WarnLevel, hook.LastEntry().Level) - assert.Equal(suite.T(), hook.LastEntry().Data["token"], "Wr0ngT0k3n") + assert.Equal(s.T(), http.StatusUnauthorized, s.recorder.Code) + assert.Equal(s.T(), logrus.WarnLevel, hook.LastEntry().Level) + assert.Equal(s.T(), hook.LastEntry().Data["token"], "Wr0ngT0k3n") } -func (suite *AuthenticationMiddlewareTestSuite) TestPassesWhenTokenCorrect() { - suite.request.Header.Set(TokenHeader, testToken) - suite.httpAuthenticationMiddleware.ServeHTTP(suite.recorder, suite.request) +func (s *AuthenticationMiddlewareTestSuite) TestPassesWhenTokenCorrect() { + s.request.Header.Set(TokenHeader, testToken) + s.httpAuthenticationMiddleware.ServeHTTP(s.recorder, s.request) - assert.Equal(suite.T(), http.StatusOK, suite.recorder.Code) + assert.Equal(s.T(), http.StatusOK, s.recorder.Code) } func TestHTTPAuthenticationMiddleware(t *testing.T) { diff --git a/api/dto/dto.go b/api/dto/dto.go index 6d64284..126f1f2 100644 --- a/api/dto/dto.go +++ b/api/dto/dto.go @@ -10,7 +10,7 @@ import ( // RunnerRequest is the expected json structure of the request body for the ProvideRunner function. type RunnerRequest struct { - ExecutionEnvironmentId int `json:"executionEnvironmentId"` + ExecutionEnvironmentID int `json:"executionEnvironmentId"` InactivityTimeout int `json:"inactivityTimeout"` } @@ -22,7 +22,7 @@ type ExecutionRequest struct { } func (er *ExecutionRequest) FullCommand() []string { - var command []string + command := make([]string, 0) command = append(command, "env", "-") for variable, value := range er.Environment { command = append(command, fmt.Sprintf("%s=%s", variable, value)) @@ -31,7 +31,8 @@ func (er *ExecutionRequest) FullCommand() []string { return command } -// ExecutionEnvironmentRequest is the expected json structure of the request body for the create execution environment function. +// ExecutionEnvironmentRequest is the expected json structure of the request body +// for the create execution environment function. type ExecutionEnvironmentRequest struct { PrewarmingPoolSize uint `json:"prewarmingPoolSize"` CPULimit uint `json:"cpuLimit"` @@ -43,12 +44,12 @@ type ExecutionEnvironmentRequest struct { // RunnerResponse is the expected response when providing a runner. type RunnerResponse struct { - Id string `json:"runnerId"` + ID string `json:"runnerId"` } // ExecutionResponse is the expected response when creating an execution for a runner. type ExecutionResponse struct { - WebSocketUrl string `json:"websocketUrl"` + WebSocketURL string `json:"websocketUrl"` } // UpdateFileSystemRequest is the expected json structure of the request body for the update file system route. @@ -102,6 +103,13 @@ const ( WebSocketExit WebSocketMessageType = "exit" ) +var ( + ErrUnknownWebSocketMessageType = errors.New("unknown WebSocket message type") + ErrMissingType = errors.New("type is missing") + ErrMissingData = errors.New("data is missing") + ErrInvalidType = errors.New("invalid type") +) + // WebSocketMessage is the type for all messages send in the WebSocket to the client. // Depending on the MessageType the Data or ExitCode might not be included in the marshaled json message. type WebSocketMessage struct { @@ -112,24 +120,29 @@ type WebSocketMessage struct { // MarshalJSON implements the json.Marshaler interface. // This converts the WebSocketMessage into the expected schema (see docs/websocket.schema.json). -func (m WebSocketMessage) MarshalJSON() ([]byte, error) { +func (m WebSocketMessage) MarshalJSON() (res []byte, err error) { switch m.Type { case WebSocketOutputStdout, WebSocketOutputStderr, WebSocketOutputError: - return json.Marshal(struct { + res, err = json.Marshal(struct { MessageType WebSocketMessageType `json:"type"` Data string `json:"data"` }{m.Type, m.Data}) case WebSocketMetaStart, WebSocketMetaTimeout: - return json.Marshal(struct { + res, err = json.Marshal(struct { MessageType WebSocketMessageType `json:"type"` }{m.Type}) case WebSocketExit: - return json.Marshal(struct { + res, err = json.Marshal(struct { MessageType WebSocketMessageType `json:"type"` ExitCode uint8 `json:"data"` }{m.Type, m.ExitCode}) } - return nil, errors.New("unhandled WebSocket message type") + if err != nil { + return nil, fmt.Errorf("error marshaling WebSocketMessage: %w", err) + } else if res == nil { + return nil, ErrUnknownWebSocketMessageType + } + return res, nil } // UnmarshalJSON implements the json.Unmarshaler interface. @@ -138,47 +151,47 @@ func (m *WebSocketMessage) UnmarshalJSON(rawMessage []byte) error { messageMap := make(map[string]interface{}) err := json.Unmarshal(rawMessage, &messageMap) if err != nil { - return err + return fmt.Errorf("error unmarshiling raw WebSocket message: %w", err) } messageType, ok := messageMap["type"] if !ok { - return errors.New("missing key type") + return ErrMissingType } messageTypeString, ok := messageType.(string) if !ok { - return errors.New("value of key type must be a string") + return fmt.Errorf("value of key type must be a string: %w", ErrInvalidType) } switch messageType := WebSocketMessageType(messageTypeString); messageType { case WebSocketExit: data, ok := messageMap["data"] if !ok { - return errors.New("missing key data") + return ErrMissingData } // json.Unmarshal converts any number to a float64 in the massageMap, so we must first cast it to the float. exit, ok := data.(float64) if !ok { - return errors.New("value of key data must be a number") + return fmt.Errorf("value of key data must be a number: %w", ErrInvalidType) } if exit != float64(uint8(exit)) { - return errors.New("value of key data must be uint8") + return fmt.Errorf("value of key data must be uint8: %w", ErrInvalidType) } m.Type = messageType m.ExitCode = uint8(exit) case WebSocketOutputStdout, WebSocketOutputStderr, WebSocketOutputError: data, ok := messageMap["data"] if !ok { - return errors.New("missing key data") + return ErrMissingData } text, ok := data.(string) if !ok { - return errors.New("value of key data must be a string") + return fmt.Errorf("value of key data must be a string: %w", ErrInvalidType) } m.Type = messageType m.Data = text case WebSocketMetaStart, WebSocketMetaTimeout: m.Type = messageType default: - return errors.New("unknown WebSocket message type") + return ErrUnknownWebSocketMessageType } return nil } diff --git a/api/environments.go b/api/environments.go index 4a9a5a9..2f324c8 100644 --- a/api/environments.go +++ b/api/environments.go @@ -2,6 +2,7 @@ package api import ( "encoding/json" + "errors" "fmt" "github.com/gorilla/mux" "gitlab.hpi.de/codeocean/codemoon/poseidon/api/dto" @@ -15,6 +16,8 @@ const ( createOrUpdateRouteName = "createOrUpdate" ) +var ErrMissingURLParameter = errors.New("url parameter missing") + type EnvironmentController struct { manager environment.Manager } @@ -35,7 +38,7 @@ func (e *EnvironmentController) createOrUpdate(writer http.ResponseWriter, reque id, ok := mux.Vars(request)[executionEnvironmentIDKey] if !ok { - writeBadRequest(writer, fmt.Errorf("could not find %s", executionEnvironmentIDKey)) + writeBadRequest(writer, fmt.Errorf("could not find %s: %w", executionEnvironmentIDKey, ErrMissingURLParameter)) return } environmentID, err := runner.NewEnvironmentID(id) @@ -54,8 +57,3 @@ func (e *EnvironmentController) createOrUpdate(writer http.ResponseWriter, reque writer.WriteHeader(http.StatusNoContent) } } - -// delete removes an execution environment from the executor -func (e *EnvironmentController) delete(writer http.ResponseWriter, request *http.Request) { // nolint:unused ToDo - -} diff --git a/api/helpers.go b/api/helpers.go index ee03969..89b72db 100644 --- a/api/helpers.go +++ b/api/helpers.go @@ -2,23 +2,24 @@ package api import ( "encoding/json" + "fmt" "gitlab.hpi.de/codeocean/codemoon/poseidon/api/dto" "net/http" ) func writeInternalServerError(writer http.ResponseWriter, err error, errorCode dto.ErrorCode) { - sendJson(writer, &dto.InternalServerError{Message: err.Error(), ErrorCode: errorCode}, http.StatusInternalServerError) + sendJSON(writer, &dto.InternalServerError{Message: err.Error(), ErrorCode: errorCode}, http.StatusInternalServerError) } func writeBadRequest(writer http.ResponseWriter, err error) { - sendJson(writer, &dto.ClientError{Message: err.Error()}, http.StatusBadRequest) + sendJSON(writer, &dto.ClientError{Message: err.Error()}, http.StatusBadRequest) } func writeNotFound(writer http.ResponseWriter, err error) { - sendJson(writer, &dto.ClientError{Message: err.Error()}, http.StatusNotFound) + sendJSON(writer, &dto.ClientError{Message: err.Error()}, http.StatusNotFound) } -func sendJson(writer http.ResponseWriter, content interface{}, httpStatusCode int) { +func sendJSON(writer http.ResponseWriter, content interface{}, httpStatusCode int) { writer.Header().Set("Content-Type", "application/json") writer.WriteHeader(httpStatusCode) response, err := json.Marshal(content) @@ -36,7 +37,7 @@ func sendJson(writer http.ResponseWriter, content interface{}, httpStatusCode in func parseJSONRequestBody(writer http.ResponseWriter, request *http.Request, structure interface{}) error { if err := json.NewDecoder(request.Body).Decode(structure); err != nil { writeBadRequest(writer, err) - return err + return fmt.Errorf("error parsing JSON request body: %w", err) } return nil } diff --git a/api/runners.go b/api/runners.go index 4b34e25..9259171 100644 --- a/api/runners.go +++ b/api/runners.go @@ -1,6 +1,7 @@ package api import ( + "errors" "fmt" "github.com/google/uuid" "github.com/gorilla/mux" @@ -16,8 +17,8 @@ const ( WebsocketPath = "/websocket" UpdateFileSystemPath = "/files" DeleteRoute = "deleteRunner" - RunnerIdKey = "runnerId" - ExecutionIdKey = "executionId" + RunnerIDKey = "runnerId" + ExecutionIDKey = "executionID" ProvideRoute = "provideRunner" ) @@ -30,9 +31,10 @@ type RunnerController struct { func (r *RunnerController) ConfigureRoutes(router *mux.Router) { runnersRouter := router.PathPrefix(RunnersPath).Subrouter() runnersRouter.HandleFunc("", r.provide).Methods(http.MethodPost).Name(ProvideRoute) - r.runnerRouter = runnersRouter.PathPrefix(fmt.Sprintf("/{%s}", RunnerIdKey)).Subrouter() + r.runnerRouter = runnersRouter.PathPrefix(fmt.Sprintf("/{%s}", RunnerIDKey)).Subrouter() r.runnerRouter.Use(r.findRunnerMiddleware) - r.runnerRouter.HandleFunc(UpdateFileSystemPath, r.updateFileSystem).Methods(http.MethodPatch).Name(UpdateFileSystemPath) + r.runnerRouter.HandleFunc(UpdateFileSystemPath, r.updateFileSystem).Methods(http.MethodPatch). + Name(UpdateFileSystemPath) r.runnerRouter.HandleFunc(ExecutePath, r.execute).Methods(http.MethodPost).Name(ExecutePath) r.runnerRouter.HandleFunc(WebsocketPath, r.connectToRunner).Methods(http.MethodGet).Name(WebsocketPath) r.runnerRouter.HandleFunc("", r.delete).Methods(http.MethodDelete).Name(DeleteRoute) @@ -46,21 +48,21 @@ func (r *RunnerController) provide(writer http.ResponseWriter, request *http.Req if err := parseJSONRequestBody(writer, request, runnerRequest); err != nil { return } - environmentId := runner.EnvironmentID(runnerRequest.ExecutionEnvironmentId) - nextRunner, err := r.manager.Claim(environmentId, runnerRequest.InactivityTimeout) + environmentID := runner.EnvironmentID(runnerRequest.ExecutionEnvironmentID) + nextRunner, err := r.manager.Claim(environmentID, runnerRequest.InactivityTimeout) if err != nil { switch err { case runner.ErrUnknownExecutionEnvironment: writeNotFound(writer, err) case runner.ErrNoRunnersAvailable: - log.WithField("environment", environmentId).Warn("No runners available") + log.WithField("environment", environmentID).Warn("No runners available") writeInternalServerError(writer, err, dto.ErrorNomadOverload) default: writeInternalServerError(writer, err, dto.ErrorUnknown) } return } - sendJson(writer, &dto.RunnerResponse{Id: nextRunner.Id()}, http.StatusOK) + sendJSON(writer, &dto.RunnerResponse{ID: nextRunner.ID()}, http.StatusOK) } // updateFileSystem handles the files API route. @@ -98,36 +100,36 @@ func (r *RunnerController) execute(writer http.ResponseWriter, request *http.Req } targetRunner, _ := runner.FromContext(request.Context()) - path, err := r.runnerRouter.Get(WebsocketPath).URL(RunnerIdKey, targetRunner.Id()) + path, err := r.runnerRouter.Get(WebsocketPath).URL(RunnerIDKey, targetRunner.ID()) if err != nil { log.WithError(err).Error("Could not create runner websocket URL.") writeInternalServerError(writer, err, dto.ErrorUnknown) return } - newUuid, err := uuid.NewRandom() + newUUID, err := uuid.NewRandom() if err != nil { log.WithError(err).Error("Could not create execution id") writeInternalServerError(writer, err, dto.ErrorUnknown) return } - id := runner.ExecutionId(newUuid.String()) + id := runner.ExecutionID(newUUID.String()) targetRunner.Add(id, executionRequest) - webSocketUrl := url.URL{ + webSocketURL := url.URL{ Scheme: scheme, Host: request.Host, Path: path.String(), - RawQuery: fmt.Sprintf("%s=%s", ExecutionIdKey, id), + RawQuery: fmt.Sprintf("%s=%s", ExecutionIDKey, id), } - sendJson(writer, &dto.ExecutionResponse{WebSocketUrl: webSocketUrl.String()}, http.StatusOK) + sendJSON(writer, &dto.ExecutionResponse{WebSocketURL: webSocketURL.String()}, http.StatusOK) } // The findRunnerMiddleware looks up the runnerId for routes containing it // and adds the runner to the context of the request. func (r *RunnerController) findRunnerMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - runnerId := mux.Vars(request)[RunnerIdKey] - targetRunner, err := r.manager.Get(runnerId) + runnerID := mux.Vars(request)[RunnerIDKey] + targetRunner, err := r.manager.Get(runnerID) if err != nil { writeNotFound(writer, err) return @@ -145,7 +147,7 @@ func (r *RunnerController) delete(writer http.ResponseWriter, request *http.Requ err := r.manager.Return(targetRunner) if err != nil { - if err == runner.ErrUnknownExecutionEnvironment { + if errors.Is(err, runner.ErrUnknownExecutionEnvironment) { writeNotFound(writer, err) } else { writeInternalServerError(writer, err, dto.ErrorNomadInternalServerError) diff --git a/api/runners_test.go b/api/runners_test.go index 36e6d6a..d4f9818 100644 --- a/api/runners_test.go +++ b/api/runners_test.go @@ -3,7 +3,6 @@ package api import ( "bytes" "encoding/json" - "errors" "fmt" "github.com/gorilla/mux" "github.com/stretchr/testify/mock" @@ -32,7 +31,7 @@ func (s *MiddlewareTestSuite) SetupTest() { s.runner = runner.NewNomadJob(tests.DefaultRunnerID, nil, nil) s.capturedRunner = nil s.runnerRequest = func(runnerId string) *http.Request { - path, err := s.router.Get("test-runner-id").URL(RunnerIdKey, runnerId) + path, err := s.router.Get("test-runner-id").URL(RunnerIDKey, runnerId) s.Require().NoError(err) request, err := http.NewRequest(http.MethodPost, path.String(), nil) s.Require().NoError(err) @@ -50,7 +49,7 @@ func (s *MiddlewareTestSuite) SetupTest() { s.router = mux.NewRouter() runnerController := &RunnerController{s.manager, s.router} s.router.Use(runnerController.findRunnerMiddleware) - s.router.HandleFunc(fmt.Sprintf("/test/{%s}", RunnerIdKey), runnerRouteHandler).Name("test-runner-id") + s.router.HandleFunc(fmt.Sprintf("/test/{%s}", RunnerIDKey), runnerRouteHandler).Name("test-runner-id") } func TestMiddlewareTestSuite(t *testing.T) { @@ -58,10 +57,10 @@ func TestMiddlewareTestSuite(t *testing.T) { } func (s *MiddlewareTestSuite) TestFindRunnerMiddlewareIfRunnerExists() { - s.manager.On("Get", s.runner.Id()).Return(s.runner, nil) + s.manager.On("Get", s.runner.ID()).Return(s.runner, nil) recorder := httptest.NewRecorder() - s.router.ServeHTTP(recorder, s.runnerRequest(s.runner.Id())) + s.router.ServeHTTP(recorder, s.runnerRequest(s.runner.ID())) s.Equal(http.StatusOK, recorder.Code) s.Equal(s.runner, s.capturedRunner) @@ -86,16 +85,16 @@ type RunnerRouteTestSuite struct { runnerManager *runner.ManagerMock router *mux.Router runner runner.Runner - executionId runner.ExecutionId + executionID runner.ExecutionID } func (s *RunnerRouteTestSuite) SetupTest() { s.runnerManager = &runner.ManagerMock{} s.router = NewRouter(s.runnerManager, nil) s.runner = runner.NewNomadJob("some-id", nil, nil) - s.executionId = "execution-id" - s.runner.Add(s.executionId, &dto.ExecutionRequest{}) - s.runnerManager.On("Get", s.runner.Id()).Return(s.runner, nil) + s.executionID = "execution-id" + s.runner.Add(s.executionID, &dto.ExecutionRequest{}) + s.runnerManager.On("Get", s.runner.ID()).Return(s.runner, nil) } func TestProvideRunnerTestSuite(t *testing.T) { @@ -115,7 +114,7 @@ func (s *ProvideRunnerTestSuite) SetupTest() { s.Require().NoError(err) s.path = path.String() - runnerRequest := dto.RunnerRequest{ExecutionEnvironmentId: tests.DefaultEnvironmentIDAsInteger} + runnerRequest := dto.RunnerRequest{ExecutionEnvironmentID: tests.DefaultEnvironmentIDAsInteger} body, err := json.Marshal(runnerRequest) s.Require().NoError(err) s.defaultRequest, err = http.NewRequest(http.MethodPost, s.path, bytes.NewReader(body)) @@ -123,7 +122,8 @@ func (s *ProvideRunnerTestSuite) SetupTest() { } func (s *ProvideRunnerTestSuite) TestValidRequestReturnsRunner() { - s.runnerManager.On("Claim", mock.AnythingOfType("runner.EnvironmentID"), mock.AnythingOfType("int")).Return(s.runner, nil) + s.runnerManager.On("Claim", mock.AnythingOfType("runner.EnvironmentID"), + mock.AnythingOfType("int")).Return(s.runner, nil) recorder := httptest.NewRecorder() s.router.ServeHTTP(recorder, s.defaultRequest) @@ -134,7 +134,7 @@ func (s *ProvideRunnerTestSuite) TestValidRequestReturnsRunner() { err := json.NewDecoder(recorder.Result().Body).Decode(&runnerResponse) s.Require().NoError(err) _ = recorder.Result().Body.Close() - s.Equal(s.runner.Id(), runnerResponse.Id) + s.Equal(s.runner.ID(), runnerResponse.ID) }) } @@ -173,7 +173,7 @@ func (s *ProvideRunnerTestSuite) TestWhenNoRunnerAvailableReturnsNomadOverload() } func (s *RunnerRouteTestSuite) TestExecuteRoute() { - path, err := s.router.Get(ExecutePath).URL(RunnerIdKey, s.runner.Id()) + path, err := s.router.Get(ExecutePath).URL(RunnerIDKey, s.runner.ID()) s.Require().NoError(err) s.Run("valid request", func() { @@ -197,12 +197,12 @@ func (s *RunnerRouteTestSuite) TestExecuteRoute() { s.Equal(http.StatusOK, recorder.Code) s.Run("creates an execution request for the runner", func() { - webSocketUrl, err := url.Parse(webSocketResponse.WebSocketUrl) + webSocketURL, err := url.Parse(webSocketResponse.WebSocketURL) s.Require().NoError(err) - executionId := webSocketUrl.Query().Get(ExecutionIdKey) - storedExecutionRequest, ok := s.runner.Pop(runner.ExecutionId(executionId)) + executionID := webSocketURL.Query().Get(ExecutionIDKey) + storedExecutionRequest, ok := s.runner.Pop(runner.ExecutionID(executionID)) - s.True(ok, "No execution request with this id: ", executionId) + s.True(ok, "No execution request with this id: ", executionID) s.Equal(&executionRequest, storedExecutionRequest) }) }) @@ -231,9 +231,9 @@ type UpdateFileSystemRouteTestSuite struct { func (s *UpdateFileSystemRouteTestSuite) SetupTest() { s.RunnerRouteTestSuite.SetupTest() - routeUrl, err := s.router.Get(UpdateFileSystemPath).URL(RunnerIdKey, tests.DefaultMockID) + routeURL, err := s.router.Get(UpdateFileSystemPath).URL(RunnerIDKey, tests.DefaultMockID) s.Require().NoError(err) - s.path = routeUrl.String() + s.path = routeURL.String() s.runnerMock = &runner.RunnerMock{} s.runnerManager.On("Get", tests.DefaultMockID).Return(s.runnerMock, nil) s.recorder = httptest.NewRecorder() @@ -243,8 +243,10 @@ func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemReturnsNoContentOnV s.runnerMock.On("UpdateFileSystem", mock.AnythingOfType("*dto.UpdateFileSystemRequest")).Return(nil) copyRequest := dto.UpdateFileSystemRequest{} - body, _ := json.Marshal(copyRequest) - request, _ := http.NewRequest(http.MethodPatch, s.path, bytes.NewReader(body)) + body, err := json.Marshal(copyRequest) + s.Require().NoError(err) + request, err := http.NewRequest(http.MethodPatch, s.path, bytes.NewReader(body)) + s.Require().NoError(err) s.router.ServeHTTP(s.recorder, request) s.Equal(http.StatusNoContent, s.recorder.Code) @@ -252,7 +254,8 @@ func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemReturnsNoContentOnV } func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemReturnsBadRequestOnInvalidRequestBody() { - request, _ := http.NewRequest(http.MethodPatch, s.path, strings.NewReader("")) + request, err := http.NewRequest(http.MethodPatch, s.path, strings.NewReader("")) + s.Require().NoError(err) s.router.ServeHTTP(s.recorder, request) s.Equal(http.StatusBadRequest, s.recorder.Code) @@ -261,10 +264,13 @@ func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemReturnsBadRequestOn func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemToNonExistingRunnerReturnsNotFound() { invalidID := "some-invalid-runner-id" s.runnerManager.On("Get", invalidID).Return(nil, runner.ErrRunnerNotFound) - path, _ := s.router.Get(UpdateFileSystemPath).URL(RunnerIdKey, invalidID) + path, err := s.router.Get(UpdateFileSystemPath).URL(RunnerIDKey, invalidID) + s.Require().NoError(err) copyRequest := dto.UpdateFileSystemRequest{} - body, _ := json.Marshal(copyRequest) - request, _ := http.NewRequest(http.MethodPatch, path.String(), bytes.NewReader(body)) + body, err := json.Marshal(copyRequest) + s.Require().NoError(err) + request, err := http.NewRequest(http.MethodPatch, path.String(), bytes.NewReader(body)) + s.Require().NoError(err) s.router.ServeHTTP(s.recorder, request) s.Equal(http.StatusNotFound, s.recorder.Code) @@ -276,8 +282,10 @@ func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemReturnsInternalServ Return(runner.ErrorFileCopyFailed) copyRequest := dto.UpdateFileSystemRequest{} - body, _ := json.Marshal(copyRequest) - request, _ := http.NewRequest(http.MethodPatch, s.path, bytes.NewReader(body)) + body, err := json.Marshal(copyRequest) + s.Require().NoError(err) + request, err := http.NewRequest(http.MethodPatch, s.path, bytes.NewReader(body)) + s.Require().NoError(err) s.router.ServeHTTP(s.recorder, request) s.Equal(http.StatusInternalServerError, s.recorder.Code) @@ -294,7 +302,7 @@ type DeleteRunnerRouteTestSuite struct { func (s *DeleteRunnerRouteTestSuite) SetupTest() { s.RunnerRouteTestSuite.SetupTest() - deleteURL, err := s.router.Get(DeleteRoute).URL(RunnerIdKey, s.runner.Id()) + deleteURL, err := s.router.Get(DeleteRoute).URL(RunnerIDKey, s.runner.ID()) s.Require().NoError(err) s.path = deleteURL.String() } @@ -316,7 +324,7 @@ func (s *DeleteRunnerRouteTestSuite) TestValidRequestReturnsNoContent() { } func (s *DeleteRunnerRouteTestSuite) TestReturnInternalServerErrorWhenApiCallToNomadFailed() { - s.runnerManager.On("Return", s.runner).Return(errors.New("API call failed")) + s.runnerManager.On("Return", s.runner).Return(tests.ErrDefault) recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodDelete, s.path, nil) @@ -328,8 +336,8 @@ func (s *DeleteRunnerRouteTestSuite) TestReturnInternalServerErrorWhenApiCallToN } func (s *DeleteRunnerRouteTestSuite) TestDeleteInvalidRunnerIdReturnsNotFound() { - s.runnerManager.On("Get", mock.AnythingOfType("string")).Return(nil, errors.New("API call failed")) - deleteURL, err := s.router.Get(DeleteRoute).URL(RunnerIdKey, "1nv4l1dID") + s.runnerManager.On("Get", mock.AnythingOfType("string")).Return(nil, tests.ErrDefault) + deleteURL, err := s.router.Get(DeleteRoute).URL(RunnerIDKey, "1nv4l1dID") s.Require().NoError(err) deletePath := deleteURL.String() diff --git a/api/websocket.go b/api/websocket.go index 3cf0f23..418d71f 100644 --- a/api/websocket.go +++ b/api/websocket.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/gorilla/websocket" "gitlab.hpi.de/codeocean/codemoon/poseidon/api/dto" "gitlab.hpi.de/codeocean/codemoon/poseidon/runner" @@ -11,6 +12,10 @@ import ( "net/http" ) +const CodeOceanToRawReaderBufferSize = 1024 + +var ErrUnknownExecutionID = errors.New("execution id unknown") + type webSocketConnection interface { WriteMessage(messageType int, data []byte) error Close() error @@ -21,11 +26,11 @@ type webSocketConnection interface { type WebSocketReader interface { io.Reader - readInputLoop() context.CancelFunc + startReadInputLoop() context.CancelFunc } -// codeOceanToRawReader is an io.Reader implementation that provides the content of the WebSocket connection to CodeOcean. -// You have to start the Reader by calling readInputLoop. After that you can use the Read function. +// codeOceanToRawReader is an io.Reader implementation that provides the content of the WebSocket connection +// to CodeOcean. You have to start the Reader by calling readInputLoop. After that you can use the Read function. type codeOceanToRawReader struct { connection webSocketConnection @@ -38,74 +43,79 @@ type codeOceanToRawReader struct { func newCodeOceanToRawReader(connection webSocketConnection) *codeOceanToRawReader { return &codeOceanToRawReader{ connection: connection, - buffer: make(chan byte, 1024), + buffer: make(chan byte, CodeOceanToRawReaderBufferSize), } } -// readInputLoop asynchronously reads from the WebSocket connection and buffers the user's input. +// readInputLoop reads from the WebSocket connection and buffers the user's input. // This is necessary because input must be read for the connection to handle special messages like close and call the // CloseHandler. -func (cr *codeOceanToRawReader) readInputLoop() context.CancelFunc { - ctx, cancel := context.WithCancel(context.Background()) - go func() { - readMessage := make(chan bool) - for { - var messageType int - var reader io.Reader - var err error +func (cr *codeOceanToRawReader) readInputLoop(ctx context.Context) { + readMessage := make(chan bool) + for { + var messageType int + var reader io.Reader + var err error - go func() { - messageType, reader, err = cr.connection.NextReader() - readMessage <- true - }() + go func() { + messageType, reader, err = cr.connection.NextReader() + readMessage <- true + }() + select { + case <-readMessage: + case <-ctx.Done(): + return + } + + if err != nil { + log.WithError(err).Warn("Error reading client message") + return + } + if messageType != websocket.TextMessage { + log.WithField("messageType", messageType).Warn("Received message of wrong type") + return + } + + message, err := io.ReadAll(reader) + if err != nil { + log.WithError(err).Warn("error while reading WebSocket message") + return + } + for _, character := range message { select { - case <-readMessage: + case cr.buffer <- character: case <-ctx.Done(): return } - - if err != nil { - log.WithError(err).Warn("Error reading client message") - return - } - if messageType != websocket.TextMessage { - log.WithField("messageType", messageType).Warn("Received message of wrong type") - return - } - - message, err := io.ReadAll(reader) - if err != nil { - log.WithError(err).Warn("error while reading WebSocket message") - return - } - for _, character := range message { - select { - case cr.buffer <- character: - case <-ctx.Done(): - return - } - } } - }() + } +} + +// startReadInputLoop start the read input loop asynchronously and returns a context.CancelFunc which can be used +// to cancel the read input loop. +func (cr *codeOceanToRawReader) startReadInputLoop() context.CancelFunc { + ctx, cancel := context.WithCancel(context.Background()) + go cr.readInputLoop(ctx) return cancel } // Read implements the io.Reader interface. // It returns bytes from the buffer. -func (cr *codeOceanToRawReader) Read(p []byte) (n int, err error) { +func (cr *codeOceanToRawReader) Read(p []byte) (int, error) { if len(p) == 0 { - return + return 0, nil } // Ensure to not return until at least one byte has been read to avoid busy waiting. p[0] = <-cr.buffer + var n int for n = 1; n < len(p); n++ { select { case p[n] = <-cr.buffer: default: - return + return n, nil } } - return + return n, nil } // rawToCodeOceanWriter is an io.Writer implementation that, when written to, wraps the written data in the appropriate @@ -137,7 +147,7 @@ func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSo connection, err := connUpgrader.Upgrade(writer, request, nil) if err != nil { log.WithError(err).Warn("Connection upgrade failed") - return nil, err + return nil, fmt.Errorf("error upgrading the connection: %w", err) } return connection, nil } @@ -161,7 +171,7 @@ func newWebSocketProxy(connection webSocketConnection) (*webSocketProxy, error) closeHandler := connection.CloseHandler() connection.SetCloseHandler(func(code int, text string) error { - // The default close handler always returns nil, so the error can be safely ignored. + //nolint:errcheck // The default close handler always returns nil, so the error can be safely ignored. _ = closeHandler(code, text) close(proxy.userExit) return nil @@ -173,7 +183,7 @@ func newWebSocketProxy(connection webSocketConnection) (*webSocketProxy, error) // and handles WebSocket exit messages. func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecution context.CancelFunc) { defer wp.close() - cancelInputLoop := wp.Stdin.readInputLoop() + cancelInputLoop := wp.Stdin.startReadInputLoop() var exitInfo runner.ExitInfo select { case exitInfo = <-exit: @@ -187,12 +197,18 @@ func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecuti } if errors.Is(exitInfo.Err, context.DeadlineExceeded) || errors.Is(exitInfo.Err, runner.ErrorRunnerInactivityTimeout) { - _ = wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout}) + err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout}) + if err != nil { + log.WithError(err).Warn("Failed to send timeout message to client") + } return } else if exitInfo.Err != nil { errorMessage := "Error executing the request" log.WithError(exitInfo.Err).Warn(errorMessage) - _ = wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage}) + err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage}) + if err != nil { + log.WithError(err).Warn("Failed to send output error message to client") + } return } log.WithField("exit_code", exitInfo.Code).Debug() @@ -211,27 +227,29 @@ func (wp *webSocketProxy) sendToClient(message dto.WebSocketMessage) error { if err != nil { log.WithField("message", message).WithError(err).Warn("Marshal error") wp.closeWithError("Error creating message") - return err + return fmt.Errorf("error marshaling WebSocket message: %w", err) } err = wp.connection.WriteMessage(websocket.TextMessage, encodedMessage) if err != nil { errorMessage := "Error writing the exit message" log.WithError(err).Warn(errorMessage) wp.closeWithError(errorMessage) - return err + return fmt.Errorf("error writing WebSocket message: %w", err) } return nil } func (wp *webSocketProxy) closeWithError(message string) { - err := wp.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseInternalServerErr, message)) + err := wp.connection.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseInternalServerErr, message)) if err != nil { log.WithError(err).Warn("Error during websocket close") } } func (wp *webSocketProxy) close() { - err := wp.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + err := wp.connection.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) _ = wp.connection.Close() if err != nil { log.WithError(err).Warn("Error during websocket close") @@ -241,10 +259,10 @@ func (wp *webSocketProxy) close() { // connectToRunner is the endpoint for websocket connections. func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request *http.Request) { targetRunner, _ := runner.FromContext(request.Context()) - executionId := runner.ExecutionId(request.URL.Query().Get(ExecutionIdKey)) - executionRequest, ok := targetRunner.Pop(executionId) + executionID := runner.ExecutionID(request.URL.Query().Get(ExecutionIDKey)) + executionRequest, ok := targetRunner.Pop(executionID) if !ok { - writeNotFound(writer, errors.New("executionId does not exist")) + writeNotFound(writer, ErrUnknownExecutionID) return } @@ -258,7 +276,7 @@ func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request * return } - log.WithField("runnerId", targetRunner.Id()).WithField("executionId", executionId).Info("Running execution") + log.WithField("runnerId", targetRunner.ID()).WithField("executionID", executionID).Info("Running execution") exit, cancel := targetRunner.ExecuteInteractively(executionRequest, proxy.Stdin, proxy.Stdout, proxy.Stderr) proxy.waitForExit(exit, cancel) diff --git a/api/websocket_test.go b/api/websocket_test.go index 2df427b..52a6cb2 100644 --- a/api/websocket_test.go +++ b/api/websocket_test.go @@ -5,7 +5,6 @@ import ( "context" "crypto/tls" "encoding/json" - "errors" "fmt" "github.com/gorilla/mux" "github.com/gorilla/websocket" @@ -34,144 +33,146 @@ func TestWebSocketTestSuite(t *testing.T) { type WebSocketTestSuite struct { suite.Suite router *mux.Router - executionId runner.ExecutionId + executionID runner.ExecutionID runner runner.Runner apiMock *nomad.ExecutorAPIMock server *httptest.Server } -func (suite *WebSocketTestSuite) SetupTest() { - runnerId := "runner-id" - suite.runner, suite.apiMock = newNomadAllocationWithMockedApiClient(runnerId) +func (s *WebSocketTestSuite) SetupTest() { + runnerID := "runner-id" + s.runner, s.apiMock = newNomadAllocationWithMockedAPIClient(runnerID) // default execution - suite.executionId = "execution-id" - suite.runner.Add(suite.executionId, &executionRequestHead) - mockApiExecuteHead(suite.apiMock) + s.executionID = "execution-id" + s.runner.Add(s.executionID, &executionRequestHead) + mockAPIExecuteHead(s.apiMock) runnerManager := &runner.ManagerMock{} - runnerManager.On("Get", suite.runner.Id()).Return(suite.runner, nil) - suite.router = NewRouter(runnerManager, nil) - suite.server = httptest.NewServer(suite.router) + runnerManager.On("Get", s.runner.ID()).Return(s.runner, nil) + s.router = NewRouter(runnerManager, nil) + s.server = httptest.NewServer(s.router) } -func (suite *WebSocketTestSuite) TearDownTest() { - suite.server.Close() +func (s *WebSocketTestSuite) TearDownTest() { + s.server.Close() } -func (suite *WebSocketTestSuite) TestWebsocketConnectionCanBeEstablished() { - wsUrl, err := suite.webSocketUrl("ws", suite.runner.Id(), suite.executionId) - suite.Require().NoError(err) - _, _, err = websocket.DefaultDialer.Dial(wsUrl.String(), nil) - suite.Require().NoError(err) +func (s *WebSocketTestSuite) TestWebsocketConnectionCanBeEstablished() { + wsURL, err := s.webSocketURL("ws", s.runner.ID(), s.executionID) + s.Require().NoError(err) + _, _, err = websocket.DefaultDialer.Dial(wsURL.String(), nil) + s.Require().NoError(err) } -func (suite *WebSocketTestSuite) TestWebsocketReturns404IfExecutionDoesNotExist() { - wsUrl, err := suite.webSocketUrl("ws", suite.runner.Id(), "invalid-execution-id") - suite.Require().NoError(err) - _, response, _ := websocket.DefaultDialer.Dial(wsUrl.String(), nil) - suite.Equal(http.StatusNotFound, response.StatusCode) +func (s *WebSocketTestSuite) TestWebsocketReturns404IfExecutionDoesNotExist() { + wsURL, err := s.webSocketURL("ws", s.runner.ID(), "invalid-execution-id") + s.Require().NoError(err) + _, response, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + s.Require().ErrorIs(err, websocket.ErrBadHandshake) + s.Equal(http.StatusNotFound, response.StatusCode) } -func (suite *WebSocketTestSuite) TestWebsocketReturns400IfRequestedViaHttp() { - wsUrl, err := suite.webSocketUrl("http", suite.runner.Id(), suite.executionId) - suite.Require().NoError(err) - response, err := http.Get(wsUrl.String()) - suite.Require().NoError(err) +func (s *WebSocketTestSuite) TestWebsocketReturns400IfRequestedViaHttp() { + wsURL, err := s.webSocketURL("http", s.runner.ID(), s.executionID) + s.Require().NoError(err) + response, err := http.Get(wsURL.String()) + s.Require().NoError(err) // This functionality is implemented by the WebSocket library. - suite.Equal(http.StatusBadRequest, response.StatusCode) + s.Equal(http.StatusBadRequest, response.StatusCode) } -func (suite *WebSocketTestSuite) TestWebsocketConnection() { - wsUrl, err := suite.webSocketUrl("ws", suite.runner.Id(), suite.executionId) - suite.Require().NoError(err) - connection, _, err := websocket.DefaultDialer.Dial(wsUrl.String(), nil) - suite.Require().NoError(err) +func (s *WebSocketTestSuite) TestWebsocketConnection() { + wsURL, err := s.webSocketURL("ws", s.runner.ID(), s.executionID) + s.Require().NoError(err) + connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + s.Require().NoError(err) err = connection.SetReadDeadline(time.Now().Add(5 * time.Second)) - suite.Require().NoError(err) + s.Require().NoError(err) - suite.Run("Receives start message", func() { + s.Run("Receives start message", func() { message, err := helpers.ReceiveNextWebSocketMessage(connection) - suite.Require().NoError(err) - suite.Equal(dto.WebSocketMetaStart, message.Type) + s.Require().NoError(err) + s.Equal(dto.WebSocketMetaStart, message.Type) }) - suite.Run("Executes the request in the runner", func() { + s.Run("Executes the request in the runner", func() { <-time.After(100 * time.Millisecond) - suite.apiMock.AssertCalled(suite.T(), "ExecuteCommand", + s.apiMock.AssertCalled(s.T(), "ExecuteCommand", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything) }) - suite.Run("Can send input", func() { + s.Run("Can send input", func() { err = connection.WriteMessage(websocket.TextMessage, []byte("Hello World\n")) - suite.Require().NoError(err) + s.Require().NoError(err) }) messages, err := helpers.ReceiveAllWebSocketMessages(connection) - suite.Require().Error(err) - suite.Equal(&websocket.CloseError{Code: websocket.CloseNormalClosure}, err) + s.Require().Error(err) + s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure)) - suite.Run("Receives output message", func() { + s.Run("Receives output message", func() { stdout, _, _ := helpers.WebSocketOutputMessages(messages) - suite.Equal("Hello World", stdout) + s.Equal("Hello World", stdout) }) - suite.Run("Receives exit message", func() { + s.Run("Receives exit message", func() { controlMessages := helpers.WebSocketControlMessages(messages) - suite.Require().Equal(1, len(controlMessages)) - suite.Equal(dto.WebSocketExit, controlMessages[0].Type) + s.Require().Equal(1, len(controlMessages)) + s.Equal(dto.WebSocketExit, controlMessages[0].Type) }) } -func (suite *WebSocketTestSuite) TestCancelWebSocketConnection() { - executionId := runner.ExecutionId("sleeping-execution") - suite.runner.Add(executionId, &executionRequestSleep) - canceled := mockApiExecuteSleep(suite.apiMock) +func (s *WebSocketTestSuite) TestCancelWebSocketConnection() { + executionID := runner.ExecutionID("sleeping-execution") + s.runner.Add(executionID, &executionRequestSleep) + canceled := mockAPIExecuteSleep(s.apiMock) - wsUrl, err := webSocketUrl("ws", suite.server, suite.router, suite.runner.Id(), executionId) - suite.Require().NoError(err) - connection, _, err := websocket.DefaultDialer.Dial(wsUrl.String(), nil) - suite.Require().NoError(err) + wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID) + s.Require().NoError(err) + connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + s.Require().NoError(err) message, err := helpers.ReceiveNextWebSocketMessage(connection) - suite.Require().NoError(err) - suite.Equal(dto.WebSocketMetaStart, message.Type) + s.Require().NoError(err) + s.Equal(dto.WebSocketMetaStart, message.Type) select { case <-canceled: - suite.Fail("ExecuteInteractively canceled unexpected") + s.Fail("ExecuteInteractively canceled unexpected") default: } - err = connection.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) - suite.Require().NoError(err) + err = connection.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + s.Require().NoError(err) select { case <-canceled: case <-time.After(time.Second): - suite.Fail("ExecuteInteractively not canceled") + s.Fail("ExecuteInteractively not canceled") } } -func (suite *WebSocketTestSuite) TestWebSocketConnectionTimeout() { - executionId := runner.ExecutionId("time-out-execution") +func (s *WebSocketTestSuite) TestWebSocketConnectionTimeout() { + executionID := runner.ExecutionID("time-out-execution") limitExecution := executionRequestSleep limitExecution.TimeLimit = 2 - suite.runner.Add(executionId, &limitExecution) - canceled := mockApiExecuteSleep(suite.apiMock) + s.runner.Add(executionID, &limitExecution) + canceled := mockAPIExecuteSleep(s.apiMock) - wsUrl, err := webSocketUrl("ws", suite.server, suite.router, suite.runner.Id(), executionId) - suite.Require().NoError(err) - connection, _, err := websocket.DefaultDialer.Dial(wsUrl.String(), nil) - suite.Require().NoError(err) + wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID) + s.Require().NoError(err) + connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + s.Require().NoError(err) message, err := helpers.ReceiveNextWebSocketMessage(connection) - suite.Require().NoError(err) - suite.Equal(dto.WebSocketMetaStart, message.Type) + s.Require().NoError(err) + s.Equal(dto.WebSocketMetaStart, message.Type) select { case <-canceled: - suite.Fail("ExecuteInteractively canceled unexpected") + s.Fail("ExecuteInteractively canceled unexpected") case <-time.After(time.Duration(limitExecution.TimeLimit-1) * time.Second): <-time.After(time.Second) } @@ -179,94 +180,94 @@ func (suite *WebSocketTestSuite) TestWebSocketConnectionTimeout() { select { case <-canceled: case <-time.After(time.Second): - suite.Fail("ExecuteInteractively not canceled") + s.Fail("ExecuteInteractively not canceled") } message, err = helpers.ReceiveNextWebSocketMessage(connection) - suite.Require().NoError(err) - suite.Equal(dto.WebSocketMetaTimeout, message.Type) + s.Require().NoError(err) + s.Equal(dto.WebSocketMetaTimeout, message.Type) } -func (suite *WebSocketTestSuite) TestWebsocketStdoutAndStderr() { - executionId := runner.ExecutionId("ls-execution") - suite.runner.Add(executionId, &executionRequestLs) - mockApiExecuteLs(suite.apiMock) +func (s *WebSocketTestSuite) TestWebsocketStdoutAndStderr() { + executionID := runner.ExecutionID("ls-execution") + s.runner.Add(executionID, &executionRequestLs) + mockAPIExecuteLs(s.apiMock) - wsUrl, err := webSocketUrl("ws", suite.server, suite.router, suite.runner.Id(), executionId) - suite.Require().NoError(err) - connection, _, err := websocket.DefaultDialer.Dial(wsUrl.String(), nil) - suite.Require().NoError(err) + wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID) + s.Require().NoError(err) + connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + s.Require().NoError(err) messages, err := helpers.ReceiveAllWebSocketMessages(connection) - suite.Require().Error(err) - suite.Equal(&websocket.CloseError{Code: websocket.CloseNormalClosure}, err) + s.Require().Error(err) + s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure)) stdout, stderr, _ := helpers.WebSocketOutputMessages(messages) - suite.Contains(stdout, "existing-file") + s.Contains(stdout, "existing-file") - suite.Contains(stderr, "non-existing-file") + s.Contains(stderr, "non-existing-file") } -func (suite *WebSocketTestSuite) TestWebsocketError() { - executionId := runner.ExecutionId("error-execution") - suite.runner.Add(executionId, &executionRequestError) - mockApiExecuteError(suite.apiMock) +func (s *WebSocketTestSuite) TestWebsocketError() { + executionID := runner.ExecutionID("error-execution") + s.runner.Add(executionID, &executionRequestError) + mockAPIExecuteError(s.apiMock) - wsUrl, err := webSocketUrl("ws", suite.server, suite.router, suite.runner.Id(), executionId) - suite.Require().NoError(err) - connection, _, err := websocket.DefaultDialer.Dial(wsUrl.String(), nil) - suite.Require().NoError(err) + wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID) + s.Require().NoError(err) + connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + s.Require().NoError(err) messages, err := helpers.ReceiveAllWebSocketMessages(connection) - suite.Require().Error(err) - suite.Equal(&websocket.CloseError{Code: websocket.CloseNormalClosure}, err) + s.Require().Error(err) + s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure)) _, _, errMessages := helpers.WebSocketOutputMessages(messages) - suite.Equal(1, len(errMessages)) - suite.Equal("Error executing the request", errMessages[0]) + s.Equal(1, len(errMessages)) + s.Equal("Error executing the request", errMessages[0]) } -func (suite *WebSocketTestSuite) TestWebsocketNonZeroExit() { - executionId := runner.ExecutionId("exit-execution") - suite.runner.Add(executionId, &executionRequestExitNonZero) - mockApiExecuteExitNonZero(suite.apiMock) +func (s *WebSocketTestSuite) TestWebsocketNonZeroExit() { + executionID := runner.ExecutionID("exit-execution") + s.runner.Add(executionID, &executionRequestExitNonZero) + mockAPIExecuteExitNonZero(s.apiMock) - wsUrl, err := webSocketUrl("ws", suite.server, suite.router, suite.runner.Id(), executionId) - suite.Require().NoError(err) - connection, _, err := websocket.DefaultDialer.Dial(wsUrl.String(), nil) - suite.Require().NoError(err) + wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID) + s.Require().NoError(err) + connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + s.Require().NoError(err) messages, err := helpers.ReceiveAllWebSocketMessages(connection) - suite.Require().Error(err) - suite.Equal(&websocket.CloseError{Code: websocket.CloseNormalClosure}, err) + s.Require().Error(err) + s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure)) controlMessages := helpers.WebSocketControlMessages(messages) - suite.Equal(2, len(controlMessages)) - suite.Equal(&dto.WebSocketMessage{Type: dto.WebSocketExit, ExitCode: 42}, controlMessages[1]) + s.Equal(2, len(controlMessages)) + s.Equal(&dto.WebSocketMessage{Type: dto.WebSocketExit, ExitCode: 42}, controlMessages[1]) } func TestWebsocketTLS(t *testing.T) { - runnerId := "runner-id" - r, apiMock := newNomadAllocationWithMockedApiClient(runnerId) + runnerID := "runner-id" + r, apiMock := newNomadAllocationWithMockedAPIClient(runnerID) - executionId := runner.ExecutionId("execution-id") - r.Add(executionId, &executionRequestLs) - mockApiExecuteLs(apiMock) + executionID := runner.ExecutionID("execution-id") + r.Add(executionID, &executionRequestLs) + mockAPIExecuteLs(apiMock) runnerManager := &runner.ManagerMock{} - runnerManager.On("Get", r.Id()).Return(r, nil) + runnerManager.On("Get", r.ID()).Return(r, nil) router := NewRouter(runnerManager, nil) server, err := helpers.StartTLSServer(t, router) require.NoError(t, err) defer server.Close() - wsUrl, err := webSocketUrl("wss", server, router, runnerId, executionId) + wsURL, err := webSocketURL("wss", server, router, runnerID, executionID) require.NoError(t, err) - config := &tls.Config{RootCAs: nil, InsecureSkipVerify: true} + config := &tls.Config{RootCAs: nil, InsecureSkipVerify: true} //nolint:gosec // test needs self-signed cert d := websocket.Dialer{TLSClientConfig: config} - connection, _, err := d.Dial(wsUrl.String(), nil) + connection, _, err := d.Dial(wsURL.String(), nil) require.NoError(t, err) message, err := helpers.ReceiveNextWebSocketMessage(connection) @@ -274,7 +275,7 @@ func TestWebsocketTLS(t *testing.T) { assert.Equal(t, dto.WebSocketMetaStart, message.Type) _, err = helpers.ReceiveAllWebSocketMessages(connection) require.Error(t, err) - assert.Equal(t, &websocket.CloseError{Code: websocket.CloseNormalClosure}, err) + assert.True(t, websocket.IsCloseError(err, websocket.CloseNormalClosure)) } func TestRawToCodeOceanWriter(t *testing.T) { @@ -284,7 +285,9 @@ func TestRawToCodeOceanWriter(t *testing.T) { connectionMock := &webSocketConnectionMock{} connectionMock.On("WriteMessage", mock.AnythingOfType("int"), mock.AnythingOfType("[]uint8")). Run(func(args mock.Arguments) { - message = args.Get(1).([]byte) + var ok bool + message, ok = args.Get(1).([]byte) + require.True(t, ok) }). Return(nil) connectionMock.On("CloseHandler").Return(nil) @@ -300,10 +303,11 @@ func TestRawToCodeOceanWriter(t *testing.T) { _, err = writer.Write([]byte(testMessage)) require.NoError(t, err) - expected, _ := json.Marshal(struct { + expected, err := json.Marshal(struct { Type string `json:"type"` Data string `json:"data"` }{string(dto.WebSocketOutputStdout), testMessage}) + require.NoError(t, err) assert.Equal(t, expected, message) } @@ -312,8 +316,10 @@ func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasRead(t *testing.T) { read := make(chan bool) go func() { + //nolint:makezero // we can't make zero initial length here as the reader otherwise doesn't block p := make([]byte, 10) - _, _ = reader.Read(p) + _, err := reader.Read(p) + require.NoError(t, err) read <- true }() @@ -340,13 +346,15 @@ func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *tes }) reader := newCodeOceanToRawReader(connection) - cancel := reader.readInputLoop() + cancel := reader.startReadInputLoop() defer cancel() read := make(chan bool) + //nolint:makezero // this is required here to make the Read call blocking message := make([]byte, 10) go func() { - _, _ = reader.Read(message) + _, err := reader.Read(message) + require.NoError(t, err) read <- true }() @@ -362,36 +370,39 @@ func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *tes // --- Test suite specific test helpers --- -func newNomadAllocationWithMockedApiClient(runnerId string) (r runner.Runner, mock *nomad.ExecutorAPIMock) { - mock = &nomad.ExecutorAPIMock{} - r = runner.NewNomadJob(runnerId, mock, nil) +func newNomadAllocationWithMockedAPIClient(runnerID string) (r runner.Runner, executorAPIMock *nomad.ExecutorAPIMock) { + executorAPIMock = &nomad.ExecutorAPIMock{} + r = runner.NewNomadJob(runnerID, executorAPIMock, nil) return } -func webSocketUrl(scheme string, server *httptest.Server, router *mux.Router, runnerId string, executionId runner.ExecutionId) (*url.URL, error) { - websocketUrl, err := url.Parse(server.URL) +func webSocketURL(scheme string, server *httptest.Server, router *mux.Router, + runnerID string, executionID runner.ExecutionID, +) (*url.URL, error) { + websocketURL, err := url.Parse(server.URL) if err != nil { return nil, err } - path, err := router.Get(WebsocketPath).URL(RunnerIdKey, runnerId) + path, err := router.Get(WebsocketPath).URL(RunnerIDKey, runnerID) if err != nil { return nil, err } - websocketUrl.Scheme = scheme - websocketUrl.Path = path.Path - websocketUrl.RawQuery = fmt.Sprintf("executionId=%s", executionId) - return websocketUrl, nil + websocketURL.Scheme = scheme + websocketURL.Path = path.Path + websocketURL.RawQuery = fmt.Sprintf("executionID=%s", executionID) + return websocketURL, nil } -func (suite *WebSocketTestSuite) webSocketUrl(scheme, runnerId string, executionId runner.ExecutionId) (*url.URL, error) { - return webSocketUrl(scheme, suite.server, suite.router, runnerId, executionId) +func (s *WebSocketTestSuite) webSocketURL(scheme, runnerID string, executionID runner.ExecutionID) (*url.URL, error) { + return webSocketURL(scheme, s.server, s.router, runnerID, executionID) } var executionRequestLs = dto.ExecutionRequest{Command: "ls"} -// mockApiExecuteLs mocks the ExecuteCommand of an ExecutorApi to act as if 'ls existing-file non-existing-file' was executed. -func mockApiExecuteLs(api *nomad.ExecutorAPIMock) { - mockApiExecute(api, &executionRequestLs, +// mockAPIExecuteLs mocks the ExecuteCommand of an ExecutorApi to act as if +// 'ls existing-file non-existing-file' was executed. +func mockAPIExecuteLs(api *nomad.ExecutorAPIMock) { + mockAPIExecute(api, &executionRequestLs, func(_ string, _ context.Context, _ []string, _ bool, _ io.Reader, stdout, stderr io.Writer) (int, error) { _, _ = stdout.Write([]byte("existing-file\n")) _, _ = stderr.Write([]byte("ls: cannot access 'non-existing-file': No such file or directory\n")) @@ -401,10 +412,12 @@ func mockApiExecuteLs(api *nomad.ExecutorAPIMock) { var executionRequestHead = dto.ExecutionRequest{Command: "head -n 1"} -// mockApiExecuteHead mocks the ExecuteCommand of an ExecutorApi to act as if 'head -n 1' was executed. -func mockApiExecuteHead(api *nomad.ExecutorAPIMock) { - mockApiExecute(api, &executionRequestHead, - func(_ string, _ context.Context, _ []string, _ bool, stdin io.Reader, stdout io.Writer, stderr io.Writer) (int, error) { +// mockAPIExecuteHead mocks the ExecuteCommand of an ExecutorApi to act as if 'head -n 1' was executed. +func mockAPIExecuteHead(api *nomad.ExecutorAPIMock) { + mockAPIExecute(api, &executionRequestHead, + func(_ string, _ context.Context, _ []string, _ bool, + stdin io.Reader, stdout io.Writer, stderr io.Writer, + ) (int, error) { scanner := bufio.NewScanner(stdin) for !scanner.Scan() { scanner = bufio.NewScanner(stdin) @@ -416,11 +429,13 @@ func mockApiExecuteHead(api *nomad.ExecutorAPIMock) { var executionRequestSleep = dto.ExecutionRequest{Command: "sleep infinity"} -// mockApiExecuteSleep mocks the ExecuteCommand method of an ExecutorAPI to sleep until the execution is canceled. -func mockApiExecuteSleep(api *nomad.ExecutorAPIMock) <-chan bool { +// mockAPIExecuteSleep mocks the ExecuteCommand method of an ExecutorAPI to sleep until the execution is canceled. +func mockAPIExecuteSleep(api *nomad.ExecutorAPIMock) <-chan bool { canceled := make(chan bool, 1) - mockApiExecute(api, &executionRequestSleep, - func(_ string, ctx context.Context, _ []string, _ bool, stdin io.Reader, stdout io.Writer, stderr io.Writer) (int, error) { + mockAPIExecute(api, &executionRequestSleep, + func(_ string, ctx context.Context, _ []string, _ bool, + stdin io.Reader, stdout io.Writer, stderr io.Writer, + ) (int, error) { <-ctx.Done() close(canceled) return 0, ctx.Err() @@ -430,28 +445,30 @@ func mockApiExecuteSleep(api *nomad.ExecutorAPIMock) <-chan bool { var executionRequestError = dto.ExecutionRequest{Command: "error"} -// mockApiExecuteError mocks the ExecuteCommand method of an ExecutorApi to return an error. -func mockApiExecuteError(api *nomad.ExecutorAPIMock) { - mockApiExecute(api, &executionRequestError, +// mockAPIExecuteError mocks the ExecuteCommand method of an ExecutorApi to return an error. +func mockAPIExecuteError(api *nomad.ExecutorAPIMock) { + mockAPIExecute(api, &executionRequestError, func(_ string, _ context.Context, _ []string, _ bool, _ io.Reader, _, _ io.Writer) (int, error) { - return 0, errors.New("intended error") + return 0, tests.ErrDefault }) } var executionRequestExitNonZero = dto.ExecutionRequest{Command: "exit 42"} -// mockApiExecuteExitNonZero mocks the ExecuteCommand method of an ExecutorApi to exit with exit status 42. -func mockApiExecuteExitNonZero(api *nomad.ExecutorAPIMock) { - mockApiExecute(api, &executionRequestExitNonZero, +// mockAPIExecuteExitNonZero mocks the ExecuteCommand method of an ExecutorApi to exit with exit status 42. +func mockAPIExecuteExitNonZero(api *nomad.ExecutorAPIMock) { + mockAPIExecute(api, &executionRequestExitNonZero, func(_ string, _ context.Context, _ []string, _ bool, _ io.Reader, _, _ io.Writer) (int, error) { return 42, nil }) } -// mockApiExecute mocks the ExecuteCommand method of an ExecutorApi to call the given method run when the command +// mockAPIExecute mocks the ExecuteCommand method of an ExecutorApi to call the given method run when the command // corresponding to the given ExecutionRequest is called. -func mockApiExecute(api *nomad.ExecutorAPIMock, request *dto.ExecutionRequest, - run func(runnerId string, ctx context.Context, command []string, tty bool, stdin io.Reader, stdout, stderr io.Writer) (int, error)) { +func mockAPIExecute(api *nomad.ExecutorAPIMock, request *dto.ExecutionRequest, + run func(runnerId string, ctx context.Context, command []string, tty bool, + stdin io.Reader, stdout, stderr io.Writer) (int, error), +) { call := api.On("ExecuteCommand", mock.AnythingOfType("string"), mock.Anything, diff --git a/config/config.go b/config/config.go index d0879d1..3fe89b3 100644 --- a/config/config.go +++ b/config/config.go @@ -5,6 +5,7 @@ import ( "errors" "flag" "fmt" + "github.com/sirupsen/logrus" "gitlab.hpi.de/codeocean/codemoon/poseidon/logging" "gopkg.in/yaml.v3" "net/url" @@ -45,6 +46,7 @@ var ( CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, PreferServerCipherSuites: true, } + ErrConfigInitialized = errors.New("configuration is already initialized") ) // server configures the Poseidon webserver. @@ -85,7 +87,7 @@ type configuration struct { // should be called directly after starting the program. func InitConfig() error { if configurationInitialized { - return errors.New("configuration is already initialized") + return ErrConfigInitialized } configurationInitialized = true content := readConfigFile() @@ -104,9 +106,9 @@ func (c *configuration) PoseidonAPIURL() *url.URL { return parseURL(Config.Server.Address, Config.Server.Port, false) } -func parseURL(address string, port int, tls bool) *url.URL { +func parseURL(address string, port int, tlsEnabled bool) *url.URL { scheme := "http" - if tls { + if tlsEnabled { scheme = "https" } return &url.URL{ @@ -151,33 +153,7 @@ func readFromEnvironment(prefix string, value reflect.Value) { } if value.Kind() != reflect.Struct { - content, ok := os.LookupEnv(prefix) - if !ok { - return - } - logEntry = logEntry.WithField("content", content) - - switch value.Kind() { - case reflect.String: - value.SetString(content) - case reflect.Int: - integer, err := strconv.Atoi(content) - if err != nil { - logEntry.Warn("Could not parse environment variable as integer") - return - } - value.SetInt(int64(integer)) - case reflect.Bool: - boolean, err := strconv.ParseBool(content) - if err != nil { - logEntry.Warn("Could not parse environment variable as boolean") - return - } - value.SetBool(boolean) - default: - // ignore this field - logEntry.WithField("type", value.Type().Name()).Warn("Setting configuration option via environment variables is not supported") - } + loadValue(prefix, value, logEntry) } else { for i := 0; i < value.NumField(); i++ { fieldName := value.Type().Field(i).Name @@ -186,3 +162,34 @@ func readFromEnvironment(prefix string, value reflect.Value) { } } } + +func loadValue(prefix string, value reflect.Value, logEntry *logrus.Entry) { + content, ok := os.LookupEnv(prefix) + if !ok { + return + } + logEntry = logEntry.WithField("content", content) + + switch value.Kind() { + case reflect.String: + value.SetString(content) + case reflect.Int: + integer, err := strconv.Atoi(content) + if err != nil { + logEntry.Warn("Could not parse environment variable as integer") + return + } + value.SetInt(int64(integer)) + case reflect.Bool: + boolean, err := strconv.ParseBool(content) + if err != nil { + logEntry.Warn("Could not parse environment variable as boolean") + return + } + value.SetBool(boolean) + default: + // ignore this field + logEntry.WithField("type", value.Type().Name()). + Warn("Setting configuration option via environment variables is not supported") + } +} diff --git a/config/config_test.go b/config/config_test.go index 7c6db87..4366cd9 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -5,6 +5,7 @@ import ( "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "os" "path/filepath" "reflect" @@ -39,8 +40,9 @@ func (c *configuration) getReflectValue() reflect.Value { return reflect.ValueOf(c).Elem() } -// writeConfigurationFile creates a file on disk and returns the path to it +// writeConfigurationFile creates a file on disk and returns the path to it. func writeConfigurationFile(t *testing.T, name string, content []byte) string { + t.Helper() directory := t.TempDir() filePath := filepath.Join(directory, name) file, err := os.Create(filePath) @@ -48,7 +50,8 @@ func writeConfigurationFile(t *testing.T, name string, content []byte) string { t.Fatal("Could not create config file") } defer file.Close() - _, _ = file.Write(content) + _, err = file.Write(content) + require.NoError(t, err) return filePath } @@ -62,13 +65,15 @@ func TestCallingInitConfigTwiceReturnsError(t *testing.T) { func TestCallingInitConfigTwiceDoesNotChangeConfig(t *testing.T) { configurationInitialized = false - _ = InitConfig() + err := InitConfig() + require.NoError(t, err) Config = newTestConfiguration() filePath := writeConfigurationFile(t, "test.yaml", []byte("server:\n port: 5000\n")) oldArgs := os.Args defer func() { os.Args = oldArgs }() os.Args = append(os.Args, "-config", filePath) - _ = InitConfig() + err = InitConfig() + require.Error(t, err) assert.Equal(t, 3000, Config.Server.Port) } @@ -156,7 +161,8 @@ func TestReadConfigFileOverwritesConfig(t *testing.T) { defer func() { os.Args = oldArgs }() os.Args = append(os.Args, "-config", filePath) configurationInitialized = false - _ = InitConfig() + err := InitConfig() + require.NoError(t, err) assert.Equal(t, 5000, Config.Server.Port) } @@ -166,7 +172,8 @@ func TestReadNonExistingConfigFileDoesNotOverwriteConfig(t *testing.T) { defer func() { os.Args = oldArgs }() os.Args = append(os.Args, "-config", "file_does_not_exist.yaml") configurationInitialized = false - _ = InitConfig() + err := InitConfig() + require.NoError(t, err) assert.Equal(t, 3000, Config.Server.Port) } diff --git a/environment/manager.go b/environment/manager.go index b3e4400..80a39fe 100644 --- a/environment/manager.go +++ b/environment/manager.go @@ -34,19 +34,18 @@ type Manager interface { id runner.EnvironmentID, request dto.ExecutionEnvironmentRequest, ) (bool, error) - - // Delete removes the execution environment with the given id from the executor. - Delete(id string) } -func NewNomadEnvironmentManager(runnerManager runner.Manager, - apiClient nomad.ExecutorAPI) (m *NomadEnvironmentManager) { - m = &NomadEnvironmentManager{runnerManager, apiClient, *parseJob(templateEnvironmentJobHCL)} +func NewNomadEnvironmentManager( + runnerManager runner.Manager, + apiClient nomad.ExecutorAPI, +) *NomadEnvironmentManager { + m := &NomadEnvironmentManager{runnerManager, apiClient, *parseJob(templateEnvironmentJobHCL)} if err := m.Load(); err != nil { log.WithError(err).Error("Error recovering the execution environments") } runnerManager.Load() - return + return m } type NomadEnvironmentManager struct { @@ -64,20 +63,16 @@ func (m *NomadEnvironmentManager) CreateOrUpdate( request.Image, request.NetworkAccess, request.ExposedPorts) if err != nil { - return false, err + return false, fmt.Errorf("error registering template job in API: %w", err) } created, err := m.runnerManager.CreateOrUpdateEnvironment(id, request.PrewarmingPoolSize, templateJob, true) if err != nil { - return created, err + return created, fmt.Errorf("error updating environment in runner manager: %w", err) } return created, nil } -func (m *NomadEnvironmentManager) Delete(id string) { - -} - func (m *NomadEnvironmentManager) Load() error { templateJobs, err := m.api.LoadEnvironmentJobs() if err != nil { diff --git a/environment/manager_test.go b/environment/manager_test.go index bd0c3c9..5958fec 100644 --- a/environment/manager_test.go +++ b/environment/manager_test.go @@ -62,13 +62,6 @@ func (s *CreateOrUpdateTestSuite) mockCreateOrUpdateEnvironment(created bool, er Return(created, err) } -func (s *CreateOrUpdateTestSuite) createJobForRequest() *nomadApi.Job { - return nomad.CreateTemplateJob(&s.manager.templateEnvironmentJob, - runner.TemplateJobID(tests.DefaultEnvironmentIDAsInteger), - s.request.PrewarmingPoolSize, s.request.CPULimit, s.request.MemoryLimit, - s.request.Image, s.request.NetworkAccess, s.request.ExposedPorts) -} - func (s *CreateOrUpdateTestSuite) TestRegistersCorrectTemplateJob() { s.mockRegisterTemplateJob(&nomadApi.Job{}, nil) s.mockCreateOrUpdateEnvironment(true, nil) @@ -86,7 +79,7 @@ func (s *CreateOrUpdateTestSuite) TestReturnsErrorWhenRegisterTemplateJobReturns s.mockRegisterTemplateJob(nil, tests.ErrDefault) created, err := s.manager.CreateOrUpdate(s.environmentID, s.request) - s.Equal(tests.ErrDefault, err) + s.ErrorIs(err, tests.ErrDefault) s.False(created) } @@ -106,20 +99,22 @@ func (s *CreateOrUpdateTestSuite) TestReturnsErrorIfCreatesOrUpdateEnvironmentRe s.mockRegisterTemplateJob(&nomadApi.Job{}, nil) s.mockCreateOrUpdateEnvironment(false, tests.ErrDefault) _, err := s.manager.CreateOrUpdate(runner.EnvironmentID(tests.DefaultEnvironmentIDAsInteger), s.request) - s.Equal(tests.ErrDefault, err) + s.ErrorIs(err, tests.ErrDefault) } func (s *CreateOrUpdateTestSuite) TestReturnsTrueIfCreatesOrUpdateEnvironmentReturnsTrue() { s.mockRegisterTemplateJob(&nomadApi.Job{}, nil) s.mockCreateOrUpdateEnvironment(true, nil) - created, _ := s.manager.CreateOrUpdate(runner.EnvironmentID(tests.DefaultEnvironmentIDAsInteger), s.request) + created, err := s.manager.CreateOrUpdate(runner.EnvironmentID(tests.DefaultEnvironmentIDAsInteger), s.request) + s.Require().NoError(err) s.True(created) } func (s *CreateOrUpdateTestSuite) TestReturnsFalseIfCreatesOrUpdateEnvironmentReturnsFalse() { s.mockRegisterTemplateJob(&nomadApi.Job{}, nil) s.mockCreateOrUpdateEnvironment(false, nil) - created, _ := s.manager.CreateOrUpdate(runner.EnvironmentID(tests.DefaultEnvironmentIDAsInteger), s.request) + created, err := s.manager.CreateOrUpdate(runner.EnvironmentID(tests.DefaultEnvironmentIDAsInteger), s.request) + s.Require().NoError(err) s.False(created) } diff --git a/logging/logging.go b/logging/logging.go index d17b9d0..721bb86 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -2,6 +2,7 @@ package logging import ( "bufio" + "fmt" "github.com/sirupsen/logrus" "net" "net/http" @@ -33,7 +34,7 @@ func GetLogger(pkg string) *logrus.Entry { } // loggingResponseWriter wraps the default http.ResponseWriter and catches the status code -// that is written +// that is written. type loggingResponseWriter struct { http.ResponseWriter statusCode int @@ -49,10 +50,14 @@ func (writer *loggingResponseWriter) WriteHeader(code int) { } func (writer *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return writer.ResponseWriter.(http.Hijacker).Hijack() + conn, rw, err := writer.ResponseWriter.(http.Hijacker).Hijack() + if err != nil { + return conn, nil, fmt.Errorf("hijacking connection failed: %w", err) + } + return conn, rw, nil } -// HTTPLoggingMiddleware returns an http.Handler that logs different information about every request +// HTTPLoggingMiddleware returns an http.Handler that logs different information about every request. func HTTPLoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now().UTC() diff --git a/logging/logging_test.go b/logging/logging_test.go index a5f032e..de30050 100644 --- a/logging/logging_test.go +++ b/logging/logging_test.go @@ -9,7 +9,7 @@ import ( "testing" ) -func mockHttpStatusHandler(status int) http.Handler { +func mockHTTPStatusHandler(status int) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(status) }) @@ -25,7 +25,7 @@ func TestHTTPMiddlewareWarnsWhenInternalServerError(t *testing.T) { t.Fatal(err) } recorder := httptest.NewRecorder() - HTTPLoggingMiddleware(mockHttpStatusHandler(500)).ServeHTTP(recorder, request) + HTTPLoggingMiddleware(mockHTTPStatusHandler(500)).ServeHTTP(recorder, request) assert.Equal(t, 1, len(hook.Entries)) assert.Equal(t, logrus.WarnLevel, hook.LastEntry().Level) @@ -41,7 +41,7 @@ func TestHTTPMiddlewareDebugsWhenStatusOK(t *testing.T) { t.Fatal(err) } recorder := httptest.NewRecorder() - HTTPLoggingMiddleware(mockHttpStatusHandler(200)).ServeHTTP(recorder, request) + HTTPLoggingMiddleware(mockHTTPStatusHandler(200)).ServeHTTP(recorder, request) assert.Equal(t, 1, len(hook.Entries)) assert.Equal(t, logrus.DebugLevel, hook.LastEntry().Level) diff --git a/main.go b/main.go index 2f7fdfc..a8331f7 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "gitlab.hpi.de/codeocean/codemoon/poseidon/api" "gitlab.hpi.de/codeocean/codemoon/poseidon/config" "gitlab.hpi.de/codeocean/codemoon/poseidon/environment" @@ -15,7 +16,10 @@ import ( "time" ) -var log = logging.GetLogger("main") +var ( + gracefulShutdownWait = 15 * time.Second + log = logging.GetLogger("main") +) func runServer(server *http.Server) { log.WithField("address", server.Addr).Info("Starting server") @@ -31,7 +35,7 @@ func runServer(server *http.Server) { err = server.ListenAndServe() } if err != nil { - if err == http.ErrServerClosed { + if errors.Is(err, http.ErrServerClosed) { log.WithError(err).Info("Server closed") } else { log.WithError(err).Fatal("Error during listening and serving") @@ -59,7 +63,7 @@ func initServer() *http.Server { } // shutdownOnOSSignal listens for a signal from the operation system -// When receiving a signal the server shuts down but waits up to 15 seconds to close remaining connections +// When receiving a signal the server shuts down but waits up to 15 seconds to close remaining connections. func shutdownOnOSSignal(server *http.Server) { // wait for SIGINT signals := make(chan os.Signal, 1) @@ -67,9 +71,11 @@ func shutdownOnOSSignal(server *http.Server) { <-signals log.Info("Received SIGINT, shutting down ...") - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownWait) defer cancel() - _ = server.Shutdown(ctx) + if err := server.Shutdown(ctx); err != nil { + log.WithError(err).Warn("error shutting server down") + } } func main() { diff --git a/nomad/api_querier.go b/nomad/api_querier.go index befdd19..93108b1 100644 --- a/nomad/api_querier.go +++ b/nomad/api_querier.go @@ -3,6 +3,7 @@ package nomad import ( "context" "errors" + "fmt" nomadApi "github.com/hashicorp/nomad/api" "io" "net/url" @@ -21,13 +22,13 @@ type apiQuerier interface { LoadJobList() (list []*nomadApi.JobListStub, err error) // JobScale returns the scale of the passed job. - JobScale(jobId string) (jobScale uint, err error) + JobScale(jobID string) (jobScale uint, err error) // SetJobScale sets the scaling count of the passed job to Nomad. - SetJobScale(jobId string, count uint, reason string) (err error) + SetJobScale(jobID string, count uint, reason string) (err error) - // DeleteRunner deletes the runner with the given Id. - DeleteRunner(runnerId string) (err error) + // DeleteRunner deletes the runner with the given ID. + DeleteRunner(runnerID string) (err error) // Execute runs a command in the passed job. Execute(jobID string, ctx context.Context, command []string, tty bool, @@ -63,8 +64,11 @@ func (nc *nomadAPIClient) init(nomadURL *url.URL, nomadNamespace string) (err er TLSConfig: &nomadApi.TLSConfig{}, Namespace: nomadNamespace, }) + if err != nil { + return fmt.Errorf("error creating new Nomad client: %w", err) + } nc.namespace = nomadNamespace - return err + return nil } func (nc *nomadAPIClient) DeleteRunner(runnerID string) (err error) { @@ -74,32 +78,43 @@ func (nc *nomadAPIClient) DeleteRunner(runnerID string) (err error) { func (nc *nomadAPIClient) Execute(runnerID string, ctx context.Context, command []string, tty bool, - stdin io.Reader, stdout, stderr io.Writer) (int, error) { + stdin io.Reader, stdout, stderr io.Writer, +) (int, error) { allocations, _, err := nc.client.Jobs().Allocations(runnerID, false, nil) + if err != nil { + return 1, fmt.Errorf("error retrieving allocations for runner: %w", err) + } if len(allocations) == 0 { return 1, ErrorNoAllocationFound } allocation, _, err := nc.client.Allocations().Info(allocations[0].ID, nil) if err != nil { - return 1, err + return 1, fmt.Errorf("error retrieving allocation info: %w", err) } - return nc.client.Allocations().Exec(ctx, allocation, TaskName, tty, command, stdin, stdout, stderr, nil, nil) + exitCode, err := nc.client.Allocations().Exec(ctx, allocation, TaskName, tty, command, stdin, stdout, stderr, nil, nil) + if err != nil { + return 1, fmt.Errorf("error executing command in allocation: %w", err) + } + return exitCode, nil } -func (nc *nomadAPIClient) listJobs(prefix string) (jobs []*nomadApi.JobListStub, err error) { +func (nc *nomadAPIClient) listJobs(prefix string) ([]*nomadApi.JobListStub, error) { q := nomadApi.QueryOptions{ Namespace: nc.namespace, Prefix: prefix, } - jobs, _, err = nc.client.Jobs().List(&q) - return + jobs, _, err := nc.client.Jobs().List(&q) + if err != nil { + return nil, fmt.Errorf("error listing Nomad jobs: %w", err) + } + return jobs, nil } func (nc *nomadAPIClient) RegisterNomadJob(job *nomadApi.Job) (string, error) { job.Namespace = &nc.namespace resp, _, err := nc.client.Jobs().Register(job, nil) if err != nil { - return "", err + return "", fmt.Errorf("error registering Nomad job: %w", err) } if resp.Warnings != "" { log. @@ -110,26 +125,32 @@ func (nc *nomadAPIClient) RegisterNomadJob(job *nomadApi.Job) (string, error) { return resp.EvalID, nil } -func (nc *nomadAPIClient) EvaluationStream(evalID string, ctx context.Context) (stream <-chan *nomadApi.Events, err error) { - stream, err = nc.client.EventStream().Stream( +func (nc *nomadAPIClient) EvaluationStream(evalID string, ctx context.Context) (<-chan *nomadApi.Events, error) { + stream, err := nc.client.EventStream().Stream( ctx, map[nomadApi.Topic][]string{ nomadApi.TopicEvaluation: {evalID}, }, 0, nc.queryOptions()) - return + if err != nil { + return nil, fmt.Errorf("error retrieving Nomad Evaluation event stream: %w", err) + } + return stream, nil } -func (nc *nomadAPIClient) AllocationStream(ctx context.Context) (stream <-chan *nomadApi.Events, err error) { - stream, err = nc.client.EventStream().Stream( +func (nc *nomadAPIClient) AllocationStream(ctx context.Context) (<-chan *nomadApi.Events, error) { + stream, err := nc.client.EventStream().Stream( ctx, map[nomadApi.Topic][]string{ nomadApi.TopicAllocation: {}, }, 0, nc.queryOptions()) - return + if err != nil { + return nil, fmt.Errorf("error retrieving Nomad Allocation event stream: %w", err) + } + return stream, nil } func (nc *nomadAPIClient) queryOptions() *nomadApi.QueryOptions { @@ -151,14 +172,14 @@ func (nc *nomadAPIClient) LoadJobList() (list []*nomadApi.JobListStub, err error } // JobScale returns the scale of the passed job. -func (nc *nomadAPIClient) JobScale(jobID string) (jobScale uint, err error) { +func (nc *nomadAPIClient) JobScale(jobID string) (uint, error) { status, _, err := nc.client.Jobs().ScaleStatus(jobID, nc.queryOptions()) if err != nil { - return + return 0, fmt.Errorf("error retrieving scale status of job: %w", err) } // ToDo: Consider counting also the placed and desired allocations - jobScale = uint(status.TaskGroups[TaskGroupName].Running) - return + jobScale := uint(status.TaskGroups[TaskGroupName].Running) + return jobScale, nil } // SetJobScale sets the scaling count of the passed job to Nomad. diff --git a/nomad/api_querier_mock.go b/nomad/api_querier_mock.go index b68a8b9..f927ee4 100644 --- a/nomad/api_querier_mock.go +++ b/nomad/api_querier_mock.go @@ -100,7 +100,7 @@ func (_m *apiQuerierMock) Execute(jobID string, ctx context.Context, command []s return r0, r1 } -// JobScale provides a mock function with given fields: jobId +// JobScale provides a mock function with given fields: jobID func (_m *apiQuerierMock) JobScale(jobId string) (uint, error) { ret := _m.Called(jobId) @@ -165,7 +165,7 @@ func (_m *apiQuerierMock) RegisterNomadJob(job *api.Job) (string, error) { return r0, r1 } -// SetJobScale provides a mock function with given fields: jobId, count, reason +// SetJobScale provides a mock function with given fields: jobID, count, reason func (_m *apiQuerierMock) SetJobScale(jobId string, count uint, reason string) error { ret := _m.Called(jobId, count, reason) diff --git a/nomad/executor_api_mock.go b/nomad/executor_api_mock.go index 4026ab3..abaae17 100644 --- a/nomad/executor_api_mock.go +++ b/nomad/executor_api_mock.go @@ -121,7 +121,7 @@ func (_m *ExecutorAPIMock) ExecuteCommand(allocationID string, ctx context.Conte return r0, r1 } -// JobScale provides a mock function with given fields: jobId +// JobScale provides a mock function with given fields: jobID func (_m *ExecutorAPIMock) JobScale(jobId string) (uint, error) { ret := _m.Called(jobId) @@ -320,7 +320,7 @@ func (_m *ExecutorAPIMock) RegisterTemplateJob(defaultJob *api.Job, id string, p return r0, r1 } -// SetJobScale provides a mock function with given fields: jobId, count, reason +// SetJobScale provides a mock function with given fields: jobID, count, reason func (_m *ExecutorAPIMock) SetJobScale(jobId string, count uint, reason string) error { ret := _m.Called(jobId, count, reason) diff --git a/nomad/job_test.go b/nomad/job_test.go index f904b0f..a901369 100644 --- a/nomad/job_test.go +++ b/nomad/job_test.go @@ -140,16 +140,7 @@ func TestConfigureNetworkSetsCorrectValues(t *testing.T) { assert.Equal(t, "bridge", networkResource.Mode) require.Equal(t, len(ports), len(networkResource.DynamicPorts)) - for _, expectedPort := range ports { - found := false - for _, actualPort := range networkResource.DynamicPorts { - if actualPort.To == int(expectedPort) { - found = true - break - } - } - assert.True(t, found, fmt.Sprintf("port list should contain %v", expectedPort)) - } + assertExpectedPorts(t, ports, networkResource) mode, ok := testTask.Config["network_mode"] assert.True(t, ok) @@ -158,6 +149,20 @@ func TestConfigureNetworkSetsCorrectValues(t *testing.T) { }) } +func assertExpectedPorts(t *testing.T, expectedPorts []uint16, networkResource *nomadApi.NetworkResource) { + t.Helper() + for _, expectedPort := range expectedPorts { + found := false + for _, actualPort := range networkResource.DynamicPorts { + if actualPort.To == int(expectedPort) { + found = true + break + } + } + assert.True(t, found, fmt.Sprintf("port list should contain %v", expectedPort)) + } +} + func TestConfigureTaskWhenNoTaskExists(t *testing.T) { taskGroup := createTestTaskGroup() require.Equal(t, 0, len(taskGroup.Tasks)) diff --git a/nomad/nomad.go b/nomad/nomad.go index 14d0c0a..60d4fae 100644 --- a/nomad/nomad.go +++ b/nomad/nomad.go @@ -85,7 +85,10 @@ func NewExecutorAPI(nomadURL *url.URL, nomadNamespace string) (ExecutorAPI, erro // init prepares an apiClient to be able to communicate to a provided Nomad API. func (a *APIClient) init(nomadURL *url.URL, nomadNamespace string) error { - return a.apiQuerier.init(nomadURL, nomadNamespace) + if err := a.apiQuerier.init(nomadURL, nomadNamespace); err != nil { + return fmt.Errorf("error initializing API querier: %w", err) + } + return nil } func (a *APIClient) LoadRunnerIDs(environmentID string) (runnerIDs []string, err error) { @@ -305,7 +308,11 @@ func (a *APIClient) ExecuteCommand(allocationID string, if tty && config.Config.Server.InteractiveStderr { return a.executeCommandInteractivelyWithStderr(allocationID, ctx, command, stdin, stdout, stderr) } - return a.apiQuerier.Execute(allocationID, ctx, command, tty, stdin, stdout, stderr) + exitCode, err := a.apiQuerier.Execute(allocationID, ctx, command, tty, stdin, stdout, stderr) + if err != nil { + return 1, fmt.Errorf("error executing command in API: %w", err) + } + return exitCode, nil } // executeCommandInteractivelyWithStderr executes the given command interactively and splits stdout @@ -325,7 +332,8 @@ func (a *APIClient) executeCommandInteractivelyWithStderr(allocationID string, c stderrExitChan := make(chan int) go func() { // Catch stderr in separate execution. - exit, err := a.Execute(allocationID, ctx, stderrFifoCommand(currentNanoTime), true, util.NullReader{}, stderr, io.Discard) + exit, err := a.Execute(allocationID, ctx, stderrFifoCommand(currentNanoTime), true, + util.NullReader{}, stderr, io.Discard) if err != nil { log.WithError(err).WithField("runner", allocationID).Warn("Stderr task finished with error") } @@ -342,15 +350,15 @@ func (a *APIClient) executeCommandInteractivelyWithStderr(allocationID string, c const ( // stderrFifoFormat represents the format we use for our stderr fifos. The %d should be unique for the execution // as otherwise multiple executions are not possible. - // Example: /tmp/stderr_1623330777825234133.fifo + // Example: "/tmp/stderr_1623330777825234133.fifo". stderrFifoFormat = "/tmp/stderr_%d.fifo" // stderrFifoCommandFormat, if executed, is supposed to create a fifo, read from it and remove it in the end. - // Example: mkfifo my.fifo && (cat my.fifo; rm my.fifo) + // Example: "mkfifo my.fifo && (cat my.fifo; rm my.fifo)". stderrFifoCommandFormat = "mkfifo %s && (cat %s; rm %s)" // stderrWrapperCommandFormat, if executed, is supposed to wait until a fifo exists (it sleeps 10ms to reduce load // cause by busy waiting on the system). Once the fifo exists, the given command is executed and its stderr // redirected to the fifo. - // Example: until [ -e my.fifo ]; do sleep 0.01; done; (echo "my.fifo exists") 2> my.fifo + // Example: "until [ -e my.fifo ]; do sleep 0.01; done; (echo \"my.fifo exists\") 2> my.fifo". stderrWrapperCommandFormat = "until [ -e %s ]; do sleep 0.01; done; (%s) 2> %s" ) diff --git a/nomad/nomad_test.go b/nomad/nomad_test.go index e583034..65cfe77 100644 --- a/nomad/nomad_test.go +++ b/nomad/nomad_test.go @@ -3,7 +3,6 @@ package nomad import ( "bytes" "context" - "errors" "fmt" nomadApi "github.com/hashicorp/nomad/api" "github.com/hashicorp/nomad/nomad/structs" @@ -29,9 +28,9 @@ func TestLoadRunnersTestSuite(t *testing.T) { type LoadRunnersTestSuite struct { suite.Suite - jobId string + jobID string mock *apiQuerierMock - nomadApiClient APIClient + nomadAPIClient APIClient availableRunner *nomadApi.JobListStub anotherAvailableRunner *nomadApi.JobListStub pendingRunner *nomadApi.JobListStub @@ -39,10 +38,10 @@ type LoadRunnersTestSuite struct { } func (s *LoadRunnersTestSuite) SetupTest() { - s.jobId = tests.DefaultJobID + s.jobID = tests.DefaultJobID s.mock = &apiQuerierMock{} - s.nomadApiClient = APIClient{apiQuerier: s.mock} + s.nomadAPIClient = APIClient{apiQuerier: s.mock} s.availableRunner = newJobListStub(tests.DefaultJobID, structs.JobStatusRunning, 1) s.anotherAvailableRunner = newJobListStub(tests.AnotherJobID, structs.JobStatusRunning, 1) @@ -65,7 +64,7 @@ func (s *LoadRunnersTestSuite) TestErrorOfUnderlyingApiCallIsPropagated() { s.mock.On("listJobs", mock.AnythingOfType("string")). Return(nil, tests.ErrDefault) - returnedIds, err := s.nomadApiClient.LoadRunnerIDs(s.jobId) + returnedIds, err := s.nomadAPIClient.LoadRunnerIDs(s.jobID) s.Nil(returnedIds) s.Equal(tests.ErrDefault, err) } @@ -74,7 +73,7 @@ func (s *LoadRunnersTestSuite) TestReturnsNoErrorWhenUnderlyingApiCallDoesNot() s.mock.On("listJobs", mock.AnythingOfType("string")). Return([]*nomadApi.JobListStub{}, nil) - _, err := s.nomadApiClient.LoadRunnerIDs(s.jobId) + _, err := s.nomadAPIClient.LoadRunnerIDs(s.jobID) s.NoError(err) } @@ -82,7 +81,8 @@ func (s *LoadRunnersTestSuite) TestAvailableRunnerIsReturned() { s.mock.On("listJobs", mock.AnythingOfType("string")). Return([]*nomadApi.JobListStub{s.availableRunner}, nil) - returnedIds, _ := s.nomadApiClient.LoadRunnerIDs(s.jobId) + returnedIds, err := s.nomadAPIClient.LoadRunnerIDs(s.jobID) + s.Require().NoError(err) s.Len(returnedIds, 1) s.Equal(s.availableRunner.ID, returnedIds[0]) } @@ -91,7 +91,8 @@ func (s *LoadRunnersTestSuite) TestPendingRunnerIsNotReturned() { s.mock.On("listJobs", mock.AnythingOfType("string")). Return([]*nomadApi.JobListStub{s.pendingRunner}, nil) - returnedIds, _ := s.nomadApiClient.LoadRunnerIDs(s.jobId) + returnedIds, err := s.nomadAPIClient.LoadRunnerIDs(s.jobID) + s.Require().NoError(err) s.Empty(returnedIds) } @@ -99,7 +100,8 @@ func (s *LoadRunnersTestSuite) TestDeadRunnerIsNotReturned() { s.mock.On("listJobs", mock.AnythingOfType("string")). Return([]*nomadApi.JobListStub{s.deadRunner}, nil) - returnedIds, _ := s.nomadApiClient.LoadRunnerIDs(s.jobId) + returnedIds, err := s.nomadAPIClient.LoadRunnerIDs(s.jobID) + s.Require().NoError(err) s.Empty(returnedIds) } @@ -113,7 +115,8 @@ func (s *LoadRunnersTestSuite) TestReturnsAllAvailableRunners() { s.mock.On("listJobs", mock.AnythingOfType("string")). Return(runnersList, nil) - returnedIds, _ := s.nomadApiClient.LoadRunnerIDs(s.jobId) + returnedIds, err := s.nomadAPIClient.LoadRunnerIDs(s.jobID) + s.Require().NoError(err) s.Len(returnedIds, 2) s.Contains(returnedIds, s.availableRunner.ID) s.Contains(returnedIds, s.anotherAvailableRunner.ID) @@ -189,12 +192,11 @@ func TestApiClient_MonitorEvaluationReturnsNilWhenStreamIsClosed(t *testing.T) { func TestApiClient_MonitorEvaluationReturnsErrorWhenStreamReturnsError(t *testing.T) { apiMock := &apiQuerierMock{} - expectedErr := errors.New("test error") apiMock.On("EvaluationStream", mock.AnythingOfType("string"), mock.AnythingOfType("*context.emptyCtx")). - Return(nil, expectedErr) + Return(nil, tests.ErrDefault) apiClient := &APIClient{apiMock} err := apiClient.MonitorEvaluation("id", context.Background()) - assert.ErrorIs(t, err, expectedErr) + assert.ErrorIs(t, err, tests.ErrDefault) } type eventPayload struct { @@ -205,10 +207,11 @@ type eventPayload struct { // eventForEvaluation takes an evaluation and creates an Event with the given evaluation // as its payload. Nomad uses the mapstructure library to decode the payload, which we // simply reverse here. -func eventForEvaluation(t *testing.T, eval nomadApi.Evaluation) nomadApi.Event { +func eventForEvaluation(t *testing.T, eval *nomadApi.Evaluation) nomadApi.Event { + t.Helper() payload := make(map[string]interface{}) - err := mapstructure.Decode(eventPayload{Evaluation: &eval}, &payload) + err := mapstructure.Decode(eventPayload{Evaluation: eval}, &payload) if err != nil { t.Fatalf("Couldn't decode evaluation %v into payload map", eval) return nomadApi.Event{} @@ -259,10 +262,10 @@ func TestApiClient_MonitorEvaluationWithSuccessfulEvent(t *testing.T) { // make sure that the tested function can complete require.Nil(t, checkEvaluation(&eval)) - events := nomadApi.Events{Events: []nomadApi.Event{eventForEvaluation(t, eval)}} - pendingEvaluationEvents := nomadApi.Events{Events: []nomadApi.Event{eventForEvaluation(t, pendingEval)}} + events := nomadApi.Events{Events: []nomadApi.Event{eventForEvaluation(t, &eval)}} + pendingEvaluationEvents := nomadApi.Events{Events: []nomadApi.Event{eventForEvaluation(t, &pendingEval)}} multipleEventsWithPending := nomadApi.Events{Events: []nomadApi.Event{ - eventForEvaluation(t, pendingEval), eventForEvaluation(t, eval), + eventForEvaluation(t, &pendingEval), eventForEvaluation(t, &eval), }} var cases = []struct { @@ -298,10 +301,10 @@ func TestApiClient_MonitorEvaluationWithFailingEvent(t *testing.T) { pendingEval := nomadApi.Evaluation{Status: structs.EvalStatusPending} - events := nomadApi.Events{Events: []nomadApi.Event{eventForEvaluation(t, eval)}} - pendingEvaluationEvents := nomadApi.Events{Events: []nomadApi.Event{eventForEvaluation(t, pendingEval)}} + events := nomadApi.Events{Events: []nomadApi.Event{eventForEvaluation(t, &eval)}} + pendingEvaluationEvents := nomadApi.Events{Events: []nomadApi.Event{eventForEvaluation(t, &pendingEval)}} multipleEventsWithPending := nomadApi.Events{Events: []nomadApi.Event{ - eventForEvaluation(t, pendingEval), eventForEvaluation(t, eval), + eventForEvaluation(t, &pendingEval), eventForEvaluation(t, &eval), }} eventsWithErr := nomadApi.Events{Err: tests.ErrDefault, Events: []nomadApi.Event{{}}} @@ -390,7 +393,8 @@ func TestCheckEvaluationWithoutFailedAllocations(t *testing.T) { }) t.Run("when evaluation status not complete", func(t *testing.T) { - incompleteStates := []string{structs.EvalStatusFailed, structs.EvalStatusCancelled, structs.EvalStatusBlocked, structs.EvalStatusPending} + incompleteStates := []string{structs.EvalStatusFailed, structs.EvalStatusCancelled, + structs.EvalStatusBlocked, structs.EvalStatusPending} for _, status := range incompleteStates { evaluation.Status = status err := checkEvaluation(&evaluation) @@ -741,7 +745,7 @@ func (s *ExecuteCommandTestSuite) TestWithoutSeparateStderrReturnsCommandError() s.mockExecute(s.testCommandArray, 1, tests.ErrDefault, func(args mock.Arguments) {}) _, err := s.nomadAPIClient. ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, util.NullReader{}, io.Discard, io.Discard) - s.Equal(tests.ErrDefault, err) + s.ErrorIs(err, tests.ErrDefault) } func (s *ExecuteCommandTestSuite) mockExecute(command interface{}, exitCode int, diff --git a/runner/execution_storage.go b/runner/execution_storage.go index 74a881c..3d7d17a 100644 --- a/runner/execution_storage.go +++ b/runner/execution_storage.go @@ -9,38 +9,38 @@ import ( type ExecutionStorage interface { // Add adds a runner to the storage. // It overwrites the existing execution if an execution with the same id already exists. - Add(id ExecutionId, executionRequest *dto.ExecutionRequest) + Add(id ExecutionID, executionRequest *dto.ExecutionRequest) // Pop deletes the execution with the given id from the storage and returns it. // If no such execution exists, ok is false and true otherwise. - Pop(id ExecutionId) (request *dto.ExecutionRequest, ok bool) + Pop(id ExecutionID) (request *dto.ExecutionRequest, ok bool) } // localExecutionStorage stores execution objects in the local application memory. -// ToDo: Create implementation that use some persistent storage like a database +// ToDo: Create implementation that use some persistent storage like a database. type localExecutionStorage struct { sync.RWMutex - executions map[ExecutionId]*dto.ExecutionRequest + executions map[ExecutionID]*dto.ExecutionRequest } // NewLocalExecutionStorage responds with an ExecutionStorage implementation. // This implementation stores the data thread-safe in the local application memory. func NewLocalExecutionStorage() *localExecutionStorage { return &localExecutionStorage{ - executions: make(map[ExecutionId]*dto.ExecutionRequest), + executions: make(map[ExecutionID]*dto.ExecutionRequest), } } -func (s *localExecutionStorage) Add(id ExecutionId, executionRequest *dto.ExecutionRequest) { +func (s *localExecutionStorage) Add(id ExecutionID, executionRequest *dto.ExecutionRequest) { s.Lock() defer s.Unlock() s.executions[id] = executionRequest } -func (s *localExecutionStorage) Pop(id ExecutionId) (request *dto.ExecutionRequest, ok bool) { +func (s *localExecutionStorage) Pop(id ExecutionID) (*dto.ExecutionRequest, bool) { s.Lock() defer s.Unlock() - request, ok = s.executions[id] + request, ok := s.executions[id] delete(s.executions, id) - return + return request, ok } diff --git a/runner/manager.go b/runner/manager.go index bc557b4..cd322ac 100644 --- a/runner/manager.go +++ b/runner/manager.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/google/uuid" nomadApi "github.com/hashicorp/nomad/api" + "github.com/sirupsen/logrus" "gitlab.hpi.de/codeocean/codemoon/poseidon/logging" "gitlab.hpi.de/codeocean/codemoon/poseidon/nomad" "strconv" @@ -127,7 +128,8 @@ func (m *NomadRunnerManager) updateEnvironment(id EnvironmentID, desiredIdleRunn } environment.desiredIdleRunnersCount = desiredIdleRunnersCount environment.templateJob = newTemplateJob - err := nomad.SetMetaConfigValue(newTemplateJob, nomad.ConfigMetaPoolSizeKey, strconv.Itoa(int(desiredIdleRunnersCount))) + err := nomad.SetMetaConfigValue(newTemplateJob, nomad.ConfigMetaPoolSizeKey, + strconv.Itoa(int(desiredIdleRunnersCount))) if err != nil { return fmt.Errorf("update environment couldn't update template environment: %w", err) } @@ -177,7 +179,7 @@ func (m *NomadRunnerManager) Claim(environmentID EnvironmentID, duration int) (R return nil, ErrNoRunnersAvailable } m.usedRunners.Add(runner) - err := m.apiClient.MarkRunnerAsUsed(runner.Id(), duration) + err := m.apiClient.MarkRunnerAsUsed(runner.ID(), duration) if err != nil { return nil, fmt.Errorf("can't mark runner as used: %w", err) } @@ -200,14 +202,14 @@ func (m *NomadRunnerManager) Get(runnerID string) (Runner, error) { return runner, nil } -func (m *NomadRunnerManager) Return(r Runner) (err error) { +func (m *NomadRunnerManager) Return(r Runner) error { r.StopTimeout() - err = m.apiClient.DeleteRunner(r.Id()) + err := m.apiClient.DeleteRunner(r.ID()) if err != nil { - return + return fmt.Errorf("error deleting runner in Nomad: %w", err) } - m.usedRunners.Delete(r.Id()) - return + m.usedRunners.Delete(r.ID()) + return nil } func (m *NomadRunnerManager) Load() { @@ -218,24 +220,7 @@ func (m *NomadRunnerManager) Load() { environmentLogger.WithError(err).Warn("Error fetching the runner jobs") } for _, job := range runnerJobs { - configTaskGroup := nomad.FindConfigTaskGroup(job) - if configTaskGroup == nil { - environmentLogger.Infof("Couldn't find config task group in job %s, skipping ...", *job.ID) - continue - } - isUsed := configTaskGroup.Meta[nomad.ConfigMetaUsedKey] == nomad.ConfigMetaUsedValue - newJob := NewNomadJob(*job.ID, m.apiClient, m) - if isUsed { - m.usedRunners.Add(newJob) - timeout, err := strconv.Atoi(configTaskGroup.Meta[nomad.ConfigMetaTimeoutKey]) - if err != nil { - log.WithError(err).Warn("Error loading timeout from meta values") - } else { - newJob.SetupTimeout(time.Duration(timeout) * time.Second) - } - } else { - environment.idleRunners.Add(newJob) - } + m.loadSingleJob(job, environmentLogger, environment) } err = m.scaleEnvironment(environment.ID()) if err != nil { @@ -244,6 +229,29 @@ func (m *NomadRunnerManager) Load() { } } +func (m *NomadRunnerManager) loadSingleJob(job *nomadApi.Job, environmentLogger *logrus.Entry, + environment *NomadEnvironment, +) { + configTaskGroup := nomad.FindConfigTaskGroup(job) + if configTaskGroup == nil { + environmentLogger.Infof("Couldn't find config task group in job %s, skipping ...", *job.ID) + return + } + isUsed := configTaskGroup.Meta[nomad.ConfigMetaUsedKey] == nomad.ConfigMetaUsedValue + newJob := NewNomadJob(*job.ID, m.apiClient, m) + if isUsed { + m.usedRunners.Add(newJob) + timeout, err := strconv.Atoi(configTaskGroup.Meta[nomad.ConfigMetaTimeoutKey]) + if err != nil { + environmentLogger.WithError(err).Warn("Error loading timeout from meta values") + } else { + newJob.SetupTimeout(time.Duration(timeout) * time.Second) + } + } else { + environment.idleRunners.Add(newJob) + } +} + func (m *NomadRunnerManager) keepRunnersSynced(ctx context.Context) { retries := 0 for ctx.Err() == nil { @@ -319,7 +327,7 @@ func (m *NomadRunnerManager) createRunners(environment *NomadEnvironment, count func (m *NomadRunnerManager) createRunner(environment *NomadEnvironment) error { newUUID, err := uuid.NewUUID() if err != nil { - return fmt.Errorf("failed generating runner id") + return fmt.Errorf("failed generating runner id: %w", err) } newRunnerID := RunnerJobID(environment.ID(), newUUID.String()) @@ -327,7 +335,11 @@ func (m *NomadRunnerManager) createRunner(environment *NomadEnvironment) error { template.ID = &newRunnerID template.Name = &newRunnerID - return m.apiClient.RegisterRunnerJob(&template) + err = m.apiClient.RegisterRunnerJob(&template) + if err != nil { + return fmt.Errorf("error registering new runner job: %w", err) + } + return nil } func (m *NomadRunnerManager) removeRunners(environment *NomadEnvironment, count uint) error { @@ -337,7 +349,7 @@ func (m *NomadRunnerManager) removeRunners(environment *NomadEnvironment, count if !ok { return fmt.Errorf("could not delete expected idle runner: %w", ErrRunnerNotFound) } - err := m.apiClient.DeleteRunner(r.Id()) + err := m.apiClient.DeleteRunner(r.ID()) if err != nil { return fmt.Errorf("could not delete expected Nomad idle runner: %w", err) } @@ -345,9 +357,9 @@ func (m *NomadRunnerManager) removeRunners(environment *NomadEnvironment, count return nil } -// RunnerJobID returns the nomad job id of the runner with the given environment id and uuid. -func RunnerJobID(environmentID EnvironmentID, uuid string) string { - return fmt.Sprintf("%d-%s", environmentID, uuid) +// RunnerJobID returns the nomad job id of the runner with the given environmentID and id. +func RunnerJobID(environmentID EnvironmentID, id string) string { + return fmt.Sprintf("%d-%s", environmentID, id) } // EnvironmentIDFromJobID returns the environment id that is part of the passed job id. @@ -363,6 +375,8 @@ func EnvironmentIDFromJobID(jobID string) (EnvironmentID, error) { return EnvironmentID(environmentID), nil } +const templateJobNameParts = 2 + // TemplateJobID returns the id of the template job for the environment with the given id. func TemplateJobID(id EnvironmentID) string { return fmt.Sprintf("%s-%d", nomad.TemplateJobPrefix, id) @@ -371,12 +385,12 @@ func TemplateJobID(id EnvironmentID) string { // IsEnvironmentTemplateID checks if the passed job id belongs to a template job. func IsEnvironmentTemplateID(jobID string) bool { parts := strings.Split(jobID, "-") - return len(parts) == 2 && parts[0] == nomad.TemplateJobPrefix + return len(parts) == templateJobNameParts && parts[0] == nomad.TemplateJobPrefix } func EnvironmentIDFromTemplateJobID(id string) (string, error) { parts := strings.Split(id, "-") - if len(parts) < 2 { + if len(parts) < templateJobNameParts { return "", fmt.Errorf("invalid template job id: %w", ErrorInvalidJobID) } return parts[1], nil diff --git a/runner/manager_test.go b/runner/manager_test.go index d14dd6c..1e481d0 100644 --- a/runner/manager_test.go +++ b/runner/manager_test.go @@ -2,7 +2,6 @@ package runner import ( "context" - "errors" nomadApi "github.com/hashicorp/nomad/api" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" @@ -129,8 +128,9 @@ func (s *ManagerTestSuite) TestClaimThrowsAnErrorIfNoRunnersAvailable() { func (s *ManagerTestSuite) TestClaimAddsRunnerToUsedRunners() { s.AddIdleRunnerForDefaultEnvironment(s.exerciseRunner) - receivedRunner, _ := s.nomadRunnerManager.Claim(defaultEnvironmentID, defaultInactivityTimeout) - savedRunner, ok := s.nomadRunnerManager.usedRunners.Get(receivedRunner.Id()) + receivedRunner, err := s.nomadRunnerManager.Claim(defaultEnvironmentID, defaultInactivityTimeout) + s.Require().NoError(err) + savedRunner, ok := s.nomadRunnerManager.usedRunners.Get(receivedRunner.ID()) s.True(ok) s.Equal(savedRunner, receivedRunner) } @@ -147,7 +147,7 @@ func (s *ManagerTestSuite) TestTwoClaimsAddExactlyTwoRunners() { func (s *ManagerTestSuite) TestGetReturnsRunnerIfRunnerIsUsed() { s.nomadRunnerManager.usedRunners.Add(s.exerciseRunner) - savedRunner, err := s.nomadRunnerManager.Get(s.exerciseRunner.Id()) + savedRunner, err := s.nomadRunnerManager.Get(s.exerciseRunner.ID()) s.NoError(err) s.Equal(savedRunner, s.exerciseRunner) } @@ -163,7 +163,7 @@ func (s *ManagerTestSuite) TestReturnRemovesRunnerFromUsedRunners() { s.nomadRunnerManager.usedRunners.Add(s.exerciseRunner) err := s.nomadRunnerManager.Return(s.exerciseRunner) s.Nil(err) - _, ok := s.nomadRunnerManager.usedRunners.Get(s.exerciseRunner.Id()) + _, ok := s.nomadRunnerManager.usedRunners.Get(s.exerciseRunner.ID()) s.False(ok) } @@ -171,11 +171,11 @@ func (s *ManagerTestSuite) TestReturnCallsDeleteRunnerApiMethod() { s.apiMock.On("DeleteRunner", mock.AnythingOfType("string")).Return(nil) err := s.nomadRunnerManager.Return(s.exerciseRunner) s.Nil(err) - s.apiMock.AssertCalled(s.T(), "DeleteRunner", s.exerciseRunner.Id()) + s.apiMock.AssertCalled(s.T(), "DeleteRunner", s.exerciseRunner.ID()) } func (s *ManagerTestSuite) TestReturnReturnsErrorWhenApiCallFailed() { - s.apiMock.On("DeleteRunner", mock.AnythingOfType("string")).Return(errors.New("return failed")) + s.apiMock.On("DeleteRunner", mock.AnythingOfType("string")).Return(tests.ErrDefault) err := s.nomadRunnerManager.Return(s.exerciseRunner) s.Error(err) } diff --git a/runner/runner.go b/runner/runner.go index 239ca74..c9da54e 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -18,11 +18,11 @@ import ( // ContextKey is the type for keys in a request context. type ContextKey string -// ExecutionId is an id for an execution in a Runner. -type ExecutionId string +// ExecutionID is an id for an execution in a Runner. +type ExecutionID string const ( - // runnerContextKey is the key used to store runners in context.Context + // runnerContextKey is the key used to store runners in context.Context. runnerContextKey ContextKey = "runner" ) @@ -93,9 +93,9 @@ func (t *InactivityTimerImplementation) SetupTimeout(duration time.Duration) { t.Unlock() err := t.manager.Return(t.runner) if err != nil { - log.WithError(err).WithField("id", t.runner.Id()).Warn("Returning runner after inactivity caused an error") + log.WithError(err).WithField("id", t.runner.ID()).Warn("Returning runner after inactivity caused an error") } else { - log.WithField("id", t.runner.Id()).Info("Returning runner due to inactivity timeout") + log.WithField("id", t.runner.ID()).Info("Returning runner due to inactivity timeout") } }) } @@ -129,8 +129,8 @@ func (t *InactivityTimerImplementation) TimeoutPassed() bool { } type Runner interface { - // Id returns the id of the runner. - Id() string + // ID returns the id of the runner. + ID() string ExecutionStorage InactivityTimer @@ -169,7 +169,7 @@ func NewNomadJob(id string, apiClient nomad.ExecutorAPI, manager Manager) *Nomad return job } -func (r *NomadJob) Id() string { +func (r *NomadJob) ID() string { return r.id } @@ -241,21 +241,24 @@ func createTarArchiveForFiles(filesToCopy []dto.File, w io.Writer) error { tarWriter := tar.NewWriter(w) for _, file := range filesToCopy { if err := tarWriter.WriteHeader(tarHeader(file)); err != nil { + err := fmt.Errorf("error writing tar file header: %w", err) log. - WithError(err). WithField("file", file). - Error("Error writing tar file header") + Error(err) return err } if _, err := tarWriter.Write(file.ByteContent()); err != nil { + err := fmt.Errorf("error writing tar file content: %w", err) log. - WithError(err). WithField("file", file). - Error("Error writing tar file content") + Error(err) return err } } - return tarWriter.Close() + if err := tarWriter.Close(); err != nil { + return fmt.Errorf("error closing tar writer: %w", err) + } + return nil } func fileDeletionCommand(pathsToDelete []dto.FilePath) string { @@ -265,7 +268,8 @@ func fileDeletionCommand(pathsToDelete []dto.FilePath) string { command := "rm --recursive --force " for _, filePath := range pathsToDelete { // To avoid command injection, filenames need to be quoted. - // See https://unix.stackexchange.com/questions/347332/what-characters-need-to-be-escaped-in-files-without-quotes for details. + // See https://unix.stackexchange.com/questions/347332/what-characters-need-to-be-escaped-in-files-without-quotes + // for details. singleQuoteEscapedFileName := strings.ReplaceAll(filePath.Cleaned(), "'", "'\\''") command += fmt.Sprintf("'%s' ", singleQuoteEscapedFileName) } @@ -293,11 +297,15 @@ func tarHeader(file dto.File) *tar.Header { // MarshalJSON implements json.Marshaler interface. // This exports private attributes like the id too. func (r *NomadJob) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { + res, err := json.Marshal(struct { ID string `json:"runnerId"` }{ - ID: r.Id(), + ID: r.ID(), }) + if err != nil { + return nil, fmt.Errorf("error marshaling Nomad job: %w", err) + } + return res, nil } // NewContext creates a context containing a runner. diff --git a/runner/runner_mock.go b/runner/runner_mock.go index 9ac27fe..a48537e 100644 --- a/runner/runner_mock.go +++ b/runner/runner_mock.go @@ -19,7 +19,7 @@ type RunnerMock struct { } // Add provides a mock function with given fields: id, executionRequest -func (_m *RunnerMock) Add(id ExecutionId, executionRequest *dto.ExecutionRequest) { +func (_m *RunnerMock) Add(id ExecutionID, executionRequest *dto.ExecutionRequest) { _m.Called(id, executionRequest) } @@ -49,7 +49,7 @@ func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin } // Id provides a mock function with given fields: -func (_m *RunnerMock) Id() string { +func (_m *RunnerMock) ID() string { ret := _m.Called() var r0 string @@ -63,11 +63,11 @@ func (_m *RunnerMock) Id() string { } // Pop provides a mock function with given fields: id -func (_m *RunnerMock) Pop(id ExecutionId) (*dto.ExecutionRequest, bool) { +func (_m *RunnerMock) Pop(id ExecutionID) (*dto.ExecutionRequest, bool) { ret := _m.Called(id) var r0 *dto.ExecutionRequest - if rf, ok := ret.Get(0).(func(ExecutionId) *dto.ExecutionRequest); ok { + if rf, ok := ret.Get(0).(func(ExecutionID) *dto.ExecutionRequest); ok { r0 = rf(id) } else { if ret.Get(0) != nil { @@ -76,7 +76,7 @@ func (_m *RunnerMock) Pop(id ExecutionId) (*dto.ExecutionRequest, bool) { } var r1 bool - if rf, ok := ret.Get(1).(func(ExecutionId) bool); ok { + if rf, ok := ret.Get(1).(func(ExecutionID) bool); ok { r1 = rf(id) } else { r1 = ret.Get(1).(bool) diff --git a/runner/runner_test.go b/runner/runner_test.go index 508548c..484e878 100644 --- a/runner/runner_test.go +++ b/runner/runner_test.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "gitlab.hpi.de/codeocean/codemoon/poseidon/api/dto" "gitlab.hpi.de/codeocean/codemoon/poseidon/nomad" @@ -21,7 +22,7 @@ import ( func TestIdIsStored(t *testing.T) { runner := NewNomadJob(tests.DefaultJobID, nil, nil) - assert.Equal(t, tests.DefaultJobID, runner.Id()) + assert.Equal(t, tests.DefaultJobID, runner.ID()) } func TestMarshalRunner(t *testing.T) { @@ -38,7 +39,7 @@ func TestExecutionRequestIsStored(t *testing.T) { TimeLimit: 10, Environment: nil, } - id := ExecutionId("test-execution") + id := ExecutionID("test-execution") runner.Add(id, executionRequest) storedExecutionRunner, ok := runner.Pop(id) @@ -50,7 +51,8 @@ func TestNewContextReturnsNewContextWithRunner(t *testing.T) { runner := NewNomadJob(tests.DefaultRunnerID, nil, nil) ctx := context.Background() newCtx := NewContext(ctx, runner) - storedRunner := newCtx.Value(runnerContextKey).(Runner) + storedRunner, ok := newCtx.Value(runnerContextKey).(Runner) + require.True(t, ok) assert.NotEqual(t, ctx, newCtx) assert.Equal(t, runner, storedRunner) @@ -106,12 +108,14 @@ func (s *ExecuteInteractivelyTestSuite) TestCallsApi() { s.runner.ExecuteInteractively(request, nil, nil, nil) time.Sleep(tests.ShortTimeout) - s.apiMock.AssertCalled(s.T(), "ExecuteCommand", tests.DefaultRunnerID, mock.Anything, request.FullCommand(), true, mock.Anything, mock.Anything, mock.Anything) + s.apiMock.AssertCalled(s.T(), "ExecuteCommand", tests.DefaultRunnerID, mock.Anything, request.FullCommand(), + true, mock.Anything, mock.Anything, mock.Anything) } func (s *ExecuteInteractivelyTestSuite) TestReturnsAfterTimeout() { s.mockedExecuteCommandCall.Run(func(args mock.Arguments) { - ctx := args.Get(1).(context.Context) + ctx, ok := args.Get(1).(context.Context) + s.Require().True(ok) <-ctx.Done() }). Return(0, nil) @@ -173,10 +177,14 @@ func (s *UpdateFileSystemTestSuite) SetupTest() { id: tests.DefaultRunnerID, api: s.apiMock, } - s.mockedExecuteCommandCall = s.apiMock.On("ExecuteCommand", tests.DefaultRunnerID, mock.Anything, mock.Anything, false, mock.Anything, mock.Anything, mock.Anything). + s.mockedExecuteCommandCall = s.apiMock.On("ExecuteCommand", tests.DefaultRunnerID, mock.Anything, + mock.Anything, false, mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { - s.command = args.Get(2).([]string) - s.stdin = args.Get(4).(*bytes.Buffer) + var ok bool + s.command, ok = args.Get(2).([]string) + s.Require().True(ok) + s.stdin, ok = args.Get(4).(*bytes.Buffer) + s.Require().True(ok) }).Return(0, nil) } @@ -186,7 +194,8 @@ func (s *UpdateFileSystemTestSuite) TestUpdateFileSystemForRunnerPerformsTarExtr copyRequest := &dto.UpdateFileSystemRequest{} err := s.runner.UpdateFileSystem(copyRequest) s.NoError(err) - s.apiMock.AssertCalled(s.T(), "ExecuteCommand", mock.Anything, mock.Anything, mock.Anything, false, mock.Anything, mock.Anything, mock.Anything) + s.apiMock.AssertCalled(s.T(), "ExecuteCommand", mock.Anything, mock.Anything, mock.Anything, + false, mock.Anything, mock.Anything, mock.Anything) s.Regexp("tar --extract --absolute-names", s.command) } @@ -205,10 +214,12 @@ func (s *UpdateFileSystemTestSuite) TestUpdateFileSystemForRunnerReturnsErrorIfA } func (s *UpdateFileSystemTestSuite) TestFilesToCopyAreIncludedInTarArchive() { - copyRequest := &dto.UpdateFileSystemRequest{Copy: []dto.File{{Path: tests.DefaultFileName, Content: []byte(tests.DefaultFileContent)}}} + copyRequest := &dto.UpdateFileSystemRequest{Copy: []dto.File{ + {Path: tests.DefaultFileName, Content: []byte(tests.DefaultFileContent)}}} err := s.runner.UpdateFileSystem(copyRequest) s.NoError(err) - s.apiMock.AssertCalled(s.T(), "ExecuteCommand", mock.Anything, mock.Anything, mock.Anything, false, mock.Anything, mock.Anything, mock.Anything) + s.apiMock.AssertCalled(s.T(), "ExecuteCommand", mock.Anything, mock.Anything, mock.Anything, false, + mock.Anything, mock.Anything, mock.Anything) tarFiles := s.readFilesFromTarArchive(s.stdin) s.Len(tarFiles, 1) @@ -219,8 +230,10 @@ func (s *UpdateFileSystemTestSuite) TestFilesToCopyAreIncludedInTarArchive() { } func (s *UpdateFileSystemTestSuite) TestTarFilesContainCorrectPathForRelativeFilePath() { - copyRequest := &dto.UpdateFileSystemRequest{Copy: []dto.File{{Path: tests.DefaultFileName, Content: []byte(tests.DefaultFileContent)}}} - _ = s.runner.UpdateFileSystem(copyRequest) + copyRequest := &dto.UpdateFileSystemRequest{Copy: []dto.File{ + {Path: tests.DefaultFileName, Content: []byte(tests.DefaultFileContent)}}} + err := s.runner.UpdateFileSystem(copyRequest) + s.Require().NoError(err) tarFiles := s.readFilesFromTarArchive(s.stdin) s.Len(tarFiles, 1) @@ -229,8 +242,10 @@ func (s *UpdateFileSystemTestSuite) TestTarFilesContainCorrectPathForRelativeFil } func (s *UpdateFileSystemTestSuite) TestFilesWithAbsolutePathArePutInAbsoluteLocation() { - copyRequest := &dto.UpdateFileSystemRequest{Copy: []dto.File{{Path: tests.FileNameWithAbsolutePath, Content: []byte(tests.DefaultFileContent)}}} - _ = s.runner.UpdateFileSystem(copyRequest) + copyRequest := &dto.UpdateFileSystemRequest{Copy: []dto.File{ + {Path: tests.FileNameWithAbsolutePath, Content: []byte(tests.DefaultFileContent)}}} + err := s.runner.UpdateFileSystem(copyRequest) + s.Require().NoError(err) tarFiles := s.readFilesFromTarArchive(s.stdin) s.Len(tarFiles, 1) @@ -239,7 +254,8 @@ func (s *UpdateFileSystemTestSuite) TestFilesWithAbsolutePathArePutInAbsoluteLoc func (s *UpdateFileSystemTestSuite) TestDirectoriesAreMarkedAsDirectoryInTar() { copyRequest := &dto.UpdateFileSystemRequest{Copy: []dto.File{{Path: tests.DefaultDirectoryName, Content: []byte{}}}} - _ = s.runner.UpdateFileSystem(copyRequest) + err := s.runner.UpdateFileSystem(copyRequest) + s.Require().NoError(err) tarFiles := s.readFilesFromTarArchive(s.stdin) s.Len(tarFiles, 1) @@ -253,7 +269,8 @@ func (s *UpdateFileSystemTestSuite) TestFilesToRemoveGetRemoved() { copyRequest := &dto.UpdateFileSystemRequest{Delete: []dto.FilePath{tests.DefaultFileName}} err := s.runner.UpdateFileSystem(copyRequest) s.NoError(err) - s.apiMock.AssertCalled(s.T(), "ExecuteCommand", mock.Anything, mock.Anything, mock.Anything, false, mock.Anything, mock.Anything, mock.Anything) + s.apiMock.AssertCalled(s.T(), "ExecuteCommand", mock.Anything, mock.Anything, mock.Anything, false, + mock.Anything, mock.Anything, mock.Anything) s.Regexp(fmt.Sprintf("rm[^;]+%s' *;", regexp.QuoteMeta(tests.DefaultFileName)), s.command) } @@ -261,7 +278,8 @@ func (s *UpdateFileSystemTestSuite) TestFilesToRemoveGetEscaped() { copyRequest := &dto.UpdateFileSystemRequest{Delete: []dto.FilePath{"/some/potentially/harmful'filename"}} err := s.runner.UpdateFileSystem(copyRequest) s.NoError(err) - s.apiMock.AssertCalled(s.T(), "ExecuteCommand", mock.Anything, mock.Anything, mock.Anything, false, mock.Anything, mock.Anything, mock.Anything) + s.apiMock.AssertCalled(s.T(), "ExecuteCommand", mock.Anything, mock.Anything, mock.Anything, false, + mock.Anything, mock.Anything, mock.Anything) s.Contains(strings.Join(s.command, " "), "'/some/potentially/harmful'\\''filename'") } @@ -285,7 +303,8 @@ func (s *UpdateFileSystemTestSuite) readFilesFromTarArchive(tarArchive io.Reader if err != nil { break } - bf, _ := io.ReadAll(reader) + bf, err := io.ReadAll(reader) + s.Require().NoError(err) files = append(files, TarFile{Name: hdr.Name, Content: string(bf), TypeFlag: hdr.Typeflag}) } return files diff --git a/runner/storage.go b/runner/storage.go index d4ef9e0..519f7bd 100644 --- a/runner/storage.go +++ b/runner/storage.go @@ -14,7 +14,8 @@ type Storage interface { // Iff the runner does not exist in the storage, ok will be false. Get(id string) (r Runner, ok bool) - // Delete deletes the runner with the passed id from the storage. It does nothing if no runner with the id is present in the store. + // Delete deletes the runner with the passed id from the storage. + // It does nothing if no runner with the id is present in the store. Delete(id string) // Length returns the number of currently stored runners in the storage. @@ -26,14 +27,14 @@ type Storage interface { } // localRunnerStorage stores runner objects in the local application memory. -// ToDo: Create implementation that use some persistent storage like a database +// ToDo: Create implementation that use some persistent storage like a database. type localRunnerStorage struct { sync.RWMutex runners map[string]Runner } // NewLocalRunnerStorage responds with a Storage implementation. -// This implementation stores the data thread-safe in the local application memory +// This implementation stores the data thread-safe in the local application memory. func NewLocalRunnerStorage() *localRunnerStorage { return &localRunnerStorage{ runners: make(map[string]Runner), @@ -43,7 +44,7 @@ func NewLocalRunnerStorage() *localRunnerStorage { func (s *localRunnerStorage) Add(r Runner) { s.Lock() defer s.Unlock() - s.runners[r.Id()] = r + s.runners[r.ID()] = r } func (s *localRunnerStorage) Get(id string) (r Runner, ok bool) { @@ -63,7 +64,7 @@ func (s *localRunnerStorage) Sample() (Runner, bool) { s.Lock() defer s.Unlock() for _, runner := range s.runners { - delete(s.runners, runner.Id()) + delete(s.runners, runner.ID()) return runner, true } return nil, false diff --git a/runner/storage_test.go b/runner/storage_test.go index 2b25c64..c6960f2 100644 --- a/runner/storage_test.go +++ b/runner/storage_test.go @@ -17,87 +17,87 @@ type RunnerPoolTestSuite struct { runner Runner } -func (suite *RunnerPoolTestSuite) SetupTest() { - suite.runnerStorage = NewLocalRunnerStorage() - suite.runner = NewRunner(tests.DefaultRunnerID, nil) - suite.runner.Add(tests.DefaultExecutionID, &dto.ExecutionRequest{Command: "true"}) +func (s *RunnerPoolTestSuite) SetupTest() { + s.runnerStorage = NewLocalRunnerStorage() + s.runner = NewRunner(tests.DefaultRunnerID, nil) + s.runner.Add(tests.DefaultExecutionID, &dto.ExecutionRequest{Command: "true"}) } -func (suite *RunnerPoolTestSuite) TestAddedRunnerCanBeRetrieved() { - suite.runnerStorage.Add(suite.runner) - retrievedRunner, ok := suite.runnerStorage.Get(suite.runner.Id()) - suite.True(ok, "A saved runner should be retrievable") - suite.Equal(suite.runner, retrievedRunner) +func (s *RunnerPoolTestSuite) TestAddedRunnerCanBeRetrieved() { + s.runnerStorage.Add(s.runner) + retrievedRunner, ok := s.runnerStorage.Get(s.runner.ID()) + s.True(ok, "A saved runner should be retrievable") + s.Equal(s.runner, retrievedRunner) } -func (suite *RunnerPoolTestSuite) TestRunnerWithSameIdOverwritesOldOne() { - otherRunnerWithSameId := NewRunner(suite.runner.Id(), nil) +func (s *RunnerPoolTestSuite) TestRunnerWithSameIdOverwritesOldOne() { + otherRunnerWithSameID := NewRunner(s.runner.ID(), nil) // assure runner is actually different - suite.NotEqual(suite.runner, otherRunnerWithSameId) + s.NotEqual(s.runner, otherRunnerWithSameID) - suite.runnerStorage.Add(suite.runner) - suite.runnerStorage.Add(otherRunnerWithSameId) - retrievedRunner, _ := suite.runnerStorage.Get(suite.runner.Id()) - suite.NotEqual(suite.runner, retrievedRunner) - suite.Equal(otherRunnerWithSameId, retrievedRunner) + s.runnerStorage.Add(s.runner) + s.runnerStorage.Add(otherRunnerWithSameID) + retrievedRunner, _ := s.runnerStorage.Get(s.runner.ID()) + s.NotEqual(s.runner, retrievedRunner) + s.Equal(otherRunnerWithSameID, retrievedRunner) } -func (suite *RunnerPoolTestSuite) TestDeletedRunnersAreNotAccessible() { - suite.runnerStorage.Add(suite.runner) - suite.runnerStorage.Delete(suite.runner.Id()) - retrievedRunner, ok := suite.runnerStorage.Get(suite.runner.Id()) - suite.Nil(retrievedRunner) - suite.False(ok, "A deleted runner should not be accessible") +func (s *RunnerPoolTestSuite) TestDeletedRunnersAreNotAccessible() { + s.runnerStorage.Add(s.runner) + s.runnerStorage.Delete(s.runner.ID()) + retrievedRunner, ok := s.runnerStorage.Get(s.runner.ID()) + s.Nil(retrievedRunner) + s.False(ok, "A deleted runner should not be accessible") } -func (suite *RunnerPoolTestSuite) TestSampleReturnsRunnerWhenOneIsAvailable() { - suite.runnerStorage.Add(suite.runner) - sampledRunner, ok := suite.runnerStorage.Sample() - suite.NotNil(sampledRunner) - suite.True(ok) +func (s *RunnerPoolTestSuite) TestSampleReturnsRunnerWhenOneIsAvailable() { + s.runnerStorage.Add(s.runner) + sampledRunner, ok := s.runnerStorage.Sample() + s.NotNil(sampledRunner) + s.True(ok) } -func (suite *RunnerPoolTestSuite) TestSampleReturnsFalseWhenNoneIsAvailable() { - sampledRunner, ok := suite.runnerStorage.Sample() - suite.Nil(sampledRunner) - suite.False(ok) +func (s *RunnerPoolTestSuite) TestSampleReturnsFalseWhenNoneIsAvailable() { + sampledRunner, ok := s.runnerStorage.Sample() + s.Nil(sampledRunner) + s.False(ok) } -func (suite *RunnerPoolTestSuite) TestSampleRemovesRunnerFromPool() { - suite.runnerStorage.Add(suite.runner) - sampledRunner, _ := suite.runnerStorage.Sample() - _, ok := suite.runnerStorage.Get(sampledRunner.Id()) - suite.False(ok) +func (s *RunnerPoolTestSuite) TestSampleRemovesRunnerFromPool() { + s.runnerStorage.Add(s.runner) + sampledRunner, _ := s.runnerStorage.Sample() + _, ok := s.runnerStorage.Get(sampledRunner.ID()) + s.False(ok) } -func (suite *RunnerPoolTestSuite) TestLenOfEmptyPoolIsZero() { - suite.Equal(0, suite.runnerStorage.Length()) +func (s *RunnerPoolTestSuite) TestLenOfEmptyPoolIsZero() { + s.Equal(0, s.runnerStorage.Length()) } -func (suite *RunnerPoolTestSuite) TestLenChangesOnStoreContentChange() { - suite.Run("len increases when runner is added", func() { - suite.runnerStorage.Add(suite.runner) - suite.Equal(1, suite.runnerStorage.Length()) +func (s *RunnerPoolTestSuite) TestLenChangesOnStoreContentChange() { + s.Run("len increases when runner is added", func() { + s.runnerStorage.Add(s.runner) + s.Equal(1, s.runnerStorage.Length()) }) - suite.Run("len does not increase when runner with same id is added", func() { - suite.runnerStorage.Add(suite.runner) - suite.Equal(1, suite.runnerStorage.Length()) + s.Run("len does not increase when runner with same id is added", func() { + s.runnerStorage.Add(s.runner) + s.Equal(1, s.runnerStorage.Length()) }) - suite.Run("len increases again when different runner is added", func() { + s.Run("len increases again when different runner is added", func() { anotherRunner := NewRunner(tests.AnotherRunnerID, nil) - suite.runnerStorage.Add(anotherRunner) - suite.Equal(2, suite.runnerStorage.Length()) + s.runnerStorage.Add(anotherRunner) + s.Equal(2, s.runnerStorage.Length()) }) - suite.Run("len decreases when runner is deleted", func() { - suite.runnerStorage.Delete(suite.runner.Id()) - suite.Equal(1, suite.runnerStorage.Length()) + s.Run("len decreases when runner is deleted", func() { + s.runnerStorage.Delete(s.runner.ID()) + s.Equal(1, s.runnerStorage.Length()) }) - suite.Run("len decreases when runner is sampled", func() { - _, _ = suite.runnerStorage.Sample() - suite.Equal(0, suite.runnerStorage.Length()) + s.Run("len decreases when runner is sampled", func() { + _, _ = s.runnerStorage.Sample() + s.Equal(0, s.runnerStorage.Length()) }) } diff --git a/tests/e2e/e2e_test.go b/tests/e2e/e2e_test.go index 0567df7..c259ee0 100644 --- a/tests/e2e/e2e_test.go +++ b/tests/e2e/e2e_test.go @@ -85,7 +85,7 @@ func createDefaultEnvironment() { ExposedPorts: nil, } - resp, err := helpers.HttpPutJSON(path, request) + resp, err := helpers.HTTPPutJSON(path, request) if err != nil || resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusNoContent { log.WithError(err).Fatal("Couldn't create default environment for e2e tests") } diff --git a/tests/e2e/environments_test.go b/tests/e2e/environments_test.go index 24b8274..f8711af 100644 --- a/tests/e2e/environments_test.go +++ b/tests/e2e/environments_test.go @@ -23,7 +23,7 @@ func TestCreateOrUpdateEnvironment(t *testing.T) { path := helpers.BuildURL(api.BasePath, api.EnvironmentsPath, tests.AnotherEnvironmentIDAsString) t.Run("returns bad request with empty body", func(t *testing.T) { - resp, err := helpers.HttpPut(path, strings.NewReader("")) + resp, err := helpers.HTTPPut(path, strings.NewReader("")) require.Nil(t, err) assert.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -91,7 +91,7 @@ func cleanupJobsForEnvironment(t *testing.T, environmentID string) { func assertPutReturnsStatusAndZeroContent(t *testing.T, path string, request dto.ExecutionEnvironmentRequest, status int) { t.Helper() - resp, err := helpers.HttpPutJSON(path, request) + resp, err := helpers.HTTPPutJSON(path, request) require.Nil(t, err) assert.Equal(t, status, resp.StatusCode) assert.Equal(t, int64(0), resp.ContentLength) diff --git a/tests/e2e/runners_test.go b/tests/e2e/runners_test.go index d606f13..6163a8e 100644 --- a/tests/e2e/runners_test.go +++ b/tests/e2e/runners_test.go @@ -16,9 +16,10 @@ import ( ) func (s *E2ETestSuite) TestProvideRunnerRoute() { - runnerRequestByteString, _ := json.Marshal(dto.RunnerRequest{ - ExecutionEnvironmentId: tests.DefaultEnvironmentIDAsInteger, + runnerRequestByteString, err := json.Marshal(dto.RunnerRequest{ + ExecutionEnvironmentID: tests.DefaultEnvironmentIDAsInteger, }) + s.Require().NoError(err) reader := bytes.NewReader(runnerRequestByteString) s.Run("valid request returns a runner", func() { @@ -29,7 +30,7 @@ func (s *E2ETestSuite) TestProvideRunnerRoute() { runnerResponse := new(dto.RunnerResponse) err = json.NewDecoder(resp.Body).Decode(runnerResponse) s.Require().NoError(err) - s.NotEmpty(runnerResponse.Id) + s.NotEmpty(runnerResponse.ID) }) s.Run("invalid request returns bad request", func() { @@ -39,9 +40,10 @@ func (s *E2ETestSuite) TestProvideRunnerRoute() { }) s.Run("requesting runner of unknown execution environment returns not found", func() { - runnerRequestByteString, _ := json.Marshal(dto.RunnerRequest{ - ExecutionEnvironmentId: tests.NonExistingIntegerID, + runnerRequestByteString, err := json.Marshal(dto.RunnerRequest{ + ExecutionEnvironmentID: tests.NonExistingIntegerID, }) + s.Require().NoError(err) reader := bytes.NewReader(runnerRequestByteString) resp, err := http.Post(helpers.BuildURL(api.BasePath, api.RunnersPath), "application/json", reader) s.Require().NoError(err) @@ -53,13 +55,17 @@ func (s *E2ETestSuite) TestProvideRunnerRoute() { // It needs a running Poseidon instance to work. func ProvideRunner(request *dto.RunnerRequest) (string, error) { url := helpers.BuildURL(api.BasePath, api.RunnersPath) - runnerRequestByteString, _ := json.Marshal(request) + runnerRequestByteString, err := json.Marshal(request) + if err != nil { + return "", err + } reader := strings.NewReader(string(runnerRequestByteString)) - resp, err := http.Post(url, "application/json", reader) + resp, err := http.Post(url, "application/json", reader) //nolint:gosec // url is not influenced by a user if err != nil { return "", err } if resp.StatusCode != http.StatusOK { + //nolint:goerr113 // dynamic error is ok in here, as it is a test return "", fmt.Errorf("expected response code 200 when getting runner, got %v", resp.StatusCode) } runnerResponse := new(dto.RunnerResponse) @@ -67,44 +73,47 @@ func ProvideRunner(request *dto.RunnerRequest) (string, error) { if err != nil { return "", err } - return runnerResponse.Id, nil + return runnerResponse.ID, nil } func (s *E2ETestSuite) TestDeleteRunnerRoute() { - runnerId, err := ProvideRunner(&dto.RunnerRequest{ - ExecutionEnvironmentId: tests.DefaultEnvironmentIDAsInteger, + runnerID, err := ProvideRunner(&dto.RunnerRequest{ + ExecutionEnvironmentID: tests.DefaultEnvironmentIDAsInteger, }) s.NoError(err) s.Run("Deleting the runner returns NoContent", func() { - resp, err := helpers.HttpDelete(helpers.BuildURL(api.BasePath, api.RunnersPath, runnerId), nil) + resp, err := helpers.HTTPDelete(helpers.BuildURL(api.BasePath, api.RunnersPath, runnerID), nil) s.NoError(err) s.Equal(http.StatusNoContent, resp.StatusCode) }) s.Run("Deleting it again returns NotFound", func() { - resp, err := helpers.HttpDelete(helpers.BuildURL(api.BasePath, api.RunnersPath, runnerId), nil) + resp, err := helpers.HTTPDelete(helpers.BuildURL(api.BasePath, api.RunnersPath, runnerID), nil) s.NoError(err) s.Equal(http.StatusNotFound, resp.StatusCode) }) s.Run("Deleting non-existing runner returns NotFound", func() { - resp, err := helpers.HttpDelete(helpers.BuildURL(api.BasePath, api.RunnersPath, tests.NonExistingStringID), nil) + resp, err := helpers.HTTPDelete(helpers.BuildURL(api.BasePath, api.RunnersPath, tests.NonExistingStringID), nil) s.NoError(err) s.Equal(http.StatusNotFound, resp.StatusCode) }) } +//nolint:funlen // there are a lot of tests for the files route, this function can be a little longer than 100 lines ;) func (s *E2ETestSuite) TestCopyFilesRoute() { runnerID, err := ProvideRunner(&dto.RunnerRequest{ - ExecutionEnvironmentId: tests.DefaultEnvironmentIDAsInteger, + ExecutionEnvironmentID: tests.DefaultEnvironmentIDAsInteger, }) s.NoError(err) - copyFilesRequestByteString, _ := json.Marshal(&dto.UpdateFileSystemRequest{ + copyFilesRequestByteString, err := json.Marshal(&dto.UpdateFileSystemRequest{ Copy: []dto.File{{Path: tests.DefaultFileName, Content: []byte(tests.DefaultFileContent)}}, }) + s.Require().NoError(err) sendCopyRequest := func(reader io.Reader) (*http.Response, error) { - return helpers.HttpPatch(helpers.BuildURL(api.BasePath, api.RunnersPath, runnerID, api.UpdateFileSystemPath), "application/json", reader) + return helpers.HTTPPatch(helpers.BuildURL(api.BasePath, api.RunnersPath, runnerID, api.UpdateFileSystemPath), + "application/json", reader) } s.Run("File copy with valid payload succeeds", func() { @@ -122,12 +131,13 @@ func (s *E2ETestSuite) TestCopyFilesRoute() { relativeFileContent := "Relative file content" absoluteFilePath := "/tmp/absolute/file/path.txt" absoluteFileContent := "Absolute file content" - testFilePathsCopyRequestString, _ := json.Marshal(&dto.UpdateFileSystemRequest{ + testFilePathsCopyRequestString, err := json.Marshal(&dto.UpdateFileSystemRequest{ Copy: []dto.File{ {Path: dto.FilePath(relativeFilePath), Content: []byte(relativeFileContent)}, {Path: dto.FilePath(absoluteFilePath), Content: []byte(absoluteFileContent)}, }, }) + s.Require().NoError(err) resp, err := sendCopyRequest(bytes.NewReader(testFilePathsCopyRequestString)) s.NoError(err) @@ -144,9 +154,10 @@ func (s *E2ETestSuite) TestCopyFilesRoute() { }) s.Run("File deletion request deletes file on runner", func() { - copyFilesRequestByteString, _ := json.Marshal(&dto.UpdateFileSystemRequest{ + copyFilesRequestByteString, err := json.Marshal(&dto.UpdateFileSystemRequest{ Delete: []dto.FilePath{tests.DefaultFileName}, }) + s.Require().NoError(err) resp, err := sendCopyRequest(bytes.NewReader(copyFilesRequestByteString)) s.NoError(err) @@ -160,10 +171,11 @@ func (s *E2ETestSuite) TestCopyFilesRoute() { }) s.Run("File copy happens after file deletion", func() { - copyFilesRequestByteString, _ := json.Marshal(&dto.UpdateFileSystemRequest{ + copyFilesRequestByteString, err := json.Marshal(&dto.UpdateFileSystemRequest{ Delete: []dto.FilePath{tests.DefaultFileName}, Copy: []dto.File{{Path: tests.DefaultFileName, Content: []byte(tests.DefaultFileContent)}}, }) + s.Require().NoError(err) resp, err := sendCopyRequest(bytes.NewReader(copyFilesRequestByteString)) s.NoError(err) @@ -177,12 +189,13 @@ func (s *E2ETestSuite) TestCopyFilesRoute() { s.Run("If one file produces permission denied error, others are still copied", func() { newFileContent := []byte("New content") - copyFilesRequestByteString, _ := json.Marshal(&dto.UpdateFileSystemRequest{ + copyFilesRequestByteString, err := json.Marshal(&dto.UpdateFileSystemRequest{ Copy: []dto.File{ {Path: "/dev/sda", Content: []byte(tests.DefaultFileContent)}, {Path: tests.DefaultFileName, Content: newFileContent}, }, }) + s.Require().NoError(err) resp, err := sendCopyRequest(bytes.NewReader(copyFilesRequestByteString)) s.NoError(err) @@ -199,13 +212,16 @@ func (s *E2ETestSuite) TestCopyFilesRoute() { }) s.Run("File copy with invalid payload returns bad request", func() { - resp, err := helpers.HttpPatch(helpers.BuildURL(api.BasePath, api.RunnersPath, runnerID, api.UpdateFileSystemPath), "text/html", strings.NewReader("")) + resp, err := helpers.HTTPPatch(helpers.BuildURL(api.BasePath, api.RunnersPath, runnerID, api.UpdateFileSystemPath), + "text/html", strings.NewReader("")) s.NoError(err) s.Equal(http.StatusBadRequest, resp.StatusCode) }) s.Run("Copying to non-existing runner returns NotFound", func() { - resp, err := helpers.HttpPatch(helpers.BuildURL(api.BasePath, api.RunnersPath, tests.NonExistingStringID, api.UpdateFileSystemPath), "application/json", bytes.NewReader(copyFilesRequestByteString)) + resp, err := helpers.HTTPPatch( + helpers.BuildURL(api.BasePath, api.RunnersPath, tests.NonExistingStringID, api.UpdateFileSystemPath), + "application/json", bytes.NewReader(copyFilesRequestByteString)) s.NoError(err) s.Equal(http.StatusNotFound, resp.StatusCode) }) @@ -214,7 +230,7 @@ func (s *E2ETestSuite) TestCopyFilesRoute() { func (s *E2ETestSuite) TestRunnerGetsDestroyedAfterInactivityTimeout() { inactivityTimeout := 5 // seconds runnerID, err := ProvideRunner(&dto.RunnerRequest{ - ExecutionEnvironmentId: tests.DefaultEnvironmentIDAsInteger, + ExecutionEnvironmentID: tests.DefaultEnvironmentIDAsInteger, InactivityTimeout: inactivityTimeout, }) s.Require().NoError(err) @@ -240,20 +256,23 @@ func (s *E2ETestSuite) TestRunnerGetsDestroyedAfterInactivityTimeout() { s.Equal(dto.WebSocketMetaTimeout, lastMessage.Type) } -func (s *E2ETestSuite) assertFileContent(runnerID, fileName string, expectedContent string) { +func (s *E2ETestSuite) assertFileContent(runnerID, fileName, expectedContent string) { stdout, stderr := s.PrintContentOfFileOnRunner(runnerID, fileName) s.Equal(expectedContent, stdout) s.Equal("", stderr) } -func (s *E2ETestSuite) PrintContentOfFileOnRunner(runnerId string, filename string) (string, string) { - webSocketURL, _ := ProvideWebSocketURL(&s.Suite, runnerId, &dto.ExecutionRequest{Command: fmt.Sprintf("cat %s", filename)}) - connection, _ := ConnectToWebSocket(webSocketURL) +func (s *E2ETestSuite) PrintContentOfFileOnRunner(runnerID, filename string) (stdout, stderr string) { + webSocketURL, err := ProvideWebSocketURL(&s.Suite, runnerID, + &dto.ExecutionRequest{Command: fmt.Sprintf("cat %s", filename)}) + s.Require().NoError(err) + connection, err := ConnectToWebSocket(webSocketURL) + s.Require().NoError(err) messages, err := helpers.ReceiveAllWebSocketMessages(connection) s.Require().Error(err) s.Equal(&websocket.CloseError{Code: websocket.CloseNormalClosure}, err) - stdout, stderr, _ := helpers.WebSocketOutputMessages(messages) + stdout, stderr, _ = helpers.WebSocketOutputMessages(messages) return stdout, stderr } diff --git a/tests/e2e/websocket_test.go b/tests/e2e/websocket_test.go index 5204dc8..c5d338f 100644 --- a/tests/e2e/websocket_test.go +++ b/tests/e2e/websocket_test.go @@ -18,12 +18,12 @@ import ( ) func (s *E2ETestSuite) TestExecuteCommandRoute() { - runnerId, err := ProvideRunner(&dto.RunnerRequest{ - ExecutionEnvironmentId: tests.DefaultEnvironmentIDAsInteger, + runnerID, err := ProvideRunner(&dto.RunnerRequest{ + ExecutionEnvironmentID: tests.DefaultEnvironmentIDAsInteger, }) s.Require().NoError(err) - webSocketURL, err := ProvideWebSocketURL(&s.Suite, runnerId, &dto.ExecutionRequest{Command: "true"}) + webSocketURL, err := ProvideWebSocketURL(&s.Suite, runnerID, &dto.ExecutionRequest{Command: "true"}) s.Require().NoError(err) s.NotEqual("", webSocketURL) @@ -50,7 +50,8 @@ func (s *E2ETestSuite) TestExecuteCommandRoute() { s.Require().Error(err) s.Equal(&websocket.CloseError{Code: websocket.CloseNormalClosure}, err) - _, _, _ = connection.ReadMessage() + _, _, err = connection.ReadMessage() + s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure)) s.True(connectionClosed, "connection should be closed") } @@ -150,7 +151,7 @@ func (s *E2ETestSuite) TestEchoEnvironment() { func (s *E2ETestSuite) TestStderrFifoIsRemoved() { runnerID, err := ProvideRunner(&dto.RunnerRequest{ - ExecutionEnvironmentId: tests.DefaultEnvironmentIDAsInteger, + ExecutionEnvironmentID: tests.DefaultEnvironmentIDAsInteger, }) s.Require().NoError(err) @@ -188,35 +189,39 @@ func (s *E2ETestSuite) ListTempDirectory(runnerID string) string { // ProvideWebSocketConnection establishes a client WebSocket connection to run the passed ExecutionRequest. // It requires a running Poseidon instance. -func ProvideWebSocketConnection(suite *suite.Suite, request *dto.ExecutionRequest) (connection *websocket.Conn, err error) { - runnerId, err := ProvideRunner(&dto.RunnerRequest{ - ExecutionEnvironmentId: tests.DefaultEnvironmentIDAsInteger, +func ProvideWebSocketConnection(s *suite.Suite, request *dto.ExecutionRequest) (*websocket.Conn, error) { + runnerID, err := ProvideRunner(&dto.RunnerRequest{ + ExecutionEnvironmentID: tests.DefaultEnvironmentIDAsInteger, }) if err != nil { - return + return nil, fmt.Errorf("error providing runner: %w", err) } - webSocketURL, err := ProvideWebSocketURL(suite, runnerId, request) + webSocketURL, err := ProvideWebSocketURL(s, runnerID, request) if err != nil { - return + return nil, fmt.Errorf("error providing WebSocket URL: %w", err) } - connection, err = ConnectToWebSocket(webSocketURL) - return + connection, err := ConnectToWebSocket(webSocketURL) + if err != nil { + return nil, fmt.Errorf("error connecting to WebSocket: %w", err) + } + return connection, nil } // ProvideWebSocketURL creates a WebSocket endpoint from the ExecutionRequest via an external api request. // It requires a running Poseidon instance. -func ProvideWebSocketURL(suite *suite.Suite, runnerId string, request *dto.ExecutionRequest) (string, error) { - url := helpers.BuildURL(api.BasePath, api.RunnersPath, runnerId, api.ExecutePath) - executionRequestByteString, _ := json.Marshal(request) +func ProvideWebSocketURL(s *suite.Suite, runnerID string, request *dto.ExecutionRequest) (string, error) { + url := helpers.BuildURL(api.BasePath, api.RunnersPath, runnerID, api.ExecutePath) + executionRequestByteString, err := json.Marshal(request) + s.Require().NoError(err) reader := strings.NewReader(string(executionRequestByteString)) - resp, err := http.Post(url, "application/json", reader) - suite.Require().NoError(err) - suite.Require().Equal(http.StatusOK, resp.StatusCode) + resp, err := http.Post(url, "application/json", reader) //nolint:gosec // url is not influenced by a user + s.Require().NoError(err) + s.Require().Equal(http.StatusOK, resp.StatusCode) executionResponse := new(dto.ExecutionResponse) err = json.NewDecoder(resp.Body).Decode(executionResponse) - suite.Require().NoError(err) - return executionResponse.WebSocketUrl, nil + s.Require().NoError(err) + return executionResponse.WebSocketURL, nil } // ConnectToWebSocket establish an external WebSocket connection to the provided url. diff --git a/tests/helpers/test_helpers.go b/tests/helpers/test_helpers.go index cf5146f..9569fe8 100644 --- a/tests/helpers/test_helpers.go +++ b/tests/helpers/test_helpers.go @@ -6,6 +6,7 @@ import ( "bytes" "crypto/tls" "encoding/json" + "fmt" "github.com/gorilla/mux" "github.com/gorilla/websocket" nomadApi "github.com/hashicorp/nomad/api" @@ -22,15 +23,15 @@ import ( ) // BuildURL joins multiple route paths. -func BuildURL(parts ...string) (url string) { - url = config.Config.PoseidonAPIURL().String() +func BuildURL(parts ...string) string { + url := config.Config.PoseidonAPIURL().String() for _, part := range parts { if !strings.HasPrefix(part, "/") { url += "/" } url += part } - return + return url } // WebSocketOutputMessages extracts all stdout, stderr and error messages from the passed messages. @@ -46,7 +47,7 @@ func WebSocketOutputMessages(messages []*dto.WebSocketMessage) (stdout, stderr s errors = append(errors, msg.Data) } } - return + return stdout, stderr, errors } // WebSocketControlMessages extracts all meta (and exit) messages from the passed messages. @@ -57,11 +58,12 @@ func WebSocketControlMessages(messages []*dto.WebSocketMessage) (controls []*dto controls = append(controls, msg) } } - return + return controls } // ReceiveAllWebSocketMessages pulls all messages from the websocket connection without sending anything. -// This function does not return unless the server closes the connection or a readDeadline is set in the WebSocket connection. +// This function does not return unless the server closes the connection or a readDeadline is set +// in the WebSocket connection. func ReceiveAllWebSocketMessages(connection *websocket.Conn) (messages []*dto.WebSocketMessage, err error) { for { var message *dto.WebSocketMessage @@ -74,71 +76,99 @@ func ReceiveAllWebSocketMessages(connection *websocket.Conn) (messages []*dto.We } // ReceiveNextWebSocketMessage pulls the next message from the websocket connection. -// This function does not return unless the server sends a message, closes the connection or a readDeadline is set in the WebSocket connection. +// This function does not return unless the server sends a message, closes the connection or a readDeadline +// is set in the WebSocket connection. func ReceiveNextWebSocketMessage(connection *websocket.Conn) (*dto.WebSocketMessage, error) { _, reader, err := connection.NextReader() if err != nil { + //nolint:wrapcheck // we could either wrap here and do complicated things with errors.As or just not wrap + // the error in this test function and allow tests to use equal return nil, err } message := new(dto.WebSocketMessage) err = json.NewDecoder(reader).Decode(message) if err != nil { - return nil, err + return nil, fmt.Errorf("error decoding WebSocket message: %w", err) } return message, nil } // StartTLSServer runs a httptest.Server with the passed mux.Router and TLS enabled. -func StartTLSServer(t *testing.T, router *mux.Router) (server *httptest.Server, err error) { +func StartTLSServer(t *testing.T, router *mux.Router) (*httptest.Server, error) { + t.Helper() dir := t.TempDir() keyOut := filepath.Join(dir, "poseidon-test.key") certOut := filepath.Join(dir, "poseidon-test.crt") - err = exec.Command("openssl", "req", "-x509", "-nodes", "-newkey", "rsa:2048", + err := exec.Command("openssl", "req", "-x509", "-nodes", "-newkey", "rsa:2048", "-keyout", keyOut, "-out", certOut, "-days", "1", "-subj", "/CN=Poseidon test", "-addext", "subjectAltName=IP:127.0.0.1,DNS:localhost").Run() if err != nil { - return nil, err + return nil, fmt.Errorf("error creating self-signed cert: %w", err) } cert, err := tls.LoadX509KeyPair(certOut, keyOut) if err != nil { - return nil, err + return nil, fmt.Errorf("error loading x509 key pair: %w", err) } - server = httptest.NewUnstartedServer(router) - server.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} + server := httptest.NewUnstartedServer(router) + server.TLS = &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS13} server.StartTLS() - return + return server, nil } -// HttpDelete sends a Delete Http Request with body to the passed url. -func HttpDelete(url string, body io.Reader) (response *http.Response, err error) { - req, _ := http.NewRequest(http.MethodDelete, url, body) +// HTTPDelete sends a Delete Http Request with body to the passed url. +func HTTPDelete(url string, body io.Reader) (response *http.Response, err error) { + //nolint:noctx // we don't need a http.NewRequestWithContext in our tests + req, err := http.NewRequest(http.MethodDelete, url, body) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } client := &http.Client{} - return client.Do(req) + response, err = client.Do(req) + if err != nil { + return nil, fmt.Errorf("error executing request: %w", err) + } + return response, nil } -// HttpPatch sends a Patch Http Request with body to the passed url. -func HttpPatch(url string, contentType string, body io.Reader) (response *http.Response, err error) { - req, _ := http.NewRequest(http.MethodPatch, url, body) +// HTTPPatch sends a Patch Http Request with body to the passed url. +func HTTPPatch(url, contentType string, body io.Reader) (response *http.Response, err error) { + //nolint:noctx // we don't need a http.NewRequestWithContext in our tests + req, err := http.NewRequest(http.MethodPatch, url, body) + if err != nil { + return nil, fmt.Errorf("error creating new request: %w", err) + } req.Header.Set("Content-Type", contentType) client := &http.Client{} - return client.Do(req) + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("error executing request: %w", err) + } + return resp, nil } -func HttpPut(url string, body io.Reader) (response *http.Response, err error) { - req, _ := http.NewRequest(http.MethodPut, url, body) +func HTTPPut(url string, body io.Reader) (response *http.Response, err error) { + //nolint:noctx // we don't need a http.NewRequestWithContext in our tests + req, err := http.NewRequest(http.MethodPut, url, body) + if err != nil { + return nil, fmt.Errorf("error creating new request: %w", err) + } client := &http.Client{} - return client.Do(req) + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("error executing request: %w", err) + } + return resp, nil } -func HttpPutJSON(url string, body interface{}) (response *http.Response, err error) { +func HTTPPutJSON(url string, body interface{}) (response *http.Response, err error) { requestByteString, err := json.Marshal(body) if err != nil { return } reader := bytes.NewReader(requestByteString) - return HttpPut(url, reader) + return HTTPPut(url, reader) } func CreateTemplateJob() (base, job *nomadApi.Job) { diff --git a/util/util_test.go b/util/util_test.go index fb7fce8..dace55d 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -2,6 +2,7 @@ package util import ( "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gitlab.hpi.de/codeocean/codemoon/poseidon/tests" "testing" ) @@ -10,8 +11,9 @@ func TestNullReaderDoesNotReturnImmediately(t *testing.T) { reader := &NullReader{} readerReturned := make(chan bool) go func() { - p := make([]byte, 5) - _, _ = reader.Read(p) + p := make([]byte, 0, 5) + _, err := reader.Read(p) + require.NoError(t, err) close(readerReturned) }() assert.False(t, tests.ChannelReceivesSomething(readerReturned, tests.ShortTimeout))