From 203d5a3a4f869484e842d5d408adbb3a0815854c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Pa=C3=9F?= <22845248+mpass99@users.noreply.github.com> Date: Sun, 26 Jun 2022 20:19:23 +0200 Subject: [PATCH] #155 refactor and synchronise writing to CodeOcean. (#174) * #155 refactor and synchronise writing to CodeOcean. * Reduce complexity of input parsing. * Update typo in internal/api/ws/codeocean_writer.go Co-authored-by: Sebastian Serth --- internal/api/websocket.go | 320 ++---------------- internal/api/websocket_test.go | 101 +----- internal/api/ws/codeocean_reader.go | 177 ++++++++++ internal/api/ws/codeocean_reader_test.go | 75 ++++ internal/api/ws/codeocean_writer.go | 155 +++++++++ internal/api/ws/codeocean_writer_test.go | 100 ++++++ internal/api/ws/connection.go | 14 + .../connection_mock.go} | 37 +- 8 files changed, 569 insertions(+), 410 deletions(-) create mode 100644 internal/api/ws/codeocean_reader.go create mode 100644 internal/api/ws/codeocean_reader_test.go create mode 100644 internal/api/ws/codeocean_writer.go create mode 100644 internal/api/ws/codeocean_writer_test.go create mode 100644 internal/api/ws/connection.go rename internal/api/{websocket_connection_mock.go => ws/connection_mock.go} (58%) diff --git a/internal/api/websocket.go b/internal/api/websocket.go index 1eb93db..1370c00 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -2,217 +2,29 @@ package api import ( "context" - "encoding/json" "errors" "fmt" "github.com/gorilla/websocket" + "github.com/openHPI/poseidon/internal/api/ws" "github.com/openHPI/poseidon/internal/runner" "github.com/openHPI/poseidon/pkg/dto" "github.com/openHPI/poseidon/pkg/logging" "github.com/openHPI/poseidon/pkg/monitoring" - "io" "net/http" - "sync" ) -const CodeOceanToRawReaderBufferSize = 1024 - var ErrUnknownExecutionID = errors.New("execution id unknown") -type webSocketConnection interface { - WriteMessage(messageType int, data []byte) error - Close() error - NextReader() (messageType int, r io.Reader, err error) - CloseHandler() func(code int, text string) error - SetCloseHandler(handler func(code int, text string) error) -} - -// WebSocketReader is an interface that is intended for providing abstraction around reading from a WebSocket. -// Besides io.Reader, it also implements io.Writer. The Write method is used to inject data into the WebSocket stream. -type WebSocketReader interface { - io.Reader - io.Writer - startReadInputLoop() - stopReadInputLoop() -} - -// codeOceanToRawReader is an io.Reader implementation that provides the content of the WebSocket connection -// to CodeOcean. You have to start the Reader by calling readInputLoop. After that you can use the Read function. -type codeOceanToRawReader struct { - connection webSocketConnection - - // readCtx is the context in that messages from CodeOcean are read. - readCtx context.Context - cancelReadCtx context.CancelFunc - // executorCtx is the context in that messages are forwarded to the executor. - executorCtx context.Context - - // A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket - // and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here - // instead of bytes.Buffer. - buffer chan byte - // The priorityBuffer is a buffer for injecting data into stdin of the execution from Poseidon, - // for example the character that causes the tty to generate a SIGQUIT signal. - // It is always read before the regular buffer. - priorityBuffer chan byte -} - -func newCodeOceanToRawReader(connection webSocketConnection, wsCtx, executorCtx context.Context) *codeOceanToRawReader { - return &codeOceanToRawReader{ - connection: connection, - readCtx: wsCtx, // This context may be canceled before the executorCtx. - cancelReadCtx: func() {}, - executorCtx: executorCtx, - buffer: make(chan byte, CodeOceanToRawReaderBufferSize), - priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize), - } -} - -// readInputLoop reads from the WebSocket connection and buffers the user's input. -// This is necessary because input must be read for the connection to handle special messages like close and call the -// CloseHandler. -func (cr *codeOceanToRawReader) readInputLoop(ctx context.Context) { - readMessage := make(chan bool) - loopContext, cancelInputLoop := context.WithCancel(ctx) - defer cancelInputLoop() - readingContext, cancelNextMessage := context.WithCancel(loopContext) - defer cancelNextMessage() - - for loopContext.Err() == nil { - var messageType int - var reader io.Reader - var err error - - go func() { - messageType, reader, err = cr.connection.NextReader() - select { - case <-readingContext.Done(): - case readMessage <- true: - } - }() - select { - case <-loopContext.Done(): - return - case <-readMessage: - } - - if handleInput(messageType, reader, err, cr.buffer, loopContext) { - return - } - } -} - -// handleInput receives a new message from the client and may forward it to the executor. -func handleInput(messageType int, reader io.Reader, err error, buffer chan byte, ctx context.Context) (done bool) { - if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure) { - log.Debug("ReadInputLoop: The client closed the connection!") - // The close handler will do something soon. - return true - } else if err != nil { - log.WithError(err).Warn("Error reading client message") - return true - } - if messageType != websocket.TextMessage { - log.WithField("messageType", messageType).Warn("Received message of wrong type") - return true - } - - message, err := io.ReadAll(reader) - if err != nil { - log.WithError(err).Warn("error while reading WebSocket message") - return true - } - - log.WithField("message", string(message)).Trace("Received message from client") - for _, character := range message { - select { - case <-ctx.Done(): - return true - case buffer <- character: - } - } - return false -} - -// startReadInputLoop start the read input loop asynchronously. -func (cr *codeOceanToRawReader) startReadInputLoop() { - ctx, cancel := context.WithCancel(cr.readCtx) - cr.cancelReadCtx = cancel - go cr.readInputLoop(ctx) -} - -// startReadInputLoop stops the asynchronous read input loop. -func (cr *codeOceanToRawReader) stopReadInputLoop() { - cr.cancelReadCtx() -} - -// Read implements the io.Reader interface. -// It returns bytes from the buffer or priorityBuffer. -func (cr *codeOceanToRawReader) Read(p []byte) (int, error) { - if len(p) == 0 { - return 0, nil - } - - // Ensure to not return until at least one byte has been read to avoid busy waiting. - select { - case <-cr.executorCtx.Done(): - return 0, io.EOF - case p[0] = <-cr.priorityBuffer: - case p[0] = <-cr.buffer: - } - var n int - for n = 1; n < len(p); n++ { - select { - case p[n] = <-cr.priorityBuffer: - case p[n] = <-cr.buffer: - default: - return n, nil - } - } - return n, nil -} - -// Write implements the io.Writer interface. -// Data written to a codeOceanToRawReader using this method is returned by Read before other data from the WebSocket. -func (cr *codeOceanToRawReader) Write(p []byte) (n int, err error) { - var c byte - for n, c = range p { - select { - case cr.priorityBuffer <- c: - default: - break - } - } - return n, nil -} - -// rawToCodeOceanWriter is an io.Writer implementation that, when written to, wraps the written data in the appropriate -// json structure and sends it to the CodeOcean via WebSocket. -type rawToCodeOceanWriter struct { - proxy *webSocketProxy - outputType dto.WebSocketMessageType -} - -// Write implements the io.Writer interface. -// The passed data is forwarded to the WebSocket to CodeOcean. -func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) { - err := rc.proxy.sendToClient(dto.WebSocketMessage{Type: rc.outputType, Data: string(p)}) - return len(p), err -} - // webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean. type webSocketProxy struct { - webSocketCtx context.Context - cancelWebSocket context.CancelFunc - connection webSocketConnection - connectionMu sync.Mutex - Stdin WebSocketReader - Stdout io.Writer - Stderr io.Writer + ctx context.Context + cancel context.CancelFunc + Input ws.WebSocketReader + Output ws.WebSocketWriter } // upgradeConnection upgrades a connection to a websocket and returns a webSocketProxy for this connection. -func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSocketConnection, error) { +func upgradeConnection(writer http.ResponseWriter, request *http.Request) (ws.Connection, error) { connUpgrader := websocket.Upgrader{} connection, err := connUpgrader.Upgrade(writer, request, nil) if err != nil { @@ -224,31 +36,18 @@ func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSo // newWebSocketProxy returns an initiated and started webSocketProxy. // As this proxy is already started, a start message is send to the client. -func newWebSocketProxy(connection webSocketConnection, proxyCtx context.Context) (*webSocketProxy, error) { +func newWebSocketProxy(connection ws.Connection, proxyCtx context.Context) (*webSocketProxy, error) { wsCtx, cancelWsCommunication := context.WithCancel(proxyCtx) - stdin := newCodeOceanToRawReader(connection, wsCtx, proxyCtx) proxy := &webSocketProxy{ - connection: connection, - Stdin: stdin, - webSocketCtx: wsCtx, - cancelWebSocket: cancelWsCommunication, - } - proxy.Stdout = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStdout} - proxy.Stderr = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStderr} - - err := proxy.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaStart}) - if err != nil { - cancelWsCommunication() - return nil, err + ctx: wsCtx, + cancel: cancelWsCommunication, + Input: ws.NewCodeOceanToRawReader(connection, wsCtx, proxyCtx), + Output: ws.NewCodeOceanOutputWriter(connection, wsCtx), } - closeHandler := connection.CloseHandler() connection.SetCloseHandler(func(code int, text string) error { - log.Info("Before closing the connection via Handler") + log.WithField("code", code).WithField("text", text).Debug("The client closed the connection.") cancelWsCommunication() - //nolint:errcheck // The default close handler always returns nil, so the error can be safely ignored. - _ = closeHandler(code, text) - log.Info("After closing the connection via Handler") return nil }) return proxy, nil @@ -257,99 +56,21 @@ func newWebSocketProxy(connection webSocketConnection, proxyCtx context.Context) // waitForExit waits for an exit of either the runner (when the command terminates) or the client closing the WebSocket // and handles WebSocket exit messages. func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecution context.CancelFunc) { - wp.Stdin.startReadInputLoop() + wp.Input.Start() var exitInfo runner.ExitInfo select { - case <-wp.webSocketCtx.Done(): + case <-wp.ctx.Done(): log.Info("Client closed the connection") + wp.Input.Stop() cancelExecution() <-exit // /internal/runner/runner.go handleExitOrContextDone does not require client connection anymore. <-exit // The goroutine closes this channel indicating that it does not use the connection to the executor anymore. - return case exitInfo = <-exit: log.Info("Execution returned") - wp.Stdin.stopReadInputLoop() // Here we stop reading from the client - defer wp.cancelWebSocket() // At the end of this method we stop writing to the client + wp.Input.Stop() + wp.Output.SendExitInfo(&exitInfo) } - - if errors.Is(exitInfo.Err, context.DeadlineExceeded) || errors.Is(exitInfo.Err, runner.ErrorRunnerInactivityTimeout) { - err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout}) - if err == nil { - wp.closeNormal() - } else { - log.WithError(err).Warn("Failed to send timeout message to client") - } - return - } else if exitInfo.Err != nil { - errorMessage := "Error executing the request" - log.WithError(exitInfo.Err).Warn(errorMessage) - err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage}) - if err == nil { - wp.closeNormal() - } else { - log.WithError(err).Warn("Failed to send output error message to client") - } - return - } - log.WithField("exit_code", exitInfo.Code).Debug() - - err := wp.sendToClient(dto.WebSocketMessage{ - Type: dto.WebSocketExit, - ExitCode: exitInfo.Code, - }) - if err != nil { - log.WithError(err).Warn("Error sending exit message") - return - } - wp.closeNormal() -} - -func (wp *webSocketProxy) sendToClient(message dto.WebSocketMessage) error { - encodedMessage, err := json.Marshal(message) - if err != nil { - log.WithField("message", message).WithError(err).Warn("Marshal error") - wp.closeWithError("Error creating message") - return fmt.Errorf("error marshaling WebSocket message: %w", err) - } - log.WithField("message", message).Trace("Sending message to client") - select { - case <-wp.webSocketCtx.Done(): - default: - log.Info("Passed WriteToCodeOceanCheck") - err = wp.writeMessage(websocket.TextMessage, encodedMessage) - if err != nil { - errorMessage := "Error writing the message" - log.WithField("message", message).WithError(err).Warn(errorMessage) - wp.closeWithError(errorMessage) - return fmt.Errorf("error writing WebSocket message: %w", err) - } - } - return nil -} - -func (wp *webSocketProxy) closeWithError(message string) { - wp.close(websocket.FormatCloseMessage(websocket.CloseInternalServerErr, message)) -} - -func (wp *webSocketProxy) closeNormal() { - wp.close(websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) -} - -func (wp *webSocketProxy) close(message []byte) { - log.Info("Before closing the connection manually") - err := wp.writeMessage(websocket.CloseMessage, message) - _ = wp.connection.Close() - if err != nil { - log.WithError(err).Warn("Error during websocket close") - } - log.Info("After closing the connection manually") -} - -func (wp *webSocketProxy) writeMessage(messageType int, data []byte) error { - wp.connectionMu.Lock() - defer wp.connectionMu.Unlock() - return wp.connection.WriteMessage(messageType, data) //nolint:wrapcheck // Wrap the original WriteMessage in a mutex. } // connectToRunner is the endpoint for websocket connections. @@ -378,10 +99,11 @@ func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request * log.WithField("runnerId", targetRunner.ID()). WithField("executionID", logging.RemoveNewlineSymbol(executionID)). Info("Running execution") - exit, cancel, err := targetRunner.ExecuteInteractively(executionID, proxy.Stdin, proxy.Stdout, proxy.Stderr) + exit, cancel, err := targetRunner.ExecuteInteractively(executionID, + proxy.Input, proxy.Output.StdOut(), proxy.Output.StdErr()) if err != nil { - proxy.closeWithError(fmt.Sprintf("execution failed with: %v", err)) - return + log.WithError(err).Warn("Cannot execute request.") + return // The proxy is stopped by the defered cancel. } proxy.waitForExit(exit, cancel) diff --git a/internal/api/websocket_test.go b/internal/api/websocket_test.go index 894a77a..186accf 100644 --- a/internal/api/websocket_test.go +++ b/internal/api/websocket_test.go @@ -4,7 +4,6 @@ import ( "bufio" "context" "crypto/tls" - "encoding/json" "fmt" "github.com/gorilla/mux" "github.com/gorilla/websocket" @@ -24,7 +23,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "strings" "testing" "time" ) @@ -100,7 +98,7 @@ func (s *WebSocketTestSuite) TestWebsocketConnection() { }) s.Run("Executes the request in the runner", func() { - <-time.After(100 * time.Millisecond) + <-time.After(tests.ShortTimeout) s.apiMock.AssertCalled(s.T(), "ExecuteCommand", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything) }) @@ -281,103 +279,6 @@ func TestWebsocketTLS(t *testing.T) { assert.True(t, websocket.IsCloseError(err, websocket.CloseNormalClosure)) } -func TestRawToCodeOceanWriter(t *testing.T) { - testMessage := "test" - var message []byte - - connectionMock := &webSocketConnectionMock{} - connectionMock.On("WriteMessage", mock.AnythingOfType("int"), mock.AnythingOfType("[]uint8")). - Run(func(args mock.Arguments) { - var ok bool - message, ok = args.Get(1).([]byte) - require.True(t, ok) - }). - Return(nil) - connectionMock.On("CloseHandler").Return(nil) - connectionMock.On("SetCloseHandler", mock.Anything).Return() - - proxyCtx, cancel := context.WithCancel(context.Background()) - defer cancel() - proxy, err := newWebSocketProxy(connectionMock, proxyCtx) - require.NoError(t, err) - writer := &rawToCodeOceanWriter{ - proxy: proxy, - outputType: dto.WebSocketOutputStdout, - } - - _, err = writer.Write([]byte(testMessage)) - require.NoError(t, err) - - expected, err := json.Marshal(struct { - Type string `json:"type"` - Data string `json:"data"` - }{string(dto.WebSocketOutputStdout), testMessage}) - require.NoError(t, err) - assert.Equal(t, expected, message) -} - -func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasRead(t *testing.T) { - readingCtx, cancel := context.WithCancel(context.Background()) - forwardingCtx := readingCtx - defer cancel() - reader := newCodeOceanToRawReader(nil, readingCtx, forwardingCtx) - - read := make(chan bool) - go func() { - //nolint:makezero // we can't make zero initial length here as the reader otherwise doesn't block - p := make([]byte, 10) - _, err := reader.Read(p) - require.NoError(t, err) - read <- true - }() - - t.Run("Does not return immediately when there is no data", func(t *testing.T) { - assert.False(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout)) - }) - - t.Run("Returns when there is data available", func(t *testing.T) { - reader.buffer <- byte(42) - assert.True(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout)) - }) -} - -func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *testing.T) { - messages := make(chan io.Reader) - - connection := &webSocketConnectionMock{} - connection.On("WriteMessage", mock.AnythingOfType("int"), mock.AnythingOfType("[]uint8")).Return(nil) - connection.On("CloseHandler").Return(nil) - connection.On("SetCloseHandler", mock.Anything).Return() - call := connection.On("NextReader") - call.Run(func(_ mock.Arguments) { - call.Return(websocket.TextMessage, <-messages, nil) - }) - - readingCtx, cancel := context.WithCancel(context.Background()) - forwardingCtx := readingCtx - defer cancel() - reader := newCodeOceanToRawReader(connection, readingCtx, forwardingCtx) - reader.startReadInputLoop() - - read := make(chan bool) - //nolint:makezero // this is required here to make the Read call blocking - message := make([]byte, 10) - go func() { - _, err := reader.Read(message) - require.NoError(t, err) - read <- true - }() - - t.Run("Does not return immediately when there is no data", func(t *testing.T) { - assert.False(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout)) - }) - - t.Run("Returns when there is data available", func(t *testing.T) { - messages <- strings.NewReader("Hello") - assert.True(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout)) - }) -} - func TestWebSocketProxyStopsReadingTheWebSocketAfterClosingIt(t *testing.T) { apiMock := &nomad.ExecutorAPIMock{} executionID := tests.DefaultExecutionID diff --git a/internal/api/ws/codeocean_reader.go b/internal/api/ws/codeocean_reader.go new file mode 100644 index 0000000..8d31541 --- /dev/null +++ b/internal/api/ws/codeocean_reader.go @@ -0,0 +1,177 @@ +package ws + +import ( + "context" + "github.com/gorilla/websocket" + "github.com/openHPI/poseidon/pkg/logging" + "io" +) + +const CodeOceanToRawReaderBufferSize = 1024 + +var log = logging.GetLogger("ws") + +// WebSocketReader is an interface that is intended for providing abstraction around reading from a WebSocket. +// Besides, io.Reader, it also implements io.Writer. The Write method is used to inject data into the WebSocket stream. +type WebSocketReader interface { + io.Reader + io.Writer + Start() + Stop() +} + +// codeOceanToRawReader is an io.Reader implementation that provides the content of the WebSocket connection +// to CodeOcean. You have to start the Reader by calling readInputLoop. After that you can use the Read function. +type codeOceanToRawReader struct { + connection Connection + + // readCtx is the context in that messages from CodeOcean are read. + readCtx context.Context + cancelReadCtx context.CancelFunc + // executorCtx is the context in that messages are forwarded to the executor. + executorCtx context.Context + + // A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket + // and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here + // instead of bytes.Buffer. + buffer chan byte + // The priorityBuffer is a buffer for injecting data into stdin of the execution from Poseidon, + // for example the character that causes the tty to generate a SIGQUIT signal. + // It is always read before the regular buffer. + priorityBuffer chan byte +} + +func NewCodeOceanToRawReader(connection Connection, wsCtx, executorCtx context.Context) *codeOceanToRawReader { + return &codeOceanToRawReader{ + connection: connection, + readCtx: wsCtx, // This context may be canceled before the executorCtx. + cancelReadCtx: func() {}, + executorCtx: executorCtx, + buffer: make(chan byte, CodeOceanToRawReaderBufferSize), + priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize), + } +} + +// readInputLoop reads from the WebSocket connection and buffers the user's input. +// This is necessary because input must be read for the connection to handle special messages like close and call the +// CloseHandler. +func (cr *codeOceanToRawReader) readInputLoop(ctx context.Context) { + readMessage := make(chan bool) + loopContext, cancelInputLoop := context.WithCancel(ctx) + defer cancelInputLoop() + readingContext, cancelNextMessage := context.WithCancel(loopContext) + defer cancelNextMessage() + + for loopContext.Err() == nil { + var messageType int + var reader io.Reader + var err error + + go func() { + messageType, reader, err = cr.connection.NextReader() + select { + case <-readingContext.Done(): + case readMessage <- true: + } + }() + select { + case <-loopContext.Done(): + return + case <-readMessage: + } + + if inputContainsError(messageType, err) { + return + } + if handleInput(reader, cr.buffer, loopContext) { + return + } + } +} + +// handleInput receives a new message from the client and may forward it to the executor. +func handleInput(reader io.Reader, buffer chan byte, ctx context.Context) (done bool) { + message, err := io.ReadAll(reader) + if err != nil { + log.WithError(err).Warn("error while reading WebSocket message") + return true + } + + log.WithField("message", string(message)).Trace("Received message from client") + for _, character := range message { + select { + case <-ctx.Done(): + return true + case buffer <- character: + } + } + return false +} + +func inputContainsError(messageType int, err error) (done bool) { + if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure) { + log.Debug("ReadInputLoop: The client closed the connection!") + // The close handler will do something soon. + return true + } else if err != nil { + log.WithError(err).Warn("Error reading client message") + return true + } + if messageType != websocket.TextMessage { + log.WithField("messageType", messageType).Warn("Received message of wrong type") + return true + } + return false +} + +// Start starts the read input loop asynchronously. +func (cr *codeOceanToRawReader) Start() { + ctx, cancel := context.WithCancel(cr.readCtx) + cr.cancelReadCtx = cancel + go cr.readInputLoop(ctx) +} + +// Stop stops the asynchronous read input loop. +func (cr *codeOceanToRawReader) Stop() { + cr.cancelReadCtx() +} + +// Read implements the io.Reader interface. +// It returns bytes from the buffer or priorityBuffer. +func (cr *codeOceanToRawReader) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + // Ensure to not return until at least one byte has been read to avoid busy waiting. + select { + case <-cr.executorCtx.Done(): + return 0, io.EOF + case p[0] = <-cr.priorityBuffer: + case p[0] = <-cr.buffer: + } + var n int + for n = 1; n < len(p); n++ { + select { + case p[n] = <-cr.priorityBuffer: + case p[n] = <-cr.buffer: + default: + return n, nil + } + } + return n, nil +} + +// Write implements the io.Writer interface. +// Data written to a codeOceanToRawReader using this method is returned by Read before other data from the WebSocket. +func (cr *codeOceanToRawReader) Write(p []byte) (n int, err error) { + var c byte + for n, c = range p { + select { + case cr.priorityBuffer <- c: + default: + break + } + } + return n, nil +} diff --git a/internal/api/ws/codeocean_reader_test.go b/internal/api/ws/codeocean_reader_test.go new file mode 100644 index 0000000..639910e --- /dev/null +++ b/internal/api/ws/codeocean_reader_test.go @@ -0,0 +1,75 @@ +package ws + +import ( + "context" + "github.com/gorilla/websocket" + "github.com/openHPI/poseidon/tests" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "io" + "strings" + "testing" +) + +func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasRead(t *testing.T) { + readingCtx, cancel := context.WithCancel(context.Background()) + forwardingCtx := readingCtx + defer cancel() + reader := NewCodeOceanToRawReader(nil, readingCtx, forwardingCtx) + + read := make(chan bool) + go func() { + //nolint:makezero // we can't make zero initial length here as the reader otherwise doesn't block + p := make([]byte, 10) + _, err := reader.Read(p) + require.NoError(t, err) + read <- true + }() + + t.Run("Does not return immediately when there is no data", func(t *testing.T) { + assert.False(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout)) + }) + + t.Run("Returns when there is data available", func(t *testing.T) { + reader.buffer <- byte(42) + assert.True(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout)) + }) +} + +func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *testing.T) { + messages := make(chan io.Reader) + + connection := &ConnectionMock{} + connection.On("WriteMessage", mock.AnythingOfType("int"), mock.AnythingOfType("[]uint8")).Return(nil) + connection.On("CloseHandler").Return(nil) + connection.On("SetCloseHandler", mock.Anything).Return() + call := connection.On("NextReader") + call.Run(func(_ mock.Arguments) { + call.Return(websocket.TextMessage, <-messages, nil) + }) + + readingCtx, cancel := context.WithCancel(context.Background()) + forwardingCtx := readingCtx + defer cancel() + reader := NewCodeOceanToRawReader(connection, readingCtx, forwardingCtx) + reader.Start() + + read := make(chan bool) + //nolint:makezero // this is required here to make the Read call blocking + message := make([]byte, 10) + go func() { + _, err := reader.Read(message) + require.NoError(t, err) + read <- true + }() + + t.Run("Does not return immediately when there is no data", func(t *testing.T) { + assert.False(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout)) + }) + + t.Run("Returns when there is data available", func(t *testing.T) { + messages <- strings.NewReader("Hello") + assert.True(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout)) + }) +} diff --git a/internal/api/ws/codeocean_writer.go b/internal/api/ws/codeocean_writer.go new file mode 100644 index 0000000..94d4e8a --- /dev/null +++ b/internal/api/ws/codeocean_writer.go @@ -0,0 +1,155 @@ +package ws + +import ( + "context" + "encoding/json" + "errors" + "github.com/gorilla/websocket" + "github.com/openHPI/poseidon/internal/runner" + "github.com/openHPI/poseidon/pkg/dto" + "io" +) + +// CodeOceanOutputWriterBufferSize defines the number of messages. +const CodeOceanOutputWriterBufferSize = 64 + +// rawToCodeOceanWriter is a simple io.Writer implementation that just forwards the call to sendMessage. +type rawToCodeOceanWriter struct { + sendMessage func(string) +} + +// Write implements the io.Writer interface. +func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) { + rc.sendMessage(string(p)) + return len(p), nil +} + +// WebSocketWriter is an interface that defines which data is required and which information can be passed. +type WebSocketWriter interface { + StdOut() io.Writer + StdErr() io.Writer + SendExitInfo(info *runner.ExitInfo) +} + +// codeOceanOutputWriter is a concrete WebSocketWriter implementation. +// It forwards the data written to stdOut or stdErr (Nomad, AWS) to the WebSocket connection (CodeOcean). +type codeOceanOutputWriter struct { + connection Connection + stdOut io.Writer + stdErr io.Writer + queue chan *writingLoopMessage + stopped bool +} + +// writingLoopMessage is an internal data structure to notify the writing loop when it should stop. +type writingLoopMessage struct { + done bool + data *dto.WebSocketMessage +} + +// NewCodeOceanOutputWriter provies an codeOceanOutputWriter for the time the context ctx is active. +// The codeOceanOutputWriter handles all the messages defined in the websocket.schema.json (start, timeout, stdout, ..). +func NewCodeOceanOutputWriter(connection Connection, ctx context.Context) *codeOceanOutputWriter { + cw := &codeOceanOutputWriter{ + connection: connection, + queue: make(chan *writingLoopMessage, CodeOceanOutputWriterBufferSize), + stopped: false, + } + cw.stdOut = &rawToCodeOceanWriter{sendMessage: func(s string) { + cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputStdout, Data: s}) + }} + cw.stdErr = &rawToCodeOceanWriter{sendMessage: func(s string) { + cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputStderr, Data: s}) + }} + + go cw.startWritingLoop() + go cw.stopWhenContextDone(ctx) + cw.send(&dto.WebSocketMessage{Type: dto.WebSocketMetaStart}) + return cw +} + +// StdOut provides an io.Writer that forwards the written data to CodeOcean as StdOut stream. +func (cw *codeOceanOutputWriter) StdOut() io.Writer { + return cw.stdOut +} + +// StdErr provides an io.Writer that forwards the written data to CodeOcean as StdErr stream. +func (cw *codeOceanOutputWriter) StdErr() io.Writer { + return cw.stdErr +} + +// SendExitInfo forwards the kind of exit (timeout, error, normal) to CodeOcean. +func (cw *codeOceanOutputWriter) SendExitInfo(info *runner.ExitInfo) { + switch { + case errors.Is(info.Err, context.DeadlineExceeded) || errors.Is(info.Err, runner.ErrorRunnerInactivityTimeout): + cw.send(&dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout}) + case info.Err != nil: + errorMessage := "Error executing the request" + log.WithError(info.Err).Warn(errorMessage) + cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage}) + default: + cw.send(&dto.WebSocketMessage{Type: dto.WebSocketExit, ExitCode: info.Code}) + } +} + +// stopWhenContextDone notifies the writing loop to stop after the context has been passed. +func (cw *codeOceanOutputWriter) stopWhenContextDone(ctx context.Context) { + <-ctx.Done() + if !cw.stopped { + cw.queue <- &writingLoopMessage{done: true} + } +} + +// send forwards the passed dto.WebSocketMessage to the writing loop. +func (cw *codeOceanOutputWriter) send(message *dto.WebSocketMessage) { + if cw.stopped { + return + } + done := message.Type == dto.WebSocketExit || + message.Type == dto.WebSocketMetaTimeout || + message.Type == dto.WebSocketOutputError + cw.queue <- &writingLoopMessage{done: done, data: message} +} + +// startWritingLoop enables the writing loop. +// This is the central and only place where written changes to the WebSocket connection should be done. +// It synchronizes the messages to provide state checks of the WebSocket connection. +func (cw *codeOceanOutputWriter) startWritingLoop() { + for { + message := <-cw.queue + done := sendMessage(cw.connection, message.data) + if done || message.done { + break + } + } + cw.stopped = true + message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + err := cw.connection.WriteMessage(websocket.CloseMessage, message) + err2 := cw.connection.Close() + if err != nil || err2 != nil { + log.WithError(err).WithField("err2", err2).Warn("Error during websocket close") + } +} + +// sendMessage is a helper function for the writing loop. It must not be called from somewhere else! +func sendMessage(connection Connection, message *dto.WebSocketMessage) (done bool) { + if message == nil { + return false + } + + encodedMessage, err := json.Marshal(message) + if err != nil { + log.WithField("message", message).WithError(err).Warn("Marshal error") + return false + } + + log.WithField("message", message).Trace("Sending message to client") + err = connection.WriteMessage(websocket.TextMessage, encodedMessage) + if err != nil { + errorMessage := "Error writing the message" + log.WithField("message", message).WithError(err).Warn(errorMessage) + return true + } + + return false +} diff --git a/internal/api/ws/codeocean_writer_test.go b/internal/api/ws/codeocean_writer_test.go new file mode 100644 index 0000000..9d0fc94 --- /dev/null +++ b/internal/api/ws/codeocean_writer_test.go @@ -0,0 +1,100 @@ +package ws + +import ( + "context" + "encoding/json" + "github.com/gorilla/websocket" + "github.com/openHPI/poseidon/internal/runner" + "github.com/openHPI/poseidon/pkg/dto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "testing" +) + +func TestRawToCodeOceanWriter(t *testing.T) { + connectionMock, message := buildConnectionMock(t) + proxyCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + output := NewCodeOceanOutputWriter(connectionMock, proxyCtx) + <-message // start message + + t.Run("StdOut", func(t *testing.T) { + testMessage := "testStdOut" + _, err := output.StdOut().Write([]byte(testMessage)) + require.NoError(t, err) + + expected, err := json.Marshal(struct { + Type string `json:"type"` + Data string `json:"data"` + }{string(dto.WebSocketOutputStdout), testMessage}) + require.NoError(t, err) + + assert.Equal(t, expected, <-message) + }) + + t.Run("StdErr", func(t *testing.T) { + testMessage := "testStdErr" + _, err := output.StdErr().Write([]byte(testMessage)) + require.NoError(t, err) + + expected, err := json.Marshal(struct { + Type string `json:"type"` + Data string `json:"data"` + }{string(dto.WebSocketOutputStderr), testMessage}) + require.NoError(t, err) + + assert.Equal(t, expected, <-message) + }) +} + +type sendExitInfoTestCase struct { + name string + info *runner.ExitInfo + message dto.WebSocketMessage +} + +func TestCodeOceanOutputWriter_SendExitInfo(t *testing.T) { + testCases := []sendExitInfoTestCase{ + {"Timeout", &runner.ExitInfo{Err: runner.ErrorRunnerInactivityTimeout}, + dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout}}, + {"Error", &runner.ExitInfo{Err: websocket.ErrCloseSent}, + dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: "Error executing the request"}}, + {"Exit", &runner.ExitInfo{Code: 21}, + dto.WebSocketMessage{Type: dto.WebSocketExit, ExitCode: 21}}, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + connectionMock, message := buildConnectionMock(t) + proxyCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + output := NewCodeOceanOutputWriter(connectionMock, proxyCtx) + <-message // start message + + output.SendExitInfo(test.info) + expected, err := json.Marshal(test.message) + require.NoError(t, err) + + msg := <-message + assert.Equal(t, expected, msg) + }) + } +} + +func buildConnectionMock(t *testing.T) (conn *ConnectionMock, messages chan []byte) { + t.Helper() + message := make(chan []byte) + connectionMock := &ConnectionMock{} + connectionMock.On("WriteMessage", mock.AnythingOfType("int"), mock.AnythingOfType("[]uint8")). + Run(func(args mock.Arguments) { + m, ok := args.Get(1).([]byte) + require.True(t, ok) + message <- m + }). + Return(nil) + connectionMock.On("CloseHandler").Return(nil) + connectionMock.On("SetCloseHandler", mock.Anything).Return() + connectionMock.On("Close").Return() + return connectionMock, message +} diff --git a/internal/api/ws/connection.go b/internal/api/ws/connection.go new file mode 100644 index 0000000..c486073 --- /dev/null +++ b/internal/api/ws/connection.go @@ -0,0 +1,14 @@ +package ws + +import ( + "io" +) + +// Connection is an internal interface for websocket.Conn in order to mock it for unit tests. +type Connection interface { + WriteMessage(messageType int, data []byte) error + Close() error + NextReader() (messageType int, r io.Reader, err error) + CloseHandler() func(code int, text string) error + SetCloseHandler(handler func(code int, text string) error) +} diff --git a/internal/api/websocket_connection_mock.go b/internal/api/ws/connection_mock.go similarity index 58% rename from internal/api/websocket_connection_mock.go rename to internal/api/ws/connection_mock.go index 792958e..c7eb31c 100644 --- a/internal/api/websocket_connection_mock.go +++ b/internal/api/ws/connection_mock.go @@ -1,6 +1,6 @@ -// Code generated by mockery v0.0.0-dev. DO NOT EDIT. +// Code generated by mockery v2.13.1. DO NOT EDIT. -package api +package ws import ( io "io" @@ -8,13 +8,13 @@ import ( mock "github.com/stretchr/testify/mock" ) -// webSocketConnectionMock is an autogenerated mock type for the webSocketConnection type -type webSocketConnectionMock struct { +// ConnectionMock is an autogenerated mock type for the Connection type +type ConnectionMock struct { mock.Mock } // Close provides a mock function with given fields: -func (_m *webSocketConnectionMock) Close() error { +func (_m *ConnectionMock) Close() error { ret := _m.Called() var r0 error @@ -28,7 +28,7 @@ func (_m *webSocketConnectionMock) Close() error { } // CloseHandler provides a mock function with given fields: -func (_m *webSocketConnectionMock) CloseHandler() func(int, string) error { +func (_m *ConnectionMock) CloseHandler() func(int, string) error { ret := _m.Called() var r0 func(int, string) error @@ -44,7 +44,7 @@ func (_m *webSocketConnectionMock) CloseHandler() func(int, string) error { } // NextReader provides a mock function with given fields: -func (_m *webSocketConnectionMock) NextReader() (int, io.Reader, error) { +func (_m *ConnectionMock) NextReader() (int, io.Reader, error) { ret := _m.Called() var r0 int @@ -73,13 +73,13 @@ func (_m *webSocketConnectionMock) NextReader() (int, io.Reader, error) { return r0, r1, r2 } -// SetCloseHandler provides a mock function with given fields: h -func (_m *webSocketConnectionMock) SetCloseHandler(h func(int, string) error) { - _m.Called(h) +// SetCloseHandler provides a mock function with given fields: handler +func (_m *ConnectionMock) SetCloseHandler(handler func(int, string) error) { + _m.Called(handler) } // WriteMessage provides a mock function with given fields: messageType, data -func (_m *webSocketConnectionMock) WriteMessage(messageType int, data []byte) error { +func (_m *ConnectionMock) WriteMessage(messageType int, data []byte) error { ret := _m.Called(messageType, data) var r0 error @@ -91,3 +91,18 @@ func (_m *webSocketConnectionMock) WriteMessage(messageType int, data []byte) er return r0 } + +type mockConstructorTestingTNewConnectionMock interface { + mock.TestingT + Cleanup(func()) +} + +// NewConnectionMock creates a new instance of ConnectionMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewConnectionMock(t mockConstructorTestingTNewConnectionMock) *ConnectionMock { + mock := &ConnectionMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}