From ed735f284f4144257afd5496393f4951640c2655 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Pa=C3=9F?= Date: Mon, 10 May 2021 11:31:05 +0200 Subject: [PATCH] Add tests for websocket connection Co-authored-by: Konrad Hanff --- api/runners.go | 2 +- api/websocket.go | 14 ++++--- api/websocket_test.go | 85 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 92 insertions(+), 9 deletions(-) diff --git a/api/runners.go b/api/runners.go index 62d2d33..9f6d165 100644 --- a/api/runners.go +++ b/api/runners.go @@ -89,7 +89,7 @@ func findRunnerMiddleware(runnerPool environment.RunnerPool) func(handler http.H runnerId := mux.Vars(request)[RunnerIdKey] r, ok := runnerPool.Get(runnerId) if !ok { - writer.WriteHeader(http.StatusNotFound) + writeNotFound(writer, errors.New("no runner with this id")) return } ctx := runner.NewContext(request.Context(), r.(runner.Runner)) diff --git a/api/websocket.go b/api/websocket.go index d3cedf7..ce81e74 100644 --- a/api/websocket.go +++ b/api/websocket.go @@ -15,20 +15,24 @@ var connUpgrade = websocket.Upgrader{ // connectToRunner is a placeholder for now and will become the endpoint for websocket connections. func connectToRunner(writer http.ResponseWriter, request *http.Request) { - r, ok := runner.FromContext(request.Context()) + r, _ := runner.FromContext(request.Context()) + executionId := request.URL.Query().Get(ExecutionIdKey) + executionRequest, ok := r.Execution(runner.ExecutionId(executionId)) if !ok { - log.Error("Runner not set in request context.") - writeInternalServerError(writer, errors.New("findRunnerMiddleware failure"), dto.ErrorUnknown) + writeNotFound(writer, errors.New("executionId does not exist")) return } - executionId := request.URL.Query().Get(ExecutionIdKey) + log. + WithField("executionId", executionId). + WithField("command", executionRequest.Command). + Info("Running execution") connClient, err := connUpgrade.Upgrade(writer, request, nil) if err != nil { writeInternalServerError(writer, err, dto.ErrorUnknown) return } defer func(connClient *websocket.Conn) { - err := connClient.Close() + err := connClient.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { writeInternalServerError(writer, err, dto.ErrorUnknown) } diff --git a/api/websocket_test.go b/api/websocket_test.go index b465457..5e7434f 100644 --- a/api/websocket_test.go +++ b/api/websocket_test.go @@ -1,12 +1,91 @@ package api import ( + "errors" + "fmt" + "github.com/gorilla/mux" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/suite" + "gitlab.hpi.de/codeocean/codemoon/poseidon/api/dto" + "gitlab.hpi.de/codeocean/codemoon/poseidon/environment/pool" + "gitlab.hpi.de/codeocean/codemoon/poseidon/runner" + "net/http" + "net/http/httptest" + "net/url" "testing" ) -func TestInvalidExecutionId(t *testing.T) { - +type WebsocketTestSuite struct { + suite.Suite + runner runner.Runner + server *httptest.Server + router *mux.Router + executionId runner.ExecutionId } -func TestEstablishWebsocketConnection(t *testing.T) { +func (suite *WebsocketTestSuite) SetupSuite() { + runnerPool := pool.NewLocalRunnerPool() + suite.runner = runner.NewExerciseRunner("testRunner") + runnerPool.AddRunner(suite.runner) + var err error + suite.executionId, err = suite.runner.AddExecution(dto.ExecutionRequest{ + Command: "command", + TimeLimit: 10, + Environment: nil, + }) + if !suite.NoError(err) { + return + } + + router := mux.NewRouter() + router.Use(findRunnerMiddleware(runnerPool)) + router.HandleFunc(fmt.Sprintf("%s/{%s}%s", RouteRunners, RunnerIdKey, WebsocketPath), connectToRunner).Methods(http.MethodGet).Name(WebsocketPath) + suite.server = httptest.NewServer(router) + suite.router = router +} + +func (suite *WebsocketTestSuite) url(scheme, runnerId string, executionId runner.ExecutionId) (*url.URL, error) { + websocketUrl, err := url.Parse(suite.server.URL) + if !suite.NoError(err, "Error: parsing test server url") { + return nil, errors.New("could not parse server url") + } + path, err := suite.router.Get(WebsocketPath).URL(RunnerIdKey, runnerId) + if !suite.NoError(err) { + return nil, errors.New("could not set runnerId") + } + websocketUrl.Scheme = scheme + websocketUrl.Path = path.Path + websocketUrl.RawQuery = fmt.Sprintf("executionId=%s", executionId) + return websocketUrl, nil +} + +func (suite *WebsocketTestSuite) TearDownSuite() { + suite.server.Close() +} + +func TestWebsocketTestSuite(t *testing.T) { + suite.Run(t, new(WebsocketTestSuite)) +} + +func (suite *WebsocketTestSuite) TestEstablishWebsocketConnection() { + path, err := suite.url("ws", suite.runner.Id(), suite.executionId) + if !suite.NoError(err) { + return + } + _, _, err = websocket.DefaultDialer.Dial(path.String(), nil) + if !suite.NoError(err) { + return + } +} + +func (suite *WebsocketTestSuite) TestWebsocketReturns404IfExecutionDoesNotExist() { + wsUrl, err := suite.url("http", suite.runner.Id(), "invalid-execution-id") + if !suite.NoError(err) { + return + } + response, err := http.Get(wsUrl.String()) + if !suite.NoError(err) { + return + } + suite.Equal(http.StatusNotFound, response.StatusCode) }