Restructure project
We previously didn't really had any structure in our project apart from creating a new folder for each package in our project root. Now that we have accumulated some packages, we use the well-known Golang project layout in order to clearly communicate our intent with packages. See https://github.com/golang-standards/project-layout
This commit is contained in:
57
internal/api/api.go
Normal file
57
internal/api/api.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/gorilla/mux"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/api/auth"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/environment"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/logging"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger("api")
|
||||
|
||||
const (
|
||||
BasePath = "/api/v1"
|
||||
HealthPath = "/health"
|
||||
RunnersPath = "/runners"
|
||||
EnvironmentsPath = "/execution-environments"
|
||||
)
|
||||
|
||||
// NewRouter returns a *mux.Router which can be
|
||||
// used by the net/http package to serve the routes of our API. It
|
||||
// always returns a router for the newest version of our API. We
|
||||
// use gorilla/mux because it is more convenient than net/http, e.g.
|
||||
// when extracting path parameters.
|
||||
func NewRouter(runnerManager runner.Manager, environmentManager environment.Manager) *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
// this can later be restricted to a specific host with
|
||||
// `router.Host(...)` and to HTTPS with `router.Schemes("https")`
|
||||
configureV1Router(router, runnerManager, environmentManager)
|
||||
router.Use(logging.HTTPLoggingMiddleware)
|
||||
return router
|
||||
}
|
||||
|
||||
// configureV1Router configures a given router with the routes of version 1 of the Poseidon API.
|
||||
func configureV1Router(router *mux.Router, runnerManager runner.Manager, environmentManager environment.Manager) {
|
||||
v1 := router.PathPrefix(BasePath).Subrouter()
|
||||
v1.HandleFunc(HealthPath, Health).Methods(http.MethodGet)
|
||||
|
||||
runnerController := &RunnerController{manager: runnerManager}
|
||||
environmentController := &EnvironmentController{manager: environmentManager}
|
||||
|
||||
configureRoutes := func(router *mux.Router) {
|
||||
runnerController.ConfigureRoutes(router)
|
||||
environmentController.ConfigureRoutes(router)
|
||||
}
|
||||
|
||||
if auth.InitializeAuthentication() {
|
||||
// Create new authenticated subrouter.
|
||||
// All routes added to v1 after this require authentication.
|
||||
authenticatedV1Router := v1.PathPrefix("").Subrouter()
|
||||
authenticatedV1Router.Use(auth.HTTPAuthenticationMiddleware)
|
||||
runnerController.ConfigureRoutes(authenticatedV1Router)
|
||||
} else {
|
||||
configureRoutes(v1)
|
||||
}
|
||||
}
|
70
internal/api/api_test.go
Normal file
70
internal/api/api_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/config"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mockHTTPHandler(writer http.ResponseWriter, _ *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func TestNewRouterV1WithAuthenticationDisabled(t *testing.T) {
|
||||
config.Config.Server.Token = ""
|
||||
router := mux.NewRouter()
|
||||
configureV1Router(router, nil, nil)
|
||||
|
||||
t.Run("health route is accessible", func(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "/api/v1/health", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, request)
|
||||
assert.Equal(t, http.StatusNoContent, recorder.Code)
|
||||
})
|
||||
|
||||
t.Run("added route is accessible", func(t *testing.T) {
|
||||
router.HandleFunc("/api/v1/test", mockHTTPHandler)
|
||||
request, err := http.NewRequest(http.MethodGet, "/api/v1/test", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, request)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewRouterV1WithAuthenticationEnabled(t *testing.T) {
|
||||
config.Config.Server.Token = "TestToken"
|
||||
router := mux.NewRouter()
|
||||
configureV1Router(router, nil, nil)
|
||||
|
||||
t.Run("health route is accessible", func(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "/api/v1/health", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, request)
|
||||
assert.Equal(t, http.StatusNoContent, recorder.Code)
|
||||
})
|
||||
|
||||
t.Run("protected route is not accessible", func(t *testing.T) {
|
||||
// request an available API route that should be guarded by authentication.
|
||||
// (which one, in particular, does not matter here)
|
||||
request, err := http.NewRequest(http.MethodPost, "/api/v1/runners", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, request)
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
})
|
||||
config.Config.Server.Token = ""
|
||||
}
|
36
internal/api/auth/auth.go
Normal file
36
internal/api/auth/auth.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/config"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/logging"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var log = logging.GetLogger("api/auth")
|
||||
|
||||
const TokenHeader = "X-Poseidon-Token"
|
||||
|
||||
var correctAuthenticationToken []byte
|
||||
|
||||
// InitializeAuthentication returns true iff the authentication is initialized successfully and can be used.
|
||||
func InitializeAuthentication() bool {
|
||||
token := config.Config.Server.Token
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
correctAuthenticationToken = []byte(token)
|
||||
return true
|
||||
}
|
||||
|
||||
func HTTPAuthenticationMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get(TokenHeader)
|
||||
if subtle.ConstantTimeCompare([]byte(token), correctAuthenticationToken) == 0 {
|
||||
log.WithField("token", token).Warn("Incorrect token")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
91
internal/api/auth/auth_test.go
Normal file
91
internal/api/auth/auth_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/config"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const testToken = "C0rr3ctT0k3n"
|
||||
|
||||
type AuthenticationMiddlewareTestSuite struct {
|
||||
suite.Suite
|
||||
request *http.Request
|
||||
recorder *httptest.ResponseRecorder
|
||||
httpAuthenticationMiddleware http.Handler
|
||||
}
|
||||
|
||||
func (s *AuthenticationMiddlewareTestSuite) SetupTest() {
|
||||
correctAuthenticationToken = []byte(testToken)
|
||||
s.recorder = httptest.NewRecorder()
|
||||
request, err := http.NewRequest(http.MethodGet, "/api/v1/test", nil)
|
||||
if err != nil {
|
||||
s.T().Fatal(err)
|
||||
}
|
||||
s.request = request
|
||||
s.httpAuthenticationMiddleware = HTTPAuthenticationMiddleware(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *AuthenticationMiddlewareTestSuite) TearDownTest() {
|
||||
correctAuthenticationToken = []byte(nil)
|
||||
}
|
||||
|
||||
func (s *AuthenticationMiddlewareTestSuite) TestReturns401WhenHeaderUnset() {
|
||||
s.httpAuthenticationMiddleware.ServeHTTP(s.recorder, s.request)
|
||||
assert.Equal(s.T(), http.StatusUnauthorized, s.recorder.Code)
|
||||
}
|
||||
|
||||
func (s *AuthenticationMiddlewareTestSuite) TestReturns401WhenTokenWrong() {
|
||||
s.request.Header.Set(TokenHeader, "Wr0ngT0k3n")
|
||||
s.httpAuthenticationMiddleware.ServeHTTP(s.recorder, s.request)
|
||||
assert.Equal(s.T(), http.StatusUnauthorized, s.recorder.Code)
|
||||
}
|
||||
|
||||
func (s *AuthenticationMiddlewareTestSuite) TestWarnsWhenUnauthorized() {
|
||||
var hook *test.Hook
|
||||
logger, hook := test.NewNullLogger()
|
||||
log = logger.WithField("pkg", "api/auth")
|
||||
|
||||
s.request.Header.Set(TokenHeader, "Wr0ngT0k3n")
|
||||
s.httpAuthenticationMiddleware.ServeHTTP(s.recorder, s.request)
|
||||
|
||||
assert.Equal(s.T(), http.StatusUnauthorized, s.recorder.Code)
|
||||
assert.Equal(s.T(), logrus.WarnLevel, hook.LastEntry().Level)
|
||||
assert.Equal(s.T(), hook.LastEntry().Data["token"], "Wr0ngT0k3n")
|
||||
}
|
||||
|
||||
func (s *AuthenticationMiddlewareTestSuite) TestPassesWhenTokenCorrect() {
|
||||
s.request.Header.Set(TokenHeader, testToken)
|
||||
s.httpAuthenticationMiddleware.ServeHTTP(s.recorder, s.request)
|
||||
|
||||
assert.Equal(s.T(), http.StatusOK, s.recorder.Code)
|
||||
}
|
||||
|
||||
func TestHTTPAuthenticationMiddleware(t *testing.T) {
|
||||
suite.Run(t, new(AuthenticationMiddlewareTestSuite))
|
||||
}
|
||||
|
||||
func TestInitializeAuthentication(t *testing.T) {
|
||||
t.Run("if token unset", func(t *testing.T) {
|
||||
config.Config.Server.Token = ""
|
||||
initialized := InitializeAuthentication()
|
||||
assert.Equal(t, false, initialized)
|
||||
assert.Equal(t, []byte(nil), correctAuthenticationToken, "it should not set correctAuthenticationToken")
|
||||
})
|
||||
t.Run("if token set", func(t *testing.T) {
|
||||
config.Config.Server.Token = testToken
|
||||
initialized := InitializeAuthentication()
|
||||
assert.Equal(t, true, initialized)
|
||||
assert.Equal(t, []byte(testToken), correctAuthenticationToken, "it should set correctAuthenticationToken")
|
||||
config.Config.Server.Token = ""
|
||||
correctAuthenticationToken = []byte(nil)
|
||||
})
|
||||
}
|
59
internal/api/environments.go
Normal file
59
internal/api/environments.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/environment"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
executionEnvironmentIDKey = "executionEnvironmentId"
|
||||
createOrUpdateRouteName = "createOrUpdate"
|
||||
)
|
||||
|
||||
var ErrMissingURLParameter = errors.New("url parameter missing")
|
||||
|
||||
type EnvironmentController struct {
|
||||
manager environment.Manager
|
||||
}
|
||||
|
||||
func (e *EnvironmentController) ConfigureRoutes(router *mux.Router) {
|
||||
environmentRouter := router.PathPrefix(EnvironmentsPath).Subrouter()
|
||||
specificEnvironmentRouter := environmentRouter.Path(fmt.Sprintf("/{%s:[0-9]+}", executionEnvironmentIDKey)).Subrouter()
|
||||
specificEnvironmentRouter.HandleFunc("", e.createOrUpdate).Methods(http.MethodPut).Name(createOrUpdateRouteName)
|
||||
}
|
||||
|
||||
// createOrUpdate creates/updates an execution environment on the executor.
|
||||
func (e *EnvironmentController) createOrUpdate(writer http.ResponseWriter, request *http.Request) {
|
||||
req := new(dto.ExecutionEnvironmentRequest)
|
||||
if err := json.NewDecoder(request.Body).Decode(req); err != nil {
|
||||
writeBadRequest(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
id, ok := mux.Vars(request)[executionEnvironmentIDKey]
|
||||
if !ok {
|
||||
writeBadRequest(writer, fmt.Errorf("could not find %s: %w", executionEnvironmentIDKey, ErrMissingURLParameter))
|
||||
return
|
||||
}
|
||||
environmentID, err := runner.NewEnvironmentID(id)
|
||||
if err != nil {
|
||||
writeBadRequest(writer, fmt.Errorf("could not update environment: %w", err))
|
||||
return
|
||||
}
|
||||
created, err := e.manager.CreateOrUpdate(environmentID, *req)
|
||||
if err != nil {
|
||||
writeInternalServerError(writer, err, dto.ErrorUnknown)
|
||||
}
|
||||
|
||||
if created {
|
||||
writer.WriteHeader(http.StatusCreated)
|
||||
} else {
|
||||
writer.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
117
internal/api/environments_test.go
Normal file
117
internal/api/environments_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/environment"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/tests"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type EnvironmentControllerTestSuite struct {
|
||||
suite.Suite
|
||||
manager *environment.ManagerMock
|
||||
router *mux.Router
|
||||
}
|
||||
|
||||
func TestEnvironmentControllerTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(EnvironmentControllerTestSuite))
|
||||
}
|
||||
|
||||
func (s *EnvironmentControllerTestSuite) SetupTest() {
|
||||
s.manager = &environment.ManagerMock{}
|
||||
s.router = NewRouter(nil, s.manager)
|
||||
}
|
||||
|
||||
type CreateOrUpdateEnvironmentTestSuite struct {
|
||||
EnvironmentControllerTestSuite
|
||||
path string
|
||||
id runner.EnvironmentID
|
||||
body []byte
|
||||
}
|
||||
|
||||
func TestCreateOrUpdateEnvironmentTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(CreateOrUpdateEnvironmentTestSuite))
|
||||
}
|
||||
|
||||
func (s *CreateOrUpdateEnvironmentTestSuite) SetupTest() {
|
||||
s.EnvironmentControllerTestSuite.SetupTest()
|
||||
s.id = tests.DefaultEnvironmentIDAsInteger
|
||||
testURL, err := s.router.Get(createOrUpdateRouteName).URL(executionEnvironmentIDKey, strconv.Itoa(int(s.id)))
|
||||
if err != nil {
|
||||
s.T().Fatal(err)
|
||||
}
|
||||
s.path = testURL.String()
|
||||
s.body, err = json.Marshal(dto.ExecutionEnvironmentRequest{})
|
||||
if err != nil {
|
||||
s.T().Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CreateOrUpdateEnvironmentTestSuite) recordRequest() *httptest.ResponseRecorder {
|
||||
recorder := httptest.NewRecorder()
|
||||
request, err := http.NewRequest(http.MethodPut, s.path, bytes.NewReader(s.body))
|
||||
if err != nil {
|
||||
s.T().Fatal(err)
|
||||
}
|
||||
s.router.ServeHTTP(recorder, request)
|
||||
return recorder
|
||||
}
|
||||
|
||||
func (s *CreateOrUpdateEnvironmentTestSuite) TestReturnsBadRequestWhenBadBody() {
|
||||
s.body = []byte{}
|
||||
recorder := s.recordRequest()
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func (s *CreateOrUpdateEnvironmentTestSuite) TestReturnsInternalServerErrorWhenManagerReturnsError() {
|
||||
testError := tests.ErrDefault
|
||||
s.manager.
|
||||
On("CreateOrUpdate", s.id, mock.AnythingOfType("dto.ExecutionEnvironmentRequest")).
|
||||
Return(false, testError)
|
||||
|
||||
recorder := s.recordRequest()
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
s.Contains(recorder.Body.String(), testError.Error())
|
||||
}
|
||||
|
||||
func (s *CreateOrUpdateEnvironmentTestSuite) TestReturnsCreatedIfNewEnvironment() {
|
||||
s.manager.
|
||||
On("CreateOrUpdate", s.id, mock.AnythingOfType("dto.ExecutionEnvironmentRequest")).
|
||||
Return(true, nil)
|
||||
|
||||
recorder := s.recordRequest()
|
||||
s.Equal(http.StatusCreated, recorder.Code)
|
||||
}
|
||||
|
||||
func (s *CreateOrUpdateEnvironmentTestSuite) TestReturnsNoContentIfNotNewEnvironment() {
|
||||
s.manager.
|
||||
On("CreateOrUpdate", s.id, mock.AnythingOfType("dto.ExecutionEnvironmentRequest")).
|
||||
Return(false, nil)
|
||||
|
||||
recorder := s.recordRequest()
|
||||
s.Equal(http.StatusNoContent, recorder.Code)
|
||||
}
|
||||
|
||||
func (s *CreateOrUpdateEnvironmentTestSuite) TestReturnsNotFoundOnNonIntegerID() {
|
||||
s.path = strings.Join([]string{BasePath, EnvironmentsPath, "/", "invalid-id"}, "")
|
||||
recorder := s.recordRequest()
|
||||
s.Equal(http.StatusNotFound, recorder.Code)
|
||||
}
|
||||
|
||||
func (s *CreateOrUpdateEnvironmentTestSuite) TestFailsOnTooLargeID() {
|
||||
tooLargeIntStr := strconv.Itoa(math.MaxInt64) + "0"
|
||||
s.path = strings.Join([]string{BasePath, EnvironmentsPath, "/", tooLargeIntStr}, "")
|
||||
recorder := s.recordRequest()
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
}
|
12
internal/api/health.go
Normal file
12
internal/api/health.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Health handles the health route.
|
||||
// It responds that the server is alive.
|
||||
// If it is not, the response won't reach the client.
|
||||
func Health(writer http.ResponseWriter, _ *http.Request) {
|
||||
writer.WriteHeader(http.StatusNoContent)
|
||||
}
|
18
internal/api/health_test.go
Normal file
18
internal/api/health_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHealthRoute(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "/health", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
recorder := httptest.NewRecorder()
|
||||
http.HandlerFunc(Health).ServeHTTP(recorder, request)
|
||||
assert.Equal(t, http.StatusNoContent, recorder.Code)
|
||||
}
|
43
internal/api/helpers.go
Normal file
43
internal/api/helpers.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func writeInternalServerError(writer http.ResponseWriter, err error, errorCode dto.ErrorCode) {
|
||||
sendJSON(writer, &dto.InternalServerError{Message: err.Error(), ErrorCode: errorCode}, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
func writeBadRequest(writer http.ResponseWriter, err error) {
|
||||
sendJSON(writer, &dto.ClientError{Message: err.Error()}, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func writeNotFound(writer http.ResponseWriter, err error) {
|
||||
sendJSON(writer, &dto.ClientError{Message: err.Error()}, http.StatusNotFound)
|
||||
}
|
||||
|
||||
func sendJSON(writer http.ResponseWriter, content interface{}, httpStatusCode int) {
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
writer.WriteHeader(httpStatusCode)
|
||||
response, err := json.Marshal(content)
|
||||
if err != nil {
|
||||
// cannot produce infinite recursive loop, since json.Marshal of dto.InternalServerError won't return an error
|
||||
writeInternalServerError(writer, err, dto.ErrorUnknown)
|
||||
return
|
||||
}
|
||||
if _, err = writer.Write(response); err != nil {
|
||||
log.WithError(err).Error("Could not write JSON response")
|
||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSONRequestBody(writer http.ResponseWriter, request *http.Request, structure interface{}) error {
|
||||
if err := json.NewDecoder(request.Body).Decode(structure); err != nil {
|
||||
writeBadRequest(writer, err)
|
||||
return fmt.Errorf("error parsing JSON request body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
159
internal/api/runners.go
Normal file
159
internal/api/runners.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/config"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
ExecutePath = "/execute"
|
||||
WebsocketPath = "/websocket"
|
||||
UpdateFileSystemPath = "/files"
|
||||
DeleteRoute = "deleteRunner"
|
||||
RunnerIDKey = "runnerId"
|
||||
ExecutionIDKey = "executionID"
|
||||
ProvideRoute = "provideRunner"
|
||||
)
|
||||
|
||||
type RunnerController struct {
|
||||
manager runner.Manager
|
||||
runnerRouter *mux.Router
|
||||
}
|
||||
|
||||
// ConfigureRoutes configures a given router with the runner routes of our API.
|
||||
func (r *RunnerController) ConfigureRoutes(router *mux.Router) {
|
||||
runnersRouter := router.PathPrefix(RunnersPath).Subrouter()
|
||||
runnersRouter.HandleFunc("", r.provide).Methods(http.MethodPost).Name(ProvideRoute)
|
||||
r.runnerRouter = runnersRouter.PathPrefix(fmt.Sprintf("/{%s}", RunnerIDKey)).Subrouter()
|
||||
r.runnerRouter.Use(r.findRunnerMiddleware)
|
||||
r.runnerRouter.HandleFunc(UpdateFileSystemPath, r.updateFileSystem).Methods(http.MethodPatch).
|
||||
Name(UpdateFileSystemPath)
|
||||
r.runnerRouter.HandleFunc(ExecutePath, r.execute).Methods(http.MethodPost).Name(ExecutePath)
|
||||
r.runnerRouter.HandleFunc(WebsocketPath, r.connectToRunner).Methods(http.MethodGet).Name(WebsocketPath)
|
||||
r.runnerRouter.HandleFunc("", r.delete).Methods(http.MethodDelete).Name(DeleteRoute)
|
||||
}
|
||||
|
||||
// provide handles the provide runners API route.
|
||||
// It tries to respond with the id of a unused runner.
|
||||
// This runner is then reserved for future use.
|
||||
func (r *RunnerController) provide(writer http.ResponseWriter, request *http.Request) {
|
||||
runnerRequest := new(dto.RunnerRequest)
|
||||
if err := parseJSONRequestBody(writer, request, runnerRequest); err != nil {
|
||||
return
|
||||
}
|
||||
environmentID := runner.EnvironmentID(runnerRequest.ExecutionEnvironmentID)
|
||||
nextRunner, err := r.manager.Claim(environmentID, runnerRequest.InactivityTimeout)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case runner.ErrUnknownExecutionEnvironment:
|
||||
writeNotFound(writer, err)
|
||||
case runner.ErrNoRunnersAvailable:
|
||||
log.WithField("environment", environmentID).Warn("No runners available")
|
||||
writeInternalServerError(writer, err, dto.ErrorNomadOverload)
|
||||
default:
|
||||
writeInternalServerError(writer, err, dto.ErrorUnknown)
|
||||
}
|
||||
return
|
||||
}
|
||||
sendJSON(writer, &dto.RunnerResponse{ID: nextRunner.ID(), MappedPorts: nextRunner.MappedPorts()}, http.StatusOK)
|
||||
}
|
||||
|
||||
// updateFileSystem handles the files API route.
|
||||
// It takes an dto.UpdateFileSystemRequest and sends it to the runner for processing.
|
||||
func (r *RunnerController) updateFileSystem(writer http.ResponseWriter, request *http.Request) {
|
||||
fileCopyRequest := new(dto.UpdateFileSystemRequest)
|
||||
if err := parseJSONRequestBody(writer, request, fileCopyRequest); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
targetRunner, _ := runner.FromContext(request.Context())
|
||||
if err := targetRunner.UpdateFileSystem(fileCopyRequest); err != nil {
|
||||
log.WithError(err).Error("Could not perform the requested updateFileSystem.")
|
||||
writeInternalServerError(writer, err, dto.ErrorUnknown)
|
||||
return
|
||||
}
|
||||
|
||||
writer.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// execute handles the execute API route.
|
||||
// It takes an ExecutionRequest and stores it for a runner.
|
||||
// It returns a url to connect to for a websocket connection to this execution in the corresponding runner.
|
||||
func (r *RunnerController) execute(writer http.ResponseWriter, request *http.Request) {
|
||||
executionRequest := new(dto.ExecutionRequest)
|
||||
if err := parseJSONRequestBody(writer, request, executionRequest); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var scheme string
|
||||
if config.Config.Server.TLS {
|
||||
scheme = "wss"
|
||||
} else {
|
||||
scheme = "ws"
|
||||
}
|
||||
targetRunner, _ := runner.FromContext(request.Context())
|
||||
|
||||
path, err := r.runnerRouter.Get(WebsocketPath).URL(RunnerIDKey, targetRunner.ID())
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not create runner websocket URL.")
|
||||
writeInternalServerError(writer, err, dto.ErrorUnknown)
|
||||
return
|
||||
}
|
||||
newUUID, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not create execution id")
|
||||
writeInternalServerError(writer, err, dto.ErrorUnknown)
|
||||
return
|
||||
}
|
||||
id := runner.ExecutionID(newUUID.String())
|
||||
targetRunner.Add(id, executionRequest)
|
||||
webSocketURL := url.URL{
|
||||
Scheme: scheme,
|
||||
Host: request.Host,
|
||||
Path: path.String(),
|
||||
RawQuery: fmt.Sprintf("%s=%s", ExecutionIDKey, id),
|
||||
}
|
||||
|
||||
sendJSON(writer, &dto.ExecutionResponse{WebSocketURL: webSocketURL.String()}, http.StatusOK)
|
||||
}
|
||||
|
||||
// The findRunnerMiddleware looks up the runnerId for routes containing it
|
||||
// and adds the runner to the context of the request.
|
||||
func (r *RunnerController) findRunnerMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
runnerID := mux.Vars(request)[RunnerIDKey]
|
||||
targetRunner, err := r.manager.Get(runnerID)
|
||||
if err != nil {
|
||||
writeNotFound(writer, err)
|
||||
return
|
||||
}
|
||||
ctx := runner.NewContext(request.Context(), targetRunner)
|
||||
requestWithRunner := request.WithContext(ctx)
|
||||
next.ServeHTTP(writer, requestWithRunner)
|
||||
})
|
||||
}
|
||||
|
||||
// delete handles the delete runner API route.
|
||||
// It destroys the given runner on the executor and removes it from the used runners list.
|
||||
func (r *RunnerController) delete(writer http.ResponseWriter, request *http.Request) {
|
||||
targetRunner, _ := runner.FromContext(request.Context())
|
||||
|
||||
err := r.manager.Return(targetRunner)
|
||||
if err != nil {
|
||||
if errors.Is(err, runner.ErrUnknownExecutionEnvironment) {
|
||||
writeNotFound(writer, err)
|
||||
} else {
|
||||
writeInternalServerError(writer, err, dto.ErrorNomadInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
writer.WriteHeader(http.StatusNoContent)
|
||||
}
|
351
internal/api/runners_test.go
Normal file
351
internal/api/runners_test.go
Normal file
@@ -0,0 +1,351 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/tests"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type MiddlewareTestSuite struct {
|
||||
suite.Suite
|
||||
manager *runner.ManagerMock
|
||||
router *mux.Router
|
||||
runner runner.Runner
|
||||
capturedRunner runner.Runner
|
||||
runnerRequest func(string) *http.Request
|
||||
}
|
||||
|
||||
func (s *MiddlewareTestSuite) SetupTest() {
|
||||
s.manager = &runner.ManagerMock{}
|
||||
s.runner = runner.NewNomadJob(tests.DefaultRunnerID, nil, nil, nil)
|
||||
s.capturedRunner = nil
|
||||
s.runnerRequest = func(runnerId string) *http.Request {
|
||||
path, err := s.router.Get("test-runner-id").URL(RunnerIDKey, runnerId)
|
||||
s.Require().NoError(err)
|
||||
request, err := http.NewRequest(http.MethodPost, path.String(), nil)
|
||||
s.Require().NoError(err)
|
||||
return request
|
||||
}
|
||||
runnerRouteHandler := func(writer http.ResponseWriter, request *http.Request) {
|
||||
var ok bool
|
||||
s.capturedRunner, ok = runner.FromContext(request.Context())
|
||||
if ok {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
s.router = mux.NewRouter()
|
||||
runnerController := &RunnerController{s.manager, s.router}
|
||||
s.router.Use(runnerController.findRunnerMiddleware)
|
||||
s.router.HandleFunc(fmt.Sprintf("/test/{%s}", RunnerIDKey), runnerRouteHandler).Name("test-runner-id")
|
||||
}
|
||||
|
||||
func TestMiddlewareTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(MiddlewareTestSuite))
|
||||
}
|
||||
|
||||
func (s *MiddlewareTestSuite) TestFindRunnerMiddlewareIfRunnerExists() {
|
||||
s.manager.On("Get", s.runner.ID()).Return(s.runner, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(recorder, s.runnerRequest(s.runner.ID()))
|
||||
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
s.Equal(s.runner, s.capturedRunner)
|
||||
}
|
||||
|
||||
func (s *MiddlewareTestSuite) TestFindRunnerMiddlewareIfRunnerDoesNotExist() {
|
||||
invalidID := "some-invalid-runner-id"
|
||||
s.manager.On("Get", invalidID).Return(nil, runner.ErrRunnerNotFound)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(recorder, s.runnerRequest(invalidID))
|
||||
|
||||
s.Equal(http.StatusNotFound, recorder.Code)
|
||||
}
|
||||
|
||||
func TestRunnerRouteTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(RunnerRouteTestSuite))
|
||||
}
|
||||
|
||||
type RunnerRouteTestSuite struct {
|
||||
suite.Suite
|
||||
runnerManager *runner.ManagerMock
|
||||
router *mux.Router
|
||||
runner runner.Runner
|
||||
executionID runner.ExecutionID
|
||||
}
|
||||
|
||||
func (s *RunnerRouteTestSuite) SetupTest() {
|
||||
s.runnerManager = &runner.ManagerMock{}
|
||||
s.router = NewRouter(s.runnerManager, nil)
|
||||
s.runner = runner.NewNomadJob("some-id", nil, nil, nil)
|
||||
s.executionID = "execution-id"
|
||||
s.runner.Add(s.executionID, &dto.ExecutionRequest{})
|
||||
s.runnerManager.On("Get", s.runner.ID()).Return(s.runner, nil)
|
||||
}
|
||||
|
||||
func TestProvideRunnerTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(ProvideRunnerTestSuite))
|
||||
}
|
||||
|
||||
type ProvideRunnerTestSuite struct {
|
||||
RunnerRouteTestSuite
|
||||
defaultRequest *http.Request
|
||||
path string
|
||||
}
|
||||
|
||||
func (s *ProvideRunnerTestSuite) SetupTest() {
|
||||
s.RunnerRouteTestSuite.SetupTest()
|
||||
|
||||
path, err := s.router.Get(ProvideRoute).URL()
|
||||
s.Require().NoError(err)
|
||||
s.path = path.String()
|
||||
|
||||
runnerRequest := dto.RunnerRequest{ExecutionEnvironmentID: tests.DefaultEnvironmentIDAsInteger}
|
||||
body, err := json.Marshal(runnerRequest)
|
||||
s.Require().NoError(err)
|
||||
s.defaultRequest, err = http.NewRequest(http.MethodPost, s.path, bytes.NewReader(body))
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *ProvideRunnerTestSuite) TestValidRequestReturnsRunner() {
|
||||
s.runnerManager.On("Claim", mock.AnythingOfType("runner.EnvironmentID"),
|
||||
mock.AnythingOfType("int")).Return(s.runner, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(recorder, s.defaultRequest)
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
|
||||
s.Run("response contains runnerId", func() {
|
||||
var runnerResponse dto.RunnerResponse
|
||||
err := json.NewDecoder(recorder.Result().Body).Decode(&runnerResponse)
|
||||
s.Require().NoError(err)
|
||||
_ = recorder.Result().Body.Close()
|
||||
s.Equal(s.runner.ID(), runnerResponse.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ProvideRunnerTestSuite) TestInvalidRequestReturnsBadRequest() {
|
||||
badRequest, err := http.NewRequest(http.MethodPost, s.path, strings.NewReader(""))
|
||||
s.Require().NoError(err)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(recorder, badRequest)
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func (s *ProvideRunnerTestSuite) TestWhenExecutionEnvironmentDoesNotExistReturnsNotFound() {
|
||||
s.runnerManager.
|
||||
On("Claim", mock.AnythingOfType("runner.EnvironmentID"), mock.AnythingOfType("int")).
|
||||
Return(nil, runner.ErrUnknownExecutionEnvironment)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(recorder, s.defaultRequest)
|
||||
s.Equal(http.StatusNotFound, recorder.Code)
|
||||
}
|
||||
|
||||
func (s *ProvideRunnerTestSuite) TestWhenNoRunnerAvailableReturnsNomadOverload() {
|
||||
s.runnerManager.On("Claim", mock.AnythingOfType("runner.EnvironmentID"), mock.AnythingOfType("int")).
|
||||
Return(nil, runner.ErrNoRunnersAvailable)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
s.router.ServeHTTP(recorder, s.defaultRequest)
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
|
||||
var internalServerError dto.InternalServerError
|
||||
err := json.NewDecoder(recorder.Result().Body).Decode(&internalServerError)
|
||||
s.Require().NoError(err)
|
||||
_ = recorder.Result().Body.Close()
|
||||
s.Equal(dto.ErrorNomadOverload, internalServerError.ErrorCode)
|
||||
}
|
||||
|
||||
func (s *RunnerRouteTestSuite) TestExecuteRoute() {
|
||||
path, err := s.router.Get(ExecutePath).URL(RunnerIDKey, s.runner.ID())
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Run("valid request", func() {
|
||||
recorder := httptest.NewRecorder()
|
||||
executionRequest := dto.ExecutionRequest{
|
||||
Command: "command",
|
||||
TimeLimit: 10,
|
||||
Environment: nil,
|
||||
}
|
||||
body, err := json.Marshal(executionRequest)
|
||||
s.Require().NoError(err)
|
||||
request, err := http.NewRequest(http.MethodPost, path.String(), bytes.NewReader(body))
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.router.ServeHTTP(recorder, request)
|
||||
|
||||
var webSocketResponse dto.ExecutionResponse
|
||||
err = json.NewDecoder(recorder.Result().Body).Decode(&webSocketResponse)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Equal(http.StatusOK, recorder.Code)
|
||||
|
||||
s.Run("creates an execution request for the runner", func() {
|
||||
webSocketURL, err := url.Parse(webSocketResponse.WebSocketURL)
|
||||
s.Require().NoError(err)
|
||||
executionID := webSocketURL.Query().Get(ExecutionIDKey)
|
||||
storedExecutionRequest, ok := s.runner.Pop(runner.ExecutionID(executionID))
|
||||
|
||||
s.True(ok, "No execution request with this id: ", executionID)
|
||||
s.Equal(&executionRequest, storedExecutionRequest)
|
||||
})
|
||||
})
|
||||
|
||||
s.Run("invalid request", func() {
|
||||
recorder := httptest.NewRecorder()
|
||||
body := ""
|
||||
request, err := http.NewRequest(http.MethodPost, path.String(), strings.NewReader(body))
|
||||
s.Require().NoError(err)
|
||||
s.router.ServeHTTP(recorder, request)
|
||||
|
||||
s.Equal(http.StatusBadRequest, recorder.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateFileSystemRouteTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(UpdateFileSystemRouteTestSuite))
|
||||
}
|
||||
|
||||
type UpdateFileSystemRouteTestSuite struct {
|
||||
RunnerRouteTestSuite
|
||||
path string
|
||||
recorder *httptest.ResponseRecorder
|
||||
runnerMock *runner.RunnerMock
|
||||
}
|
||||
|
||||
func (s *UpdateFileSystemRouteTestSuite) SetupTest() {
|
||||
s.RunnerRouteTestSuite.SetupTest()
|
||||
routeURL, err := s.router.Get(UpdateFileSystemPath).URL(RunnerIDKey, tests.DefaultMockID)
|
||||
s.Require().NoError(err)
|
||||
s.path = routeURL.String()
|
||||
s.runnerMock = &runner.RunnerMock{}
|
||||
s.runnerManager.On("Get", tests.DefaultMockID).Return(s.runnerMock, nil)
|
||||
s.recorder = httptest.NewRecorder()
|
||||
}
|
||||
|
||||
func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemReturnsNoContentOnValidRequest() {
|
||||
s.runnerMock.On("UpdateFileSystem", mock.AnythingOfType("*dto.UpdateFileSystemRequest")).Return(nil)
|
||||
|
||||
copyRequest := dto.UpdateFileSystemRequest{}
|
||||
body, err := json.Marshal(copyRequest)
|
||||
s.Require().NoError(err)
|
||||
request, err := http.NewRequest(http.MethodPatch, s.path, bytes.NewReader(body))
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.router.ServeHTTP(s.recorder, request)
|
||||
s.Equal(http.StatusNoContent, s.recorder.Code)
|
||||
s.runnerMock.AssertCalled(s.T(), "UpdateFileSystem", mock.AnythingOfType("*dto.UpdateFileSystemRequest"))
|
||||
}
|
||||
|
||||
func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemReturnsBadRequestOnInvalidRequestBody() {
|
||||
request, err := http.NewRequest(http.MethodPatch, s.path, strings.NewReader(""))
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.router.ServeHTTP(s.recorder, request)
|
||||
s.Equal(http.StatusBadRequest, s.recorder.Code)
|
||||
}
|
||||
|
||||
func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemToNonExistingRunnerReturnsNotFound() {
|
||||
invalidID := "some-invalid-runner-id"
|
||||
s.runnerManager.On("Get", invalidID).Return(nil, runner.ErrRunnerNotFound)
|
||||
path, err := s.router.Get(UpdateFileSystemPath).URL(RunnerIDKey, invalidID)
|
||||
s.Require().NoError(err)
|
||||
copyRequest := dto.UpdateFileSystemRequest{}
|
||||
body, err := json.Marshal(copyRequest)
|
||||
s.Require().NoError(err)
|
||||
request, err := http.NewRequest(http.MethodPatch, path.String(), bytes.NewReader(body))
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.router.ServeHTTP(s.recorder, request)
|
||||
s.Equal(http.StatusNotFound, s.recorder.Code)
|
||||
}
|
||||
|
||||
func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemReturnsInternalServerErrorWhenCopyFailed() {
|
||||
s.runnerMock.
|
||||
On("UpdateFileSystem", mock.AnythingOfType("*dto.UpdateFileSystemRequest")).
|
||||
Return(runner.ErrorFileCopyFailed)
|
||||
|
||||
copyRequest := dto.UpdateFileSystemRequest{}
|
||||
body, err := json.Marshal(copyRequest)
|
||||
s.Require().NoError(err)
|
||||
request, err := http.NewRequest(http.MethodPatch, s.path, bytes.NewReader(body))
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.router.ServeHTTP(s.recorder, request)
|
||||
s.Equal(http.StatusInternalServerError, s.recorder.Code)
|
||||
}
|
||||
|
||||
func TestDeleteRunnerRouteTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(DeleteRunnerRouteTestSuite))
|
||||
}
|
||||
|
||||
type DeleteRunnerRouteTestSuite struct {
|
||||
RunnerRouteTestSuite
|
||||
path string
|
||||
}
|
||||
|
||||
func (s *DeleteRunnerRouteTestSuite) SetupTest() {
|
||||
s.RunnerRouteTestSuite.SetupTest()
|
||||
deleteURL, err := s.router.Get(DeleteRoute).URL(RunnerIDKey, s.runner.ID())
|
||||
s.Require().NoError(err)
|
||||
s.path = deleteURL.String()
|
||||
}
|
||||
|
||||
func (s *DeleteRunnerRouteTestSuite) TestValidRequestReturnsNoContent() {
|
||||
s.runnerManager.On("Return", s.runner).Return(nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, err := http.NewRequest(http.MethodDelete, s.path, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.router.ServeHTTP(recorder, request)
|
||||
|
||||
s.Equal(http.StatusNoContent, recorder.Code)
|
||||
|
||||
s.Run("runner was returned to runner manager", func() {
|
||||
s.runnerManager.AssertCalled(s.T(), "Return", s.runner)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DeleteRunnerRouteTestSuite) TestReturnInternalServerErrorWhenApiCallToNomadFailed() {
|
||||
s.runnerManager.On("Return", s.runner).Return(tests.ErrDefault)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, err := http.NewRequest(http.MethodDelete, s.path, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.router.ServeHTTP(recorder, request)
|
||||
|
||||
s.Equal(http.StatusInternalServerError, recorder.Code)
|
||||
}
|
||||
|
||||
func (s *DeleteRunnerRouteTestSuite) TestDeleteInvalidRunnerIdReturnsNotFound() {
|
||||
s.runnerManager.On("Get", mock.AnythingOfType("string")).Return(nil, tests.ErrDefault)
|
||||
deleteURL, err := s.router.Get(DeleteRoute).URL(RunnerIDKey, "1nv4l1dID")
|
||||
s.Require().NoError(err)
|
||||
deletePath := deleteURL.String()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, err := http.NewRequest(http.MethodDelete, deletePath, nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.router.ServeHTTP(recorder, request)
|
||||
|
||||
s.Equal(http.StatusNotFound, recorder.Code)
|
||||
}
|
283
internal/api/websocket.go
Normal file
283
internal/api/websocket.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gorilla/websocket"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const CodeOceanToRawReaderBufferSize = 1024
|
||||
|
||||
var ErrUnknownExecutionID = errors.New("execution id unknown")
|
||||
|
||||
type webSocketConnection interface {
|
||||
WriteMessage(messageType int, data []byte) error
|
||||
Close() error
|
||||
NextReader() (messageType int, r io.Reader, err error)
|
||||
CloseHandler() func(code int, text string) error
|
||||
SetCloseHandler(handler func(code int, text string) error)
|
||||
}
|
||||
|
||||
type WebSocketReader interface {
|
||||
io.Reader
|
||||
startReadInputLoop() context.CancelFunc
|
||||
}
|
||||
|
||||
// codeOceanToRawReader is an io.Reader implementation that provides the content of the WebSocket connection
|
||||
// to CodeOcean. You have to start the Reader by calling readInputLoop. After that you can use the Read function.
|
||||
type codeOceanToRawReader struct {
|
||||
connection webSocketConnection
|
||||
|
||||
// A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket
|
||||
// and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here
|
||||
// instead of bytes.Buffer.
|
||||
buffer chan byte
|
||||
}
|
||||
|
||||
func newCodeOceanToRawReader(connection webSocketConnection) *codeOceanToRawReader {
|
||||
return &codeOceanToRawReader{
|
||||
connection: connection,
|
||||
buffer: make(chan byte, CodeOceanToRawReaderBufferSize),
|
||||
}
|
||||
}
|
||||
|
||||
// readInputLoop reads from the WebSocket connection and buffers the user's input.
|
||||
// This is necessary because input must be read for the connection to handle special messages like close and call the
|
||||
// CloseHandler.
|
||||
func (cr *codeOceanToRawReader) readInputLoop(ctx context.Context) {
|
||||
readMessage := make(chan bool)
|
||||
for {
|
||||
var messageType int
|
||||
var reader io.Reader
|
||||
var err error
|
||||
|
||||
go func() {
|
||||
messageType, reader, err = cr.connection.NextReader()
|
||||
readMessage <- true
|
||||
}()
|
||||
select {
|
||||
case <-readMessage:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Error reading client message")
|
||||
return
|
||||
}
|
||||
if messageType != websocket.TextMessage {
|
||||
log.WithField("messageType", messageType).Warn("Received message of wrong type")
|
||||
return
|
||||
}
|
||||
|
||||
message, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("error while reading WebSocket message")
|
||||
return
|
||||
}
|
||||
for _, character := range message {
|
||||
select {
|
||||
case cr.buffer <- character:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startReadInputLoop start the read input loop asynchronously and returns a context.CancelFunc which can be used
|
||||
// to cancel the read input loop.
|
||||
func (cr *codeOceanToRawReader) startReadInputLoop() context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go cr.readInputLoop(ctx)
|
||||
return cancel
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface.
|
||||
// It returns bytes from the buffer.
|
||||
func (cr *codeOceanToRawReader) Read(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
// Ensure to not return until at least one byte has been read to avoid busy waiting.
|
||||
p[0] = <-cr.buffer
|
||||
var n int
|
||||
for n = 1; n < len(p); n++ {
|
||||
select {
|
||||
case p[n] = <-cr.buffer:
|
||||
default:
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// rawToCodeOceanWriter is an io.Writer implementation that, when written to, wraps the written data in the appropriate
|
||||
// json structure and sends it to the CodeOcean via WebSocket.
|
||||
type rawToCodeOceanWriter struct {
|
||||
proxy *webSocketProxy
|
||||
outputType dto.WebSocketMessageType
|
||||
}
|
||||
|
||||
// Write implements the io.Writer interface.
|
||||
// The passed data is forwarded to the WebSocket to CodeOcean.
|
||||
func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) {
|
||||
err := rc.proxy.sendToClient(dto.WebSocketMessage{Type: rc.outputType, Data: string(p)})
|
||||
return len(p), err
|
||||
}
|
||||
|
||||
// webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean.
|
||||
type webSocketProxy struct {
|
||||
userExit chan bool
|
||||
connection webSocketConnection
|
||||
Stdin WebSocketReader
|
||||
Stdout io.Writer
|
||||
Stderr io.Writer
|
||||
}
|
||||
|
||||
// upgradeConnection upgrades a connection to a websocket and returns a webSocketProxy for this connection.
|
||||
func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSocketConnection, error) {
|
||||
connUpgrader := websocket.Upgrader{}
|
||||
connection, err := connUpgrader.Upgrade(writer, request, nil)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Connection upgrade failed")
|
||||
return nil, fmt.Errorf("error upgrading the connection: %w", err)
|
||||
}
|
||||
return connection, nil
|
||||
}
|
||||
|
||||
// newWebSocketProxy returns a initiated and started webSocketProxy.
|
||||
// As this proxy is already started, a start message is send to the client.
|
||||
func newWebSocketProxy(connection webSocketConnection) (*webSocketProxy, error) {
|
||||
stdin := newCodeOceanToRawReader(connection)
|
||||
proxy := &webSocketProxy{
|
||||
connection: connection,
|
||||
Stdin: stdin,
|
||||
userExit: make(chan bool),
|
||||
}
|
||||
proxy.Stdout = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStdout}
|
||||
proxy.Stderr = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStderr}
|
||||
|
||||
err := proxy.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaStart})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
closeHandler := connection.CloseHandler()
|
||||
connection.SetCloseHandler(func(code int, text string) error {
|
||||
//nolint:errcheck // The default close handler always returns nil, so the error can be safely ignored.
|
||||
_ = closeHandler(code, text)
|
||||
close(proxy.userExit)
|
||||
return nil
|
||||
})
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// waitForExit waits for an exit of either the runner (when the command terminates) or the client closing the WebSocket
|
||||
// and handles WebSocket exit messages.
|
||||
func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecution context.CancelFunc) {
|
||||
defer wp.close()
|
||||
cancelInputLoop := wp.Stdin.startReadInputLoop()
|
||||
var exitInfo runner.ExitInfo
|
||||
select {
|
||||
case exitInfo = <-exit:
|
||||
cancelInputLoop()
|
||||
log.Info("Execution returned")
|
||||
case <-wp.userExit:
|
||||
cancelInputLoop()
|
||||
cancelExecution()
|
||||
log.Info("Client closed the connection")
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(exitInfo.Err, context.DeadlineExceeded) || errors.Is(exitInfo.Err, runner.ErrorRunnerInactivityTimeout) {
|
||||
err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout})
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to send timeout message to client")
|
||||
}
|
||||
return
|
||||
} else if exitInfo.Err != nil {
|
||||
errorMessage := "Error executing the request"
|
||||
log.WithError(exitInfo.Err).Warn(errorMessage)
|
||||
err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage})
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to send output error message to client")
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithField("exit_code", exitInfo.Code).Debug()
|
||||
|
||||
err := wp.sendToClient(dto.WebSocketMessage{
|
||||
Type: dto.WebSocketExit,
|
||||
ExitCode: exitInfo.Code,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (wp *webSocketProxy) sendToClient(message dto.WebSocketMessage) error {
|
||||
encodedMessage, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
log.WithField("message", message).WithError(err).Warn("Marshal error")
|
||||
wp.closeWithError("Error creating message")
|
||||
return fmt.Errorf("error marshaling WebSocket message: %w", err)
|
||||
}
|
||||
err = wp.connection.WriteMessage(websocket.TextMessage, encodedMessage)
|
||||
if err != nil {
|
||||
errorMessage := "Error writing the exit message"
|
||||
log.WithError(err).Warn(errorMessage)
|
||||
wp.closeWithError(errorMessage)
|
||||
return fmt.Errorf("error writing WebSocket message: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wp *webSocketProxy) closeWithError(message string) {
|
||||
err := wp.connection.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseInternalServerErr, message))
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Error during websocket close")
|
||||
}
|
||||
}
|
||||
|
||||
func (wp *webSocketProxy) close() {
|
||||
err := wp.connection.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
_ = wp.connection.Close()
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Error during websocket close")
|
||||
}
|
||||
}
|
||||
|
||||
// connectToRunner is the endpoint for websocket connections.
|
||||
func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request *http.Request) {
|
||||
targetRunner, _ := runner.FromContext(request.Context())
|
||||
executionID := runner.ExecutionID(request.URL.Query().Get(ExecutionIDKey))
|
||||
executionRequest, ok := targetRunner.Pop(executionID)
|
||||
if !ok {
|
||||
writeNotFound(writer, ErrUnknownExecutionID)
|
||||
return
|
||||
}
|
||||
|
||||
connection, err := upgradeConnection(writer, request)
|
||||
if err != nil {
|
||||
writeInternalServerError(writer, err, dto.ErrorUnknown)
|
||||
return
|
||||
}
|
||||
proxy, err := newWebSocketProxy(connection)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.WithField("runnerId", targetRunner.ID()).WithField("executionID", executionID).Info("Running execution")
|
||||
exit, cancel := targetRunner.ExecuteInteractively(executionRequest, proxy.Stdin, proxy.Stdout, proxy.Stderr)
|
||||
|
||||
proxy.waitForExit(exit, cancel)
|
||||
}
|
93
internal/api/websocket_connection_mock.go
Normal file
93
internal/api/websocket_connection_mock.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Code generated by mockery v0.0.0-dev. DO NOT EDIT.
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
io "io"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// webSocketConnectionMock is an autogenerated mock type for the webSocketConnection type
|
||||
type webSocketConnectionMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *webSocketConnectionMock) Close() error {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// CloseHandler provides a mock function with given fields:
|
||||
func (_m *webSocketConnectionMock) CloseHandler() func(int, string) error {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 func(int, string) error
|
||||
if rf, ok := ret.Get(0).(func() func(int, string) error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(func(int, string) error)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NextReader provides a mock function with given fields:
|
||||
func (_m *webSocketConnectionMock) NextReader() (int, io.Reader, error) {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 int
|
||||
if rf, ok := ret.Get(0).(func() int); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(int)
|
||||
}
|
||||
|
||||
var r1 io.Reader
|
||||
if rf, ok := ret.Get(1).(func() io.Reader); ok {
|
||||
r1 = rf()
|
||||
} else {
|
||||
if ret.Get(1) != nil {
|
||||
r1 = ret.Get(1).(io.Reader)
|
||||
}
|
||||
}
|
||||
|
||||
var r2 error
|
||||
if rf, ok := ret.Get(2).(func() error); ok {
|
||||
r2 = rf()
|
||||
} else {
|
||||
r2 = ret.Error(2)
|
||||
}
|
||||
|
||||
return r0, r1, r2
|
||||
}
|
||||
|
||||
// SetCloseHandler provides a mock function with given fields: h
|
||||
func (_m *webSocketConnectionMock) SetCloseHandler(h func(int, string) error) {
|
||||
_m.Called(h)
|
||||
}
|
||||
|
||||
// WriteMessage provides a mock function with given fields: messageType, data
|
||||
func (_m *webSocketConnectionMock) WriteMessage(messageType int, data []byte) error {
|
||||
ret := _m.Called(messageType, data)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(int, []byte) error); ok {
|
||||
r0 = rf(messageType, data)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
490
internal/api/websocket_test.go
Normal file
490
internal/api/websocket_test.go
Normal file
@@ -0,0 +1,490 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/nomad"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/tests"
|
||||
"gitlab.hpi.de/codeocean/codemoon/poseidon/tests/helpers"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWebSocketTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(WebSocketTestSuite))
|
||||
}
|
||||
|
||||
type WebSocketTestSuite struct {
|
||||
suite.Suite
|
||||
router *mux.Router
|
||||
executionID runner.ExecutionID
|
||||
runner runner.Runner
|
||||
apiMock *nomad.ExecutorAPIMock
|
||||
server *httptest.Server
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) SetupTest() {
|
||||
runnerID := "runner-id"
|
||||
s.runner, s.apiMock = newNomadAllocationWithMockedAPIClient(runnerID)
|
||||
|
||||
// default execution
|
||||
s.executionID = "execution-id"
|
||||
s.runner.Add(s.executionID, &executionRequestHead)
|
||||
mockAPIExecuteHead(s.apiMock)
|
||||
|
||||
runnerManager := &runner.ManagerMock{}
|
||||
runnerManager.On("Get", s.runner.ID()).Return(s.runner, nil)
|
||||
s.router = NewRouter(runnerManager, nil)
|
||||
s.server = httptest.NewServer(s.router)
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) TearDownTest() {
|
||||
s.server.Close()
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) TestWebsocketConnectionCanBeEstablished() {
|
||||
wsURL, err := s.webSocketURL("ws", s.runner.ID(), s.executionID)
|
||||
s.Require().NoError(err)
|
||||
_, _, err = websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) TestWebsocketReturns404IfExecutionDoesNotExist() {
|
||||
wsURL, err := s.webSocketURL("ws", s.runner.ID(), "invalid-execution-id")
|
||||
s.Require().NoError(err)
|
||||
_, response, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||
s.Require().ErrorIs(err, websocket.ErrBadHandshake)
|
||||
s.Equal(http.StatusNotFound, response.StatusCode)
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) TestWebsocketReturns400IfRequestedViaHttp() {
|
||||
wsURL, err := s.webSocketURL("http", s.runner.ID(), s.executionID)
|
||||
s.Require().NoError(err)
|
||||
response, err := http.Get(wsURL.String())
|
||||
s.Require().NoError(err)
|
||||
// This functionality is implemented by the WebSocket library.
|
||||
s.Equal(http.StatusBadRequest, response.StatusCode)
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) TestWebsocketConnection() {
|
||||
wsURL, err := s.webSocketURL("ws", s.runner.ID(), s.executionID)
|
||||
s.Require().NoError(err)
|
||||
connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||
s.Require().NoError(err)
|
||||
err = connection.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Run("Receives start message", func() {
|
||||
message, err := helpers.ReceiveNextWebSocketMessage(connection)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(dto.WebSocketMetaStart, message.Type)
|
||||
})
|
||||
|
||||
s.Run("Executes the request in the runner", func() {
|
||||
<-time.After(100 * time.Millisecond)
|
||||
s.apiMock.AssertCalled(s.T(), "ExecuteCommand",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything)
|
||||
})
|
||||
|
||||
s.Run("Can send input", func() {
|
||||
err = connection.WriteMessage(websocket.TextMessage, []byte("Hello World\n"))
|
||||
s.Require().NoError(err)
|
||||
})
|
||||
|
||||
messages, err := helpers.ReceiveAllWebSocketMessages(connection)
|
||||
s.Require().Error(err)
|
||||
s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure))
|
||||
|
||||
s.Run("Receives output message", func() {
|
||||
stdout, _, _ := helpers.WebSocketOutputMessages(messages)
|
||||
s.Equal("Hello World", stdout)
|
||||
})
|
||||
|
||||
s.Run("Receives exit message", func() {
|
||||
controlMessages := helpers.WebSocketControlMessages(messages)
|
||||
s.Require().Equal(1, len(controlMessages))
|
||||
s.Equal(dto.WebSocketExit, controlMessages[0].Type)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) TestCancelWebSocketConnection() {
|
||||
executionID := runner.ExecutionID("sleeping-execution")
|
||||
s.runner.Add(executionID, &executionRequestSleep)
|
||||
canceled := mockAPIExecuteSleep(s.apiMock)
|
||||
|
||||
wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID)
|
||||
s.Require().NoError(err)
|
||||
connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
message, err := helpers.ReceiveNextWebSocketMessage(connection)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(dto.WebSocketMetaStart, message.Type)
|
||||
|
||||
select {
|
||||
case <-canceled:
|
||||
s.Fail("ExecuteInteractively canceled unexpected")
|
||||
default:
|
||||
}
|
||||
|
||||
err = connection.WriteControl(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second))
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case <-canceled:
|
||||
case <-time.After(time.Second):
|
||||
s.Fail("ExecuteInteractively not canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) TestWebSocketConnectionTimeout() {
|
||||
executionID := runner.ExecutionID("time-out-execution")
|
||||
limitExecution := executionRequestSleep
|
||||
limitExecution.TimeLimit = 2
|
||||
s.runner.Add(executionID, &limitExecution)
|
||||
canceled := mockAPIExecuteSleep(s.apiMock)
|
||||
|
||||
wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID)
|
||||
s.Require().NoError(err)
|
||||
connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
message, err := helpers.ReceiveNextWebSocketMessage(connection)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(dto.WebSocketMetaStart, message.Type)
|
||||
|
||||
select {
|
||||
case <-canceled:
|
||||
s.Fail("ExecuteInteractively canceled unexpected")
|
||||
case <-time.After(time.Duration(limitExecution.TimeLimit-1) * time.Second):
|
||||
<-time.After(time.Second)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-canceled:
|
||||
case <-time.After(time.Second):
|
||||
s.Fail("ExecuteInteractively not canceled")
|
||||
}
|
||||
|
||||
message, err = helpers.ReceiveNextWebSocketMessage(connection)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(dto.WebSocketMetaTimeout, message.Type)
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) TestWebsocketStdoutAndStderr() {
|
||||
executionID := runner.ExecutionID("ls-execution")
|
||||
s.runner.Add(executionID, &executionRequestLs)
|
||||
mockAPIExecuteLs(s.apiMock)
|
||||
|
||||
wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID)
|
||||
s.Require().NoError(err)
|
||||
connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
messages, err := helpers.ReceiveAllWebSocketMessages(connection)
|
||||
s.Require().Error(err)
|
||||
s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure))
|
||||
stdout, stderr, _ := helpers.WebSocketOutputMessages(messages)
|
||||
|
||||
s.Contains(stdout, "existing-file")
|
||||
|
||||
s.Contains(stderr, "non-existing-file")
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) TestWebsocketError() {
|
||||
executionID := runner.ExecutionID("error-execution")
|
||||
s.runner.Add(executionID, &executionRequestError)
|
||||
mockAPIExecuteError(s.apiMock)
|
||||
|
||||
wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID)
|
||||
s.Require().NoError(err)
|
||||
connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
messages, err := helpers.ReceiveAllWebSocketMessages(connection)
|
||||
s.Require().Error(err)
|
||||
s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure))
|
||||
|
||||
_, _, errMessages := helpers.WebSocketOutputMessages(messages)
|
||||
s.Equal(1, len(errMessages))
|
||||
s.Equal("Error executing the request", errMessages[0])
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) TestWebsocketNonZeroExit() {
|
||||
executionID := runner.ExecutionID("exit-execution")
|
||||
s.runner.Add(executionID, &executionRequestExitNonZero)
|
||||
mockAPIExecuteExitNonZero(s.apiMock)
|
||||
|
||||
wsURL, err := webSocketURL("ws", s.server, s.router, s.runner.ID(), executionID)
|
||||
s.Require().NoError(err)
|
||||
connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
messages, err := helpers.ReceiveAllWebSocketMessages(connection)
|
||||
s.Require().Error(err)
|
||||
s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure))
|
||||
|
||||
controlMessages := helpers.WebSocketControlMessages(messages)
|
||||
s.Equal(2, len(controlMessages))
|
||||
s.Equal(&dto.WebSocketMessage{Type: dto.WebSocketExit, ExitCode: 42}, controlMessages[1])
|
||||
}
|
||||
|
||||
func TestWebsocketTLS(t *testing.T) {
|
||||
runnerID := "runner-id"
|
||||
r, apiMock := newNomadAllocationWithMockedAPIClient(runnerID)
|
||||
|
||||
executionID := runner.ExecutionID("execution-id")
|
||||
r.Add(executionID, &executionRequestLs)
|
||||
mockAPIExecuteLs(apiMock)
|
||||
|
||||
runnerManager := &runner.ManagerMock{}
|
||||
runnerManager.On("Get", r.ID()).Return(r, nil)
|
||||
router := NewRouter(runnerManager, nil)
|
||||
|
||||
server, err := helpers.StartTLSServer(t, router)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
wsURL, err := webSocketURL("wss", server, router, runnerID, executionID)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &tls.Config{RootCAs: nil, InsecureSkipVerify: true} //nolint:gosec // test needs self-signed cert
|
||||
d := websocket.Dialer{TLSClientConfig: config}
|
||||
connection, _, err := d.Dial(wsURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
message, err := helpers.ReceiveNextWebSocketMessage(connection)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, dto.WebSocketMetaStart, message.Type)
|
||||
_, err = helpers.ReceiveAllWebSocketMessages(connection)
|
||||
require.Error(t, err)
|
||||
assert.True(t, websocket.IsCloseError(err, websocket.CloseNormalClosure))
|
||||
}
|
||||
|
||||
func TestRawToCodeOceanWriter(t *testing.T) {
|
||||
testMessage := "test"
|
||||
var message []byte
|
||||
|
||||
connectionMock := &webSocketConnectionMock{}
|
||||
connectionMock.On("WriteMessage", mock.AnythingOfType("int"), mock.AnythingOfType("[]uint8")).
|
||||
Run(func(args mock.Arguments) {
|
||||
var ok bool
|
||||
message, ok = args.Get(1).([]byte)
|
||||
require.True(t, ok)
|
||||
}).
|
||||
Return(nil)
|
||||
connectionMock.On("CloseHandler").Return(nil)
|
||||
connectionMock.On("SetCloseHandler", mock.Anything).Return()
|
||||
|
||||
proxy, err := newWebSocketProxy(connectionMock)
|
||||
require.NoError(t, err)
|
||||
writer := &rawToCodeOceanWriter{
|
||||
proxy: proxy,
|
||||
outputType: dto.WebSocketOutputStdout,
|
||||
}
|
||||
|
||||
_, err = writer.Write([]byte(testMessage))
|
||||
require.NoError(t, err)
|
||||
|
||||
expected, err := json.Marshal(struct {
|
||||
Type string `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}{string(dto.WebSocketOutputStdout), testMessage})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, message)
|
||||
}
|
||||
|
||||
func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasRead(t *testing.T) {
|
||||
reader := newCodeOceanToRawReader(nil)
|
||||
|
||||
read := make(chan bool)
|
||||
go func() {
|
||||
//nolint:makezero // we can't make zero initial length here as the reader otherwise doesn't block
|
||||
p := make([]byte, 10)
|
||||
_, err := reader.Read(p)
|
||||
require.NoError(t, err)
|
||||
read <- true
|
||||
}()
|
||||
|
||||
t.Run("Does not return immediately when there is no data", func(t *testing.T) {
|
||||
assert.False(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
|
||||
})
|
||||
|
||||
t.Run("Returns when there is data available", func(t *testing.T) {
|
||||
reader.buffer <- byte(42)
|
||||
assert.True(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *testing.T) {
|
||||
messages := make(chan io.Reader)
|
||||
|
||||
connection := &webSocketConnectionMock{}
|
||||
connection.On("WriteMessage", mock.AnythingOfType("int"), mock.AnythingOfType("[]uint8")).Return(nil)
|
||||
connection.On("CloseHandler").Return(nil)
|
||||
connection.On("SetCloseHandler", mock.Anything).Return()
|
||||
call := connection.On("NextReader")
|
||||
call.Run(func(_ mock.Arguments) {
|
||||
call.Return(websocket.TextMessage, <-messages, nil)
|
||||
})
|
||||
|
||||
reader := newCodeOceanToRawReader(connection)
|
||||
cancel := reader.startReadInputLoop()
|
||||
defer cancel()
|
||||
|
||||
read := make(chan bool)
|
||||
//nolint:makezero // this is required here to make the Read call blocking
|
||||
message := make([]byte, 10)
|
||||
go func() {
|
||||
_, err := reader.Read(message)
|
||||
require.NoError(t, err)
|
||||
read <- true
|
||||
}()
|
||||
|
||||
t.Run("Does not return immediately when there is no data", func(t *testing.T) {
|
||||
assert.False(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
|
||||
})
|
||||
|
||||
t.Run("Returns when there is data available", func(t *testing.T) {
|
||||
messages <- strings.NewReader("Hello")
|
||||
assert.True(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
|
||||
})
|
||||
}
|
||||
|
||||
// --- Test suite specific test helpers ---
|
||||
|
||||
func newNomadAllocationWithMockedAPIClient(runnerID string) (r runner.Runner, executorAPIMock *nomad.ExecutorAPIMock) {
|
||||
executorAPIMock = &nomad.ExecutorAPIMock{}
|
||||
r = runner.NewNomadJob(runnerID, nil, executorAPIMock, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func webSocketURL(scheme string, server *httptest.Server, router *mux.Router,
|
||||
runnerID string, executionID runner.ExecutionID,
|
||||
) (*url.URL, error) {
|
||||
websocketURL, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path, err := router.Get(WebsocketPath).URL(RunnerIDKey, runnerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
websocketURL.Scheme = scheme
|
||||
websocketURL.Path = path.Path
|
||||
websocketURL.RawQuery = fmt.Sprintf("executionID=%s", executionID)
|
||||
return websocketURL, nil
|
||||
}
|
||||
|
||||
func (s *WebSocketTestSuite) webSocketURL(scheme, runnerID string, executionID runner.ExecutionID) (*url.URL, error) {
|
||||
return webSocketURL(scheme, s.server, s.router, runnerID, executionID)
|
||||
}
|
||||
|
||||
var executionRequestLs = dto.ExecutionRequest{Command: "ls"}
|
||||
|
||||
// mockAPIExecuteLs mocks the ExecuteCommand of an ExecutorApi to act as if
|
||||
// 'ls existing-file non-existing-file' was executed.
|
||||
func mockAPIExecuteLs(api *nomad.ExecutorAPIMock) {
|
||||
mockAPIExecute(api, &executionRequestLs,
|
||||
func(_ string, _ context.Context, _ []string, _ bool, _ io.Reader, stdout, stderr io.Writer) (int, error) {
|
||||
_, _ = stdout.Write([]byte("existing-file\n"))
|
||||
_, _ = stderr.Write([]byte("ls: cannot access 'non-existing-file': No such file or directory\n"))
|
||||
return 0, nil
|
||||
})
|
||||
}
|
||||
|
||||
var executionRequestHead = dto.ExecutionRequest{Command: "head -n 1"}
|
||||
|
||||
// mockAPIExecuteHead mocks the ExecuteCommand of an ExecutorApi to act as if 'head -n 1' was executed.
|
||||
func mockAPIExecuteHead(api *nomad.ExecutorAPIMock) {
|
||||
mockAPIExecute(api, &executionRequestHead,
|
||||
func(_ string, _ context.Context, _ []string, _ bool,
|
||||
stdin io.Reader, stdout io.Writer, stderr io.Writer,
|
||||
) (int, error) {
|
||||
scanner := bufio.NewScanner(stdin)
|
||||
for !scanner.Scan() {
|
||||
scanner = bufio.NewScanner(stdin)
|
||||
}
|
||||
_, _ = stdout.Write(scanner.Bytes())
|
||||
return 0, nil
|
||||
})
|
||||
}
|
||||
|
||||
var executionRequestSleep = dto.ExecutionRequest{Command: "sleep infinity"}
|
||||
|
||||
// mockAPIExecuteSleep mocks the ExecuteCommand method of an ExecutorAPI to sleep until the execution is canceled.
|
||||
func mockAPIExecuteSleep(api *nomad.ExecutorAPIMock) <-chan bool {
|
||||
canceled := make(chan bool, 1)
|
||||
mockAPIExecute(api, &executionRequestSleep,
|
||||
func(_ string, ctx context.Context, _ []string, _ bool,
|
||||
stdin io.Reader, stdout io.Writer, stderr io.Writer,
|
||||
) (int, error) {
|
||||
<-ctx.Done()
|
||||
close(canceled)
|
||||
return 0, ctx.Err()
|
||||
})
|
||||
return canceled
|
||||
}
|
||||
|
||||
var executionRequestError = dto.ExecutionRequest{Command: "error"}
|
||||
|
||||
// mockAPIExecuteError mocks the ExecuteCommand method of an ExecutorApi to return an error.
|
||||
func mockAPIExecuteError(api *nomad.ExecutorAPIMock) {
|
||||
mockAPIExecute(api, &executionRequestError,
|
||||
func(_ string, _ context.Context, _ []string, _ bool, _ io.Reader, _, _ io.Writer) (int, error) {
|
||||
return 0, tests.ErrDefault
|
||||
})
|
||||
}
|
||||
|
||||
var executionRequestExitNonZero = dto.ExecutionRequest{Command: "exit 42"}
|
||||
|
||||
// mockAPIExecuteExitNonZero mocks the ExecuteCommand method of an ExecutorApi to exit with exit status 42.
|
||||
func mockAPIExecuteExitNonZero(api *nomad.ExecutorAPIMock) {
|
||||
mockAPIExecute(api, &executionRequestExitNonZero,
|
||||
func(_ string, _ context.Context, _ []string, _ bool, _ io.Reader, _, _ io.Writer) (int, error) {
|
||||
return 42, nil
|
||||
})
|
||||
}
|
||||
|
||||
// mockAPIExecute mocks the ExecuteCommand method of an ExecutorApi to call the given method run when the command
|
||||
// corresponding to the given ExecutionRequest is called.
|
||||
func mockAPIExecute(api *nomad.ExecutorAPIMock, request *dto.ExecutionRequest,
|
||||
run func(runnerId string, ctx context.Context, command []string, tty bool,
|
||||
stdin io.Reader, stdout, stderr io.Writer) (int, error),
|
||||
) {
|
||||
call := api.On("ExecuteCommand",
|
||||
mock.AnythingOfType("string"),
|
||||
mock.Anything,
|
||||
request.FullCommand(),
|
||||
mock.AnythingOfType("bool"),
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
mock.Anything)
|
||||
call.Run(func(args mock.Arguments) {
|
||||
exit, err := run(args.Get(0).(string),
|
||||
args.Get(1).(context.Context),
|
||||
args.Get(2).([]string),
|
||||
args.Get(3).(bool),
|
||||
args.Get(4).(io.Reader),
|
||||
args.Get(5).(io.Writer),
|
||||
args.Get(6).(io.Writer))
|
||||
call.ReturnArguments = mock.Arguments{exit, err}
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user