diff --git a/internal/api/websocket.go b/internal/api/websocket.go index 565dda8..32834a2 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -38,8 +38,10 @@ type WebSocketReader interface { type codeOceanToRawReader struct { connection webSocketConnection - // ctx is used to cancel the reading routine. - ctx context.Context + // wsCtx is the context in that messages from CodeOcean are read. + wsCtx context.Context + // executorCtx is the context in that messages are forwarded to the executor. + executorCtx context.Context // A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket // and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here @@ -51,10 +53,11 @@ type codeOceanToRawReader struct { priorityBuffer chan byte } -func newCodeOceanToRawReader(connection webSocketConnection, ctx context.Context) *codeOceanToRawReader { +func newCodeOceanToRawReader(connection webSocketConnection, wsCtx, executorCtx context.Context) *codeOceanToRawReader { return &codeOceanToRawReader{ connection: connection, - ctx: ctx, + wsCtx: wsCtx, + executorCtx: executorCtx, buffer: make(chan byte, CodeOceanToRawReaderBufferSize), priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize), } @@ -65,7 +68,7 @@ func newCodeOceanToRawReader(connection webSocketConnection, ctx context.Context // CloseHandler. func (cr *codeOceanToRawReader) readInputLoop() { readMessage := make(chan bool) - loopContext, cancelInputLoop := context.WithCancel(cr.ctx) + loopContext, cancelInputLoop := context.WithCancel(cr.wsCtx) defer cancelInputLoop() readingContext, cancelNextMessage := context.WithCancel(loopContext) defer cancelNextMessage() @@ -138,7 +141,7 @@ func (cr *codeOceanToRawReader) Read(p []byte) (int, error) { // Ensure to not return until at least one byte has been read to avoid busy waiting. select { - case <-cr.ctx.Done(): + case <-cr.executorCtx.Done(): return 0, io.EOF case p[0] = <-cr.priorityBuffer: case p[0] = <-cr.buffer: @@ -208,23 +211,23 @@ func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSo return connection, nil } -// newWebSocketProxy returns a initiated and started webSocketProxy. +// newWebSocketProxy returns an initiated and started webSocketProxy. // As this proxy is already started, a start message is send to the client. -func newWebSocketProxy(connection webSocketConnection, ctx context.Context) (*webSocketProxy, error) { - stdin := newCodeOceanToRawReader(connection, ctx) - inputCtx, inputCancel := context.WithCancel(ctx) +func newWebSocketProxy(connection webSocketConnection, proxyCtx context.Context) (*webSocketProxy, error) { + wsCtx, cancelWsCommunication := context.WithCancel(proxyCtx) + stdin := newCodeOceanToRawReader(connection, wsCtx, proxyCtx) proxy := &webSocketProxy{ connection: connection, Stdin: stdin, - webSocketCtx: inputCtx, - cancelWebSocket: inputCancel, + webSocketCtx: wsCtx, + cancelWebSocket: cancelWsCommunication, } proxy.Stdout = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStdout} proxy.Stderr = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStderr} err := proxy.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaStart}) if err != nil { - inputCancel() + cancelWsCommunication() return nil, err } @@ -232,7 +235,7 @@ func newWebSocketProxy(connection webSocketConnection, ctx context.Context) (*we connection.SetCloseHandler(func(code int, text string) error { //nolint:errcheck // The default close handler always returns nil, so the error can be safely ignored. _ = closeHandler(code, text) - inputCancel() + cancelWsCommunication() return nil }) return proxy, nil diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go index 5ce6d74..07b2432 100644 --- a/internal/api/websocket_test.go +++ b/internal/api/websocket_test.go @@ -8,11 +8,14 @@ import ( "fmt" "github.com/gorilla/mux" "github.com/gorilla/websocket" + "github.com/openHPI/poseidon/internal/environment" "github.com/openHPI/poseidon/internal/nomad" "github.com/openHPI/poseidon/internal/runner" "github.com/openHPI/poseidon/pkg/dto" "github.com/openHPI/poseidon/tests" "github.com/openHPI/poseidon/tests/helpers" + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -44,7 +47,7 @@ func (s *WebSocketTestSuite) SetupTest() { s.runner, s.apiMock = newNomadAllocationWithMockedAPIClient(runnerID) // default execution - s.executionID = "execution-id" + s.executionID = tests.DefaultExecutionID s.runner.StoreExecution(s.executionID, &executionRequestHead) mockAPIExecuteHead(s.apiMock) @@ -250,7 +253,7 @@ func TestWebsocketTLS(t *testing.T) { runnerID := "runner-id" r, apiMock := newNomadAllocationWithMockedAPIClient(runnerID) - executionID := "execution-id" + executionID := tests.DefaultExecutionID r.StoreExecution(executionID, &executionRequestLs) mockAPIExecuteLs(apiMock) @@ -315,8 +318,9 @@ func TestRawToCodeOceanWriter(t *testing.T) { func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasRead(t *testing.T) { readingCtx, cancel := context.WithCancel(context.Background()) + forwardingCtx := readingCtx defer cancel() - reader := newCodeOceanToRawReader(nil, readingCtx) + reader := newCodeOceanToRawReader(nil, readingCtx, forwardingCtx) read := make(chan bool) go func() { @@ -350,8 +354,9 @@ func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *tes }) readingCtx, cancel := context.WithCancel(context.Background()) + forwardingCtx := readingCtx defer cancel() - reader := newCodeOceanToRawReader(connection, readingCtx) + reader := newCodeOceanToRawReader(connection, readingCtx, forwardingCtx) reader.startReadInputLoop() read := make(chan bool) @@ -373,6 +378,32 @@ func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *tes }) } +func TestWebSocketProxyStopsReadingTheWebSocketAfterClosingIt(t *testing.T) { + apiMock := &nomad.ExecutorAPIMock{} + executionID := tests.DefaultExecutionID + r, wsURL := newRunnerWithNotMockedRunnerManager(t, apiMock, executionID) + + logger, hook := test.NewNullLogger() + log = logger.WithField("pkg", "api") + + r.StoreExecution(executionID, &executionRequestHead) + mockAPIExecute(apiMock, &executionRequestHead, + func(_ string, ctx context.Context, _ []string, _ bool, _ io.Reader, _, _ io.Writer) (int, error) { + return 0, nil + }) + connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + require.NoError(t, err) + + _, err = helpers.ReceiveAllWebSocketMessages(connection) + require.Error(t, err) + assert.True(t, websocket.IsCloseError(err, websocket.CloseNormalClosure)) + for _, logMsg := range hook.Entries { + if logMsg.Level < logrus.InfoLevel { + assert.Fail(t, logMsg.Message) + } + } +} + // --- Test suite specific test helpers --- func newNomadAllocationWithMockedAPIClient(runnerID string) (runner.Runner, *nomad.ExecutorAPIMock) { @@ -383,6 +414,39 @@ func newNomadAllocationWithMockedAPIClient(runnerID string) (runner.Runner, *nom return r, executorAPIMock } +func newRunnerWithNotMockedRunnerManager(t *testing.T, apiMock *nomad.ExecutorAPIMock, executionID string) ( + r runner.Runner, wsURL *url.URL) { + t.Helper() + apiMock.On("MarkRunnerAsUsed", mock.AnythingOfType("string"), mock.AnythingOfType("int")).Return(nil) + apiMock.On("DeleteJob", mock.AnythingOfType("string")).Return(nil) + apiMock.On("RegisterRunnerJob", mock.AnythingOfType("*api.Job")).Return(nil) + call := apiMock.On("WatchEventStream", mock.Anything, mock.Anything, mock.Anything) + call.Run(func(args mock.Arguments) { + <-context.Background().Done() + call.ReturnArguments = mock.Arguments{nil} + }) + runnerManager := runner.NewNomadRunnerManager(apiMock, context.Background()) + router := NewRouter(runnerManager, nil) + server := httptest.NewServer(router) + + runnerID := tests.DefaultRunnerID + runnerJob := runner.NewNomadJob(runnerID, nil, apiMock, runnerManager) + e, err := environment.NewNomadEnvironment("job \"template-0\" {}") + require.NoError(t, err) + eID, err := nomad.EnvironmentIDFromRunnerID(runnerID) + require.NoError(t, err) + e.SetID(eID) + e.SetPrewarmingPoolSize(0) + runnerManager.SetEnvironment(e) + e.AddRunner(runnerJob) + + r, err = runnerManager.Claim(e.ID(), int(tests.DefaultTestTimeout.Seconds())) + require.NoError(t, err) + wsURL, err = webSocketURL("ws", server, router, r.ID(), executionID) + require.NoError(t, err) + return r, wsURL +} + func webSocketURL(scheme string, server *httptest.Server, router *mux.Router, runnerID string, executionID string, ) (*url.URL, error) { diff --git a/internal/environment/manager_test.go b/internal/environment/manager_test.go index 98577db..69c0a5b 100644 --- a/internal/environment/manager_test.go +++ b/internal/environment/manager_test.go @@ -261,7 +261,7 @@ func TestNomadEnvironmentManager_List(t *testing.T) { func mockWatchAllocations(apiMock *nomad.ExecutorAPIMock) { call := apiMock.On("WatchEventStream", mock.Anything, mock.Anything, mock.Anything) call.Run(func(args mock.Arguments) { - <-time.After(10 * time.Minute) // 10 minutes is the default test timeout + <-time.After(tests.DefaultTestTimeout) call.ReturnArguments = mock.Arguments{nil} }) } diff --git a/internal/runner/manager_test.go b/internal/runner/manager_test.go index b929f42..9682b09 100644 --- a/internal/runner/manager_test.go +++ b/internal/runner/manager_test.go @@ -44,7 +44,7 @@ func mockRunnerQueries(apiMock *nomad.ExecutorAPIMock, returnedRunnerIds []strin apiMock.ExpectedCalls = []*mock.Call{} call := apiMock.On("WatchEventStream", mock.Anything, mock.Anything, mock.Anything) call.Run(func(args mock.Arguments) { - <-time.After(10 * time.Minute) // 10 minutes is the default test timeout + <-time.After(tests.DefaultTestTimeout) call.ReturnArguments = mock.Arguments{nil} }) apiMock.On("LoadEnvironmentJobs").Return([]*nomadApi.Job{}, nil) diff --git a/tests/constants.go b/tests/constants.go index 5ba88bd..6d41f2a 100644 --- a/tests/constants.go +++ b/tests/constants.go @@ -26,6 +26,7 @@ const ( DefaultExecutionID = "s0m3-3x3cu710n-1d" DefaultMockID = "m0ck-1d" ShortTimeout = 100 * time.Millisecond + DefaultTestTimeout = 10 * time.Minute ) var (