diff --git a/internal/api/websocket.go b/internal/api/websocket.go index 777a9a0..565dda8 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -30,7 +30,7 @@ type webSocketConnection interface { type WebSocketReader interface { io.Reader io.Writer - startReadInputLoop() context.CancelFunc + startReadInputLoop() } // codeOceanToRawReader is an io.Reader implementation that provides the content of the WebSocket connection @@ -38,6 +38,9 @@ type WebSocketReader interface { type codeOceanToRawReader struct { connection webSocketConnection + // ctx is used to cancel the reading routine. + ctx context.Context + // A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket // and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here // instead of bytes.Buffer. @@ -48,9 +51,10 @@ type codeOceanToRawReader struct { priorityBuffer chan byte } -func newCodeOceanToRawReader(connection webSocketConnection) *codeOceanToRawReader { +func newCodeOceanToRawReader(connection webSocketConnection, ctx context.Context) *codeOceanToRawReader { return &codeOceanToRawReader{ connection: connection, + ctx: ctx, buffer: make(chan byte, CodeOceanToRawReaderBufferSize), priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize), } @@ -59,12 +63,14 @@ func newCodeOceanToRawReader(connection webSocketConnection) *codeOceanToRawRead // readInputLoop reads from the WebSocket connection and buffers the user's input. // This is necessary because input must be read for the connection to handle special messages like close and call the // CloseHandler. -func (cr *codeOceanToRawReader) readInputLoop(ctx context.Context) { +func (cr *codeOceanToRawReader) readInputLoop() { readMessage := make(chan bool) - readingContext, cancel := context.WithCancel(ctx) - defer cancel() + loopContext, cancelInputLoop := context.WithCancel(cr.ctx) + defer cancelInputLoop() + readingContext, cancelNextMessage := context.WithCancel(loopContext) + defer cancelNextMessage() - for ctx.Err() == nil { + for loopContext.Err() == nil { var messageType int var reader io.Reader var err error @@ -77,12 +83,12 @@ func (cr *codeOceanToRawReader) readInputLoop(ctx context.Context) { } }() select { - case <-ctx.Done(): + case <-loopContext.Done(): return case <-readMessage: } - if handleInput(messageType, reader, err, cr.buffer, ctx) { + if handleInput(messageType, reader, err, cr.buffer, loopContext) { return } } @@ -118,12 +124,9 @@ func handleInput(messageType int, reader io.Reader, err error, buffer chan byte, return false } -// startReadInputLoop start the read input loop asynchronously and returns a context.CancelFunc which can be used -// to cancel the read input loop. -func (cr *codeOceanToRawReader) startReadInputLoop() context.CancelFunc { - ctx, cancel := context.WithCancel(context.Background()) - go cr.readInputLoop(ctx) - return cancel +// startReadInputLoop start the read input loop asynchronously. +func (cr *codeOceanToRawReader) startReadInputLoop() { + go cr.readInputLoop() } // Read implements the io.Reader interface. @@ -132,8 +135,11 @@ func (cr *codeOceanToRawReader) Read(p []byte) (int, error) { if len(p) == 0 { return 0, nil } + // Ensure to not return until at least one byte has been read to avoid busy waiting. select { + case <-cr.ctx.Done(): + return 0, io.EOF case p[0] = <-cr.priorityBuffer: case p[0] = <-cr.buffer: } @@ -168,13 +174,12 @@ func (cr *codeOceanToRawReader) Write(p []byte) (n int, err error) { type rawToCodeOceanWriter struct { proxy *webSocketProxy outputType dto.WebSocketMessageType - stopped bool } // Write implements the io.Writer interface. // The passed data is forwarded to the WebSocket to CodeOcean. func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) { - if rc.stopped { + if rc.proxy.webSocketCtx.Err() != nil { return 0, nil } err := rc.proxy.sendToClient(dto.WebSocketMessage{Type: rc.outputType, Data: string(p)}) @@ -183,13 +188,13 @@ func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) { // webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean. type webSocketProxy struct { - userExit chan bool - connection webSocketConnection - connectionMu sync.Mutex - Stdin WebSocketReader - Stdout io.Writer - Stderr io.Writer - cancelWebSocketWrite func() + webSocketCtx context.Context + cancelWebSocket context.CancelFunc + connection webSocketConnection + connectionMu sync.Mutex + Stdin WebSocketReader + Stdout io.Writer + Stderr io.Writer } // upgradeConnection upgrades a connection to a websocket and returns a webSocketProxy for this connection. @@ -205,24 +210,21 @@ func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSo // newWebSocketProxy returns a initiated and started webSocketProxy. // As this proxy is already started, a start message is send to the client. -func newWebSocketProxy(connection webSocketConnection) (*webSocketProxy, error) { - stdin := newCodeOceanToRawReader(connection) +func newWebSocketProxy(connection webSocketConnection, ctx context.Context) (*webSocketProxy, error) { + stdin := newCodeOceanToRawReader(connection, ctx) + inputCtx, inputCancel := context.WithCancel(ctx) proxy := &webSocketProxy{ - connection: connection, - Stdin: stdin, - userExit: make(chan bool), + connection: connection, + Stdin: stdin, + webSocketCtx: inputCtx, + cancelWebSocket: inputCancel, } - stdOut := &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStdout} - stdErr := &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStderr} - proxy.cancelWebSocketWrite = func() { - stdOut.stopped = true - stdErr.stopped = true - } - proxy.Stdout = stdOut - proxy.Stderr = stdErr + proxy.Stdout = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStdout} + proxy.Stderr = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStderr} err := proxy.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaStart}) if err != nil { + inputCancel() return nil, err } @@ -230,7 +232,7 @@ func newWebSocketProxy(connection webSocketConnection) (*webSocketProxy, error) connection.SetCloseHandler(func(code int, text string) error { //nolint:errcheck // The default close handler always returns nil, so the error can be safely ignored. _ = closeHandler(code, text) - close(proxy.userExit) + inputCancel() return nil }) return proxy, nil @@ -239,19 +241,19 @@ func newWebSocketProxy(connection webSocketConnection) (*webSocketProxy, error) // waitForExit waits for an exit of either the runner (when the command terminates) or the client closing the WebSocket // and handles WebSocket exit messages. func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecution context.CancelFunc) { - cancelInputLoop := wp.Stdin.startReadInputLoop() + wp.Stdin.startReadInputLoop() + var exitInfo runner.ExitInfo select { - case exitInfo = <-exit: - cancelInputLoop() - wp.cancelWebSocketWrite() - log.Info("Execution returned") - case <-wp.userExit: - cancelInputLoop() - wp.cancelWebSocketWrite() - cancelExecution() + case <-wp.webSocketCtx.Done(): log.Info("Client closed the connection") + cancelExecution() + <-exit // /internal/runner/runner.go handleExitOrContextDone does not require client connection anymore. + <-exit // The goroutine closes this channel indicating that it does not use the connection to the executor anymore. return + case exitInfo = <-exit: + log.Info("Execution returned") + wp.cancelWebSocket() } if errors.Is(exitInfo.Err, context.DeadlineExceeded) || errors.Is(exitInfo.Err, runner.ErrorRunnerInactivityTimeout) { @@ -339,7 +341,9 @@ func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request * writeInternalServerError(writer, err, dto.ErrorUnknown) return } - proxy, err := newWebSocketProxy(connection) + proxyCtx, cancelProxy := context.WithCancel(context.Background()) + defer cancelProxy() + proxy, err := newWebSocketProxy(connection, proxyCtx) if err != nil { return } diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go index 4b28cfd..5ce6d74 100644 --- a/internal/api/websocket_test.go +++ b/internal/api/websocket_test.go @@ -293,7 +293,9 @@ func TestRawToCodeOceanWriter(t *testing.T) { connectionMock.On("CloseHandler").Return(nil) connectionMock.On("SetCloseHandler", mock.Anything).Return() - proxy, err := newWebSocketProxy(connectionMock) + proxyCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + proxy, err := newWebSocketProxy(connectionMock, proxyCtx) require.NoError(t, err) writer := &rawToCodeOceanWriter{ proxy: proxy, @@ -312,7 +314,9 @@ func TestRawToCodeOceanWriter(t *testing.T) { } func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasRead(t *testing.T) { - reader := newCodeOceanToRawReader(nil) + readingCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + reader := newCodeOceanToRawReader(nil, readingCtx) read := make(chan bool) go func() { @@ -345,9 +349,10 @@ func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *tes call.Return(websocket.TextMessage, <-messages, nil) }) - reader := newCodeOceanToRawReader(connection) - cancel := reader.startReadInputLoop() + readingCtx, cancel := context.WithCancel(context.Background()) defer cancel() + reader := newCodeOceanToRawReader(connection, readingCtx) + reader.startReadInputLoop() read := make(chan bool) //nolint:makezero // this is required here to make the Read call blocking diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 0b99e1b..7681c30 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -208,18 +208,19 @@ func (r *NomadJob) handleExitOrContextDone(ctx context.Context, cancelExecute co exitInternal <-chan ExitInfo, exit chan<- ExitInfo, stdin io.ReadWriter, ) { defer cancelExecute() + defer close(exit) // When this function has finished the connection to the executor is closed. + select { case exitInfo := <-exitInternal: exit <- exitInfo - close(exit) return case <-ctx.Done(): - // From this time on until the WebSocket connection to the client is closed in /internal/api/websocket.go - // waitForExit, output can still be forwarded to the client. We accept this race condition because adding - // a locking mechanism would complicate the interfaces used (currently io.Writer). - exit <- ExitInfo{255, ctx.Err()} - close(exit) } + + // From this time on the WebSocket connection to the client is closed in /internal/api/websocket.go + // waitForExit. Input can still be sent to the executor. + exit <- ExitInfo{255, ctx.Err()} + // This injects the SIGQUIT character into the stdin. This character is parsed by the tty line discipline // (tty has to be true) and converted to a SIGQUIT signal sent to the foreground process attached to the tty. // By default, SIGQUIT causes the process to terminate and produces a core dump. Processes can catch this signal