diff --git a/internal/api/websocket.go b/internal/api/websocket.go index bee45c3..bced43a 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -19,7 +19,6 @@ var ErrUnknownExecutionID = errors.New("execution id unknown") // webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean. type webSocketProxy struct { ctx context.Context - cancel context.CancelFunc Input ws.WebSocketReader Output ws.WebSocketWriter } @@ -38,16 +37,19 @@ func upgradeConnection(writer http.ResponseWriter, request *http.Request) (ws.Co // newWebSocketProxy returns an initiated and started webSocketProxy. // As this proxy is already started, a start message is send to the client. func newWebSocketProxy(connection ws.Connection, proxyCtx context.Context) *webSocketProxy { - wsCtx, cancelWsCommunication := context.WithCancel(proxyCtx) + // wsCtx is detached from the proxyCtx + // as it should send all messages in the buffer even if the execution/proxy is done. + wsCtx, cancelWsCommunication := context.WithCancel(context.Background()) + wsCtx = sentry.SetHubOnContext(wsCtx, sentry.GetHubFromContext(proxyCtx)) + proxy := &webSocketProxy{ ctx: wsCtx, - cancel: cancelWsCommunication, Input: ws.NewCodeOceanToRawReader(connection, wsCtx, proxyCtx), - Output: ws.NewCodeOceanOutputWriter(connection, wsCtx), + Output: ws.NewCodeOceanOutputWriter(connection, wsCtx, cancelWsCommunication), } connection.SetCloseHandler(func(code int, text string) error { - log.WithField("code", code).WithField("text", text).Debug("The client closed the connection.") + log.WithContext(wsCtx).WithField("code", code).WithField("text", text).Debug("The client closed the connection.") cancelWsCommunication() return nil }) @@ -70,7 +72,7 @@ func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecuti case exitInfo = <-exit: log.WithContext(wp.ctx).Info("Execution returned") wp.Input.Stop() - wp.Output.SendExitInfo(&exitInfo) + wp.Output.Close(&exitInfo) } } diff --git a/internal/api/ws/codeocean_writer.go b/internal/api/ws/codeocean_writer.go index 94d4e8a..ddb0976 100644 --- a/internal/api/ws/codeocean_writer.go +++ b/internal/api/ws/codeocean_writer.go @@ -28,7 +28,7 @@ func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) { type WebSocketWriter interface { StdOut() io.Writer StdErr() io.Writer - SendExitInfo(info *runner.ExitInfo) + Close(info *runner.ExitInfo) } // codeOceanOutputWriter is a concrete WebSocketWriter implementation. @@ -38,7 +38,7 @@ type codeOceanOutputWriter struct { stdOut io.Writer stdErr io.Writer queue chan *writingLoopMessage - stopped bool + ctx context.Context } // writingLoopMessage is an internal data structure to notify the writing loop when it should stop. @@ -47,13 +47,14 @@ type writingLoopMessage struct { data *dto.WebSocketMessage } -// NewCodeOceanOutputWriter provies an codeOceanOutputWriter for the time the context ctx is active. +// NewCodeOceanOutputWriter provides an codeOceanOutputWriter for the time the context ctx is active. // The codeOceanOutputWriter handles all the messages defined in the websocket.schema.json (start, timeout, stdout, ..). -func NewCodeOceanOutputWriter(connection Connection, ctx context.Context) *codeOceanOutputWriter { +func NewCodeOceanOutputWriter( + connection Connection, ctx context.Context, done context.CancelFunc) *codeOceanOutputWriter { cw := &codeOceanOutputWriter{ connection: connection, queue: make(chan *writingLoopMessage, CodeOceanOutputWriterBufferSize), - stopped: false, + ctx: ctx, } cw.stdOut = &rawToCodeOceanWriter{sendMessage: func(s string) { cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputStdout, Data: s}) @@ -62,8 +63,7 @@ func NewCodeOceanOutputWriter(connection Connection, ctx context.Context) *codeO cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputStderr, Data: s}) }} - go cw.startWritingLoop() - go cw.stopWhenContextDone(ctx) + go cw.startWritingLoop(done) cw.send(&dto.WebSocketMessage{Type: dto.WebSocketMetaStart}) return cw } @@ -78,76 +78,78 @@ func (cw *codeOceanOutputWriter) StdErr() io.Writer { return cw.stdErr } -// SendExitInfo forwards the kind of exit (timeout, error, normal) to CodeOcean. -func (cw *codeOceanOutputWriter) SendExitInfo(info *runner.ExitInfo) { +// Close forwards the kind of exit (timeout, error, normal) to CodeOcean. +// This results in the closing of the WebSocket connection. +func (cw *codeOceanOutputWriter) Close(info *runner.ExitInfo) { switch { case errors.Is(info.Err, context.DeadlineExceeded) || errors.Is(info.Err, runner.ErrorRunnerInactivityTimeout): cw.send(&dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout}) case info.Err != nil: errorMessage := "Error executing the request" - log.WithError(info.Err).Warn(errorMessage) + log.WithContext(cw.ctx).WithError(info.Err).Warn(errorMessage) cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage}) default: cw.send(&dto.WebSocketMessage{Type: dto.WebSocketExit, ExitCode: info.Code}) } } -// stopWhenContextDone notifies the writing loop to stop after the context has been passed. -func (cw *codeOceanOutputWriter) stopWhenContextDone(ctx context.Context) { - <-ctx.Done() - if !cw.stopped { - cw.queue <- &writingLoopMessage{done: true} - } -} - // send forwards the passed dto.WebSocketMessage to the writing loop. func (cw *codeOceanOutputWriter) send(message *dto.WebSocketMessage) { - if cw.stopped { + select { + case <-cw.ctx.Done(): return + default: + done := message.Type == dto.WebSocketExit || + message.Type == dto.WebSocketMetaTimeout || + message.Type == dto.WebSocketOutputError + cw.queue <- &writingLoopMessage{done: done, data: message} } - done := message.Type == dto.WebSocketExit || - message.Type == dto.WebSocketMetaTimeout || - message.Type == dto.WebSocketOutputError - cw.queue <- &writingLoopMessage{done: done, data: message} } // startWritingLoop enables the writing loop. // This is the central and only place where written changes to the WebSocket connection should be done. // It synchronizes the messages to provide state checks of the WebSocket connection. -func (cw *codeOceanOutputWriter) startWritingLoop() { - for { - message := <-cw.queue - done := sendMessage(cw.connection, message.data) - if done || message.done { - break +func (cw *codeOceanOutputWriter) startWritingLoop(writingLoopDone context.CancelFunc) { + defer func() { + message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + err := cw.connection.WriteMessage(websocket.CloseMessage, message) + err2 := cw.connection.Close() + if err != nil || err2 != nil { + log.WithContext(cw.ctx).WithError(err).WithField("err2", err2).Warn("Error during websocket close") + } + }() + + for { + select { + case <-cw.ctx.Done(): + return + case message := <-cw.queue: + done := sendMessage(cw.connection, message.data, cw.ctx) + if done || message.done { + writingLoopDone() + return + } } - } - cw.stopped = true - message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") - err := cw.connection.WriteMessage(websocket.CloseMessage, message) - err2 := cw.connection.Close() - if err != nil || err2 != nil { - log.WithError(err).WithField("err2", err2).Warn("Error during websocket close") } } // sendMessage is a helper function for the writing loop. It must not be called from somewhere else! -func sendMessage(connection Connection, message *dto.WebSocketMessage) (done bool) { +func sendMessage(connection Connection, message *dto.WebSocketMessage, ctx context.Context) (done bool) { if message == nil { return false } encodedMessage, err := json.Marshal(message) if err != nil { - log.WithField("message", message).WithError(err).Warn("Marshal error") + log.WithContext(ctx).WithField("message", message).WithError(err).Warn("Marshal error") return false } - log.WithField("message", message).Trace("Sending message to client") + log.WithContext(ctx).WithField("message", message).Trace("Sending message to client") err = connection.WriteMessage(websocket.TextMessage, encodedMessage) if err != nil { errorMessage := "Error writing the message" - log.WithField("message", message).WithError(err).Warn(errorMessage) + log.WithContext(ctx).WithField("message", message).WithError(err).Warn(errorMessage) return true } diff --git a/internal/api/ws/codeocean_writer_test.go b/internal/api/ws/codeocean_writer_test.go index 9d0fc94..675191e 100644 --- a/internal/api/ws/codeocean_writer_test.go +++ b/internal/api/ws/codeocean_writer_test.go @@ -16,7 +16,7 @@ func TestRawToCodeOceanWriter(t *testing.T) { connectionMock, message := buildConnectionMock(t) proxyCtx, cancel := context.WithCancel(context.Background()) defer cancel() - output := NewCodeOceanOutputWriter(connectionMock, proxyCtx) + output := NewCodeOceanOutputWriter(connectionMock, proxyCtx, cancel) <-message // start message t.Run("StdOut", func(t *testing.T) { @@ -69,10 +69,10 @@ func TestCodeOceanOutputWriter_SendExitInfo(t *testing.T) { connectionMock, message := buildConnectionMock(t) proxyCtx, cancel := context.WithCancel(context.Background()) defer cancel() - output := NewCodeOceanOutputWriter(connectionMock, proxyCtx) + output := NewCodeOceanOutputWriter(connectionMock, proxyCtx, cancel) <-message // start message - output.SendExitInfo(test.info) + output.Close(test.info) expected, err := json.Marshal(test.message) require.NoError(t, err)