#155 refactor and synchronise writing to CodeOcean. (#174)

* #155 refactor and synchronise writing to CodeOcean.

* Reduce complexity of input parsing.

* Update typo in internal/api/ws/codeocean_writer.go

Co-authored-by: Sebastian Serth <MrSerth@users.noreply.github.com>
This commit is contained in:
Maximilian Paß
2022-06-26 20:19:23 +02:00
committed by GitHub
parent a4d13fb8cb
commit 203d5a3a4f
8 changed files with 569 additions and 410 deletions

View File

@ -2,217 +2,29 @@ package api
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/openHPI/poseidon/internal/api/ws"
"github.com/openHPI/poseidon/internal/runner" "github.com/openHPI/poseidon/internal/runner"
"github.com/openHPI/poseidon/pkg/dto" "github.com/openHPI/poseidon/pkg/dto"
"github.com/openHPI/poseidon/pkg/logging" "github.com/openHPI/poseidon/pkg/logging"
"github.com/openHPI/poseidon/pkg/monitoring" "github.com/openHPI/poseidon/pkg/monitoring"
"io"
"net/http" "net/http"
"sync"
) )
const CodeOceanToRawReaderBufferSize = 1024
var ErrUnknownExecutionID = errors.New("execution id unknown") var ErrUnknownExecutionID = errors.New("execution id unknown")
type webSocketConnection interface {
WriteMessage(messageType int, data []byte) error
Close() error
NextReader() (messageType int, r io.Reader, err error)
CloseHandler() func(code int, text string) error
SetCloseHandler(handler func(code int, text string) error)
}
// WebSocketReader is an interface that is intended for providing abstraction around reading from a WebSocket.
// Besides io.Reader, it also implements io.Writer. The Write method is used to inject data into the WebSocket stream.
type WebSocketReader interface {
io.Reader
io.Writer
startReadInputLoop()
stopReadInputLoop()
}
// codeOceanToRawReader is an io.Reader implementation that provides the content of the WebSocket connection
// to CodeOcean. You have to start the Reader by calling readInputLoop. After that you can use the Read function.
type codeOceanToRawReader struct {
connection webSocketConnection
// readCtx is the context in that messages from CodeOcean are read.
readCtx context.Context
cancelReadCtx context.CancelFunc
// executorCtx is the context in that messages are forwarded to the executor.
executorCtx context.Context
// A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket
// and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here
// instead of bytes.Buffer.
buffer chan byte
// The priorityBuffer is a buffer for injecting data into stdin of the execution from Poseidon,
// for example the character that causes the tty to generate a SIGQUIT signal.
// It is always read before the regular buffer.
priorityBuffer chan byte
}
func newCodeOceanToRawReader(connection webSocketConnection, wsCtx, executorCtx context.Context) *codeOceanToRawReader {
return &codeOceanToRawReader{
connection: connection,
readCtx: wsCtx, // This context may be canceled before the executorCtx.
cancelReadCtx: func() {},
executorCtx: executorCtx,
buffer: make(chan byte, CodeOceanToRawReaderBufferSize),
priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize),
}
}
// readInputLoop reads from the WebSocket connection and buffers the user's input.
// This is necessary because input must be read for the connection to handle special messages like close and call the
// CloseHandler.
func (cr *codeOceanToRawReader) readInputLoop(ctx context.Context) {
readMessage := make(chan bool)
loopContext, cancelInputLoop := context.WithCancel(ctx)
defer cancelInputLoop()
readingContext, cancelNextMessage := context.WithCancel(loopContext)
defer cancelNextMessage()
for loopContext.Err() == nil {
var messageType int
var reader io.Reader
var err error
go func() {
messageType, reader, err = cr.connection.NextReader()
select {
case <-readingContext.Done():
case readMessage <- true:
}
}()
select {
case <-loopContext.Done():
return
case <-readMessage:
}
if handleInput(messageType, reader, err, cr.buffer, loopContext) {
return
}
}
}
// handleInput receives a new message from the client and may forward it to the executor.
func handleInput(messageType int, reader io.Reader, err error, buffer chan byte, ctx context.Context) (done bool) {
if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure) {
log.Debug("ReadInputLoop: The client closed the connection!")
// The close handler will do something soon.
return true
} else if err != nil {
log.WithError(err).Warn("Error reading client message")
return true
}
if messageType != websocket.TextMessage {
log.WithField("messageType", messageType).Warn("Received message of wrong type")
return true
}
message, err := io.ReadAll(reader)
if err != nil {
log.WithError(err).Warn("error while reading WebSocket message")
return true
}
log.WithField("message", string(message)).Trace("Received message from client")
for _, character := range message {
select {
case <-ctx.Done():
return true
case buffer <- character:
}
}
return false
}
// startReadInputLoop start the read input loop asynchronously.
func (cr *codeOceanToRawReader) startReadInputLoop() {
ctx, cancel := context.WithCancel(cr.readCtx)
cr.cancelReadCtx = cancel
go cr.readInputLoop(ctx)
}
// startReadInputLoop stops the asynchronous read input loop.
func (cr *codeOceanToRawReader) stopReadInputLoop() {
cr.cancelReadCtx()
}
// Read implements the io.Reader interface.
// It returns bytes from the buffer or priorityBuffer.
func (cr *codeOceanToRawReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
// Ensure to not return until at least one byte has been read to avoid busy waiting.
select {
case <-cr.executorCtx.Done():
return 0, io.EOF
case p[0] = <-cr.priorityBuffer:
case p[0] = <-cr.buffer:
}
var n int
for n = 1; n < len(p); n++ {
select {
case p[n] = <-cr.priorityBuffer:
case p[n] = <-cr.buffer:
default:
return n, nil
}
}
return n, nil
}
// Write implements the io.Writer interface.
// Data written to a codeOceanToRawReader using this method is returned by Read before other data from the WebSocket.
func (cr *codeOceanToRawReader) Write(p []byte) (n int, err error) {
var c byte
for n, c = range p {
select {
case cr.priorityBuffer <- c:
default:
break
}
}
return n, nil
}
// rawToCodeOceanWriter is an io.Writer implementation that, when written to, wraps the written data in the appropriate
// json structure and sends it to the CodeOcean via WebSocket.
type rawToCodeOceanWriter struct {
proxy *webSocketProxy
outputType dto.WebSocketMessageType
}
// Write implements the io.Writer interface.
// The passed data is forwarded to the WebSocket to CodeOcean.
func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) {
err := rc.proxy.sendToClient(dto.WebSocketMessage{Type: rc.outputType, Data: string(p)})
return len(p), err
}
// webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean. // webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean.
type webSocketProxy struct { type webSocketProxy struct {
webSocketCtx context.Context ctx context.Context
cancelWebSocket context.CancelFunc cancel context.CancelFunc
connection webSocketConnection Input ws.WebSocketReader
connectionMu sync.Mutex Output ws.WebSocketWriter
Stdin WebSocketReader
Stdout io.Writer
Stderr io.Writer
} }
// upgradeConnection upgrades a connection to a websocket and returns a webSocketProxy for this connection. // upgradeConnection upgrades a connection to a websocket and returns a webSocketProxy for this connection.
func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSocketConnection, error) { func upgradeConnection(writer http.ResponseWriter, request *http.Request) (ws.Connection, error) {
connUpgrader := websocket.Upgrader{} connUpgrader := websocket.Upgrader{}
connection, err := connUpgrader.Upgrade(writer, request, nil) connection, err := connUpgrader.Upgrade(writer, request, nil)
if err != nil { if err != nil {
@ -224,31 +36,18 @@ func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSo
// newWebSocketProxy returns an initiated and started webSocketProxy. // newWebSocketProxy returns an initiated and started webSocketProxy.
// As this proxy is already started, a start message is send to the client. // As this proxy is already started, a start message is send to the client.
func newWebSocketProxy(connection webSocketConnection, proxyCtx context.Context) (*webSocketProxy, error) { func newWebSocketProxy(connection ws.Connection, proxyCtx context.Context) (*webSocketProxy, error) {
wsCtx, cancelWsCommunication := context.WithCancel(proxyCtx) wsCtx, cancelWsCommunication := context.WithCancel(proxyCtx)
stdin := newCodeOceanToRawReader(connection, wsCtx, proxyCtx)
proxy := &webSocketProxy{ proxy := &webSocketProxy{
connection: connection, ctx: wsCtx,
Stdin: stdin, cancel: cancelWsCommunication,
webSocketCtx: wsCtx, Input: ws.NewCodeOceanToRawReader(connection, wsCtx, proxyCtx),
cancelWebSocket: cancelWsCommunication, Output: ws.NewCodeOceanOutputWriter(connection, wsCtx),
}
proxy.Stdout = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStdout}
proxy.Stderr = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStderr}
err := proxy.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaStart})
if err != nil {
cancelWsCommunication()
return nil, err
} }
closeHandler := connection.CloseHandler()
connection.SetCloseHandler(func(code int, text string) error { connection.SetCloseHandler(func(code int, text string) error {
log.Info("Before closing the connection via Handler") log.WithField("code", code).WithField("text", text).Debug("The client closed the connection.")
cancelWsCommunication() cancelWsCommunication()
//nolint:errcheck // The default close handler always returns nil, so the error can be safely ignored.
_ = closeHandler(code, text)
log.Info("After closing the connection via Handler")
return nil return nil
}) })
return proxy, nil return proxy, nil
@ -257,99 +56,21 @@ func newWebSocketProxy(connection webSocketConnection, proxyCtx context.Context)
// waitForExit waits for an exit of either the runner (when the command terminates) or the client closing the WebSocket // waitForExit waits for an exit of either the runner (when the command terminates) or the client closing the WebSocket
// and handles WebSocket exit messages. // and handles WebSocket exit messages.
func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecution context.CancelFunc) { func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecution context.CancelFunc) {
wp.Stdin.startReadInputLoop() wp.Input.Start()
var exitInfo runner.ExitInfo var exitInfo runner.ExitInfo
select { select {
case <-wp.webSocketCtx.Done(): case <-wp.ctx.Done():
log.Info("Client closed the connection") log.Info("Client closed the connection")
wp.Input.Stop()
cancelExecution() cancelExecution()
<-exit // /internal/runner/runner.go handleExitOrContextDone does not require client connection anymore. <-exit // /internal/runner/runner.go handleExitOrContextDone does not require client connection anymore.
<-exit // The goroutine closes this channel indicating that it does not use the connection to the executor anymore. <-exit // The goroutine closes this channel indicating that it does not use the connection to the executor anymore.
return
case exitInfo = <-exit: case exitInfo = <-exit:
log.Info("Execution returned") log.Info("Execution returned")
wp.Stdin.stopReadInputLoop() // Here we stop reading from the client wp.Input.Stop()
defer wp.cancelWebSocket() // At the end of this method we stop writing to the client wp.Output.SendExitInfo(&exitInfo)
} }
if errors.Is(exitInfo.Err, context.DeadlineExceeded) || errors.Is(exitInfo.Err, runner.ErrorRunnerInactivityTimeout) {
err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout})
if err == nil {
wp.closeNormal()
} else {
log.WithError(err).Warn("Failed to send timeout message to client")
}
return
} else if exitInfo.Err != nil {
errorMessage := "Error executing the request"
log.WithError(exitInfo.Err).Warn(errorMessage)
err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage})
if err == nil {
wp.closeNormal()
} else {
log.WithError(err).Warn("Failed to send output error message to client")
}
return
}
log.WithField("exit_code", exitInfo.Code).Debug()
err := wp.sendToClient(dto.WebSocketMessage{
Type: dto.WebSocketExit,
ExitCode: exitInfo.Code,
})
if err != nil {
log.WithError(err).Warn("Error sending exit message")
return
}
wp.closeNormal()
}
func (wp *webSocketProxy) sendToClient(message dto.WebSocketMessage) error {
encodedMessage, err := json.Marshal(message)
if err != nil {
log.WithField("message", message).WithError(err).Warn("Marshal error")
wp.closeWithError("Error creating message")
return fmt.Errorf("error marshaling WebSocket message: %w", err)
}
log.WithField("message", message).Trace("Sending message to client")
select {
case <-wp.webSocketCtx.Done():
default:
log.Info("Passed WriteToCodeOceanCheck")
err = wp.writeMessage(websocket.TextMessage, encodedMessage)
if err != nil {
errorMessage := "Error writing the message"
log.WithField("message", message).WithError(err).Warn(errorMessage)
wp.closeWithError(errorMessage)
return fmt.Errorf("error writing WebSocket message: %w", err)
}
}
return nil
}
func (wp *webSocketProxy) closeWithError(message string) {
wp.close(websocket.FormatCloseMessage(websocket.CloseInternalServerErr, message))
}
func (wp *webSocketProxy) closeNormal() {
wp.close(websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
}
func (wp *webSocketProxy) close(message []byte) {
log.Info("Before closing the connection manually")
err := wp.writeMessage(websocket.CloseMessage, message)
_ = wp.connection.Close()
if err != nil {
log.WithError(err).Warn("Error during websocket close")
}
log.Info("After closing the connection manually")
}
func (wp *webSocketProxy) writeMessage(messageType int, data []byte) error {
wp.connectionMu.Lock()
defer wp.connectionMu.Unlock()
return wp.connection.WriteMessage(messageType, data) //nolint:wrapcheck // Wrap the original WriteMessage in a mutex.
} }
// connectToRunner is the endpoint for websocket connections. // connectToRunner is the endpoint for websocket connections.
@ -378,10 +99,11 @@ func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request *
log.WithField("runnerId", targetRunner.ID()). log.WithField("runnerId", targetRunner.ID()).
WithField("executionID", logging.RemoveNewlineSymbol(executionID)). WithField("executionID", logging.RemoveNewlineSymbol(executionID)).
Info("Running execution") Info("Running execution")
exit, cancel, err := targetRunner.ExecuteInteractively(executionID, proxy.Stdin, proxy.Stdout, proxy.Stderr) exit, cancel, err := targetRunner.ExecuteInteractively(executionID,
proxy.Input, proxy.Output.StdOut(), proxy.Output.StdErr())
if err != nil { if err != nil {
proxy.closeWithError(fmt.Sprintf("execution failed with: %v", err)) log.WithError(err).Warn("Cannot execute request.")
return return // The proxy is stopped by the defered cancel.
} }
proxy.waitForExit(exit, cancel) proxy.waitForExit(exit, cancel)

View File

@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/json"
"fmt" "fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -24,7 +23,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings"
"testing" "testing"
"time" "time"
) )
@ -100,7 +98,7 @@ func (s *WebSocketTestSuite) TestWebsocketConnection() {
}) })
s.Run("Executes the request in the runner", func() { s.Run("Executes the request in the runner", func() {
<-time.After(100 * time.Millisecond) <-time.After(tests.ShortTimeout)
s.apiMock.AssertCalled(s.T(), "ExecuteCommand", s.apiMock.AssertCalled(s.T(), "ExecuteCommand",
mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything) mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything)
}) })
@ -281,103 +279,6 @@ func TestWebsocketTLS(t *testing.T) {
assert.True(t, websocket.IsCloseError(err, websocket.CloseNormalClosure)) assert.True(t, websocket.IsCloseError(err, websocket.CloseNormalClosure))
} }
func TestRawToCodeOceanWriter(t *testing.T) {
testMessage := "test"
var message []byte
connectionMock := &webSocketConnectionMock{}
connectionMock.On("WriteMessage", mock.AnythingOfType("int"), mock.AnythingOfType("[]uint8")).
Run(func(args mock.Arguments) {
var ok bool
message, ok = args.Get(1).([]byte)
require.True(t, ok)
}).
Return(nil)
connectionMock.On("CloseHandler").Return(nil)
connectionMock.On("SetCloseHandler", mock.Anything).Return()
proxyCtx, cancel := context.WithCancel(context.Background())
defer cancel()
proxy, err := newWebSocketProxy(connectionMock, proxyCtx)
require.NoError(t, err)
writer := &rawToCodeOceanWriter{
proxy: proxy,
outputType: dto.WebSocketOutputStdout,
}
_, err = writer.Write([]byte(testMessage))
require.NoError(t, err)
expected, err := json.Marshal(struct {
Type string `json:"type"`
Data string `json:"data"`
}{string(dto.WebSocketOutputStdout), testMessage})
require.NoError(t, err)
assert.Equal(t, expected, message)
}
func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasRead(t *testing.T) {
readingCtx, cancel := context.WithCancel(context.Background())
forwardingCtx := readingCtx
defer cancel()
reader := newCodeOceanToRawReader(nil, readingCtx, forwardingCtx)
read := make(chan bool)
go func() {
//nolint:makezero // we can't make zero initial length here as the reader otherwise doesn't block
p := make([]byte, 10)
_, err := reader.Read(p)
require.NoError(t, err)
read <- true
}()
t.Run("Does not return immediately when there is no data", func(t *testing.T) {
assert.False(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
})
t.Run("Returns when there is data available", func(t *testing.T) {
reader.buffer <- byte(42)
assert.True(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
})
}
func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *testing.T) {
messages := make(chan io.Reader)
connection := &webSocketConnectionMock{}
connection.On("WriteMessage", mock.AnythingOfType("int"), mock.AnythingOfType("[]uint8")).Return(nil)
connection.On("CloseHandler").Return(nil)
connection.On("SetCloseHandler", mock.Anything).Return()
call := connection.On("NextReader")
call.Run(func(_ mock.Arguments) {
call.Return(websocket.TextMessage, <-messages, nil)
})
readingCtx, cancel := context.WithCancel(context.Background())
forwardingCtx := readingCtx
defer cancel()
reader := newCodeOceanToRawReader(connection, readingCtx, forwardingCtx)
reader.startReadInputLoop()
read := make(chan bool)
//nolint:makezero // this is required here to make the Read call blocking
message := make([]byte, 10)
go func() {
_, err := reader.Read(message)
require.NoError(t, err)
read <- true
}()
t.Run("Does not return immediately when there is no data", func(t *testing.T) {
assert.False(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
})
t.Run("Returns when there is data available", func(t *testing.T) {
messages <- strings.NewReader("Hello")
assert.True(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
})
}
func TestWebSocketProxyStopsReadingTheWebSocketAfterClosingIt(t *testing.T) { func TestWebSocketProxyStopsReadingTheWebSocketAfterClosingIt(t *testing.T) {
apiMock := &nomad.ExecutorAPIMock{} apiMock := &nomad.ExecutorAPIMock{}
executionID := tests.DefaultExecutionID executionID := tests.DefaultExecutionID

View File

@ -0,0 +1,177 @@
package ws
import (
"context"
"github.com/gorilla/websocket"
"github.com/openHPI/poseidon/pkg/logging"
"io"
)
const CodeOceanToRawReaderBufferSize = 1024
var log = logging.GetLogger("ws")
// WebSocketReader is an interface that is intended for providing abstraction around reading from a WebSocket.
// Besides, io.Reader, it also implements io.Writer. The Write method is used to inject data into the WebSocket stream.
type WebSocketReader interface {
io.Reader
io.Writer
Start()
Stop()
}
// codeOceanToRawReader is an io.Reader implementation that provides the content of the WebSocket connection
// to CodeOcean. You have to start the Reader by calling readInputLoop. After that you can use the Read function.
type codeOceanToRawReader struct {
connection Connection
// readCtx is the context in that messages from CodeOcean are read.
readCtx context.Context
cancelReadCtx context.CancelFunc
// executorCtx is the context in that messages are forwarded to the executor.
executorCtx context.Context
// A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket
// and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here
// instead of bytes.Buffer.
buffer chan byte
// The priorityBuffer is a buffer for injecting data into stdin of the execution from Poseidon,
// for example the character that causes the tty to generate a SIGQUIT signal.
// It is always read before the regular buffer.
priorityBuffer chan byte
}
func NewCodeOceanToRawReader(connection Connection, wsCtx, executorCtx context.Context) *codeOceanToRawReader {
return &codeOceanToRawReader{
connection: connection,
readCtx: wsCtx, // This context may be canceled before the executorCtx.
cancelReadCtx: func() {},
executorCtx: executorCtx,
buffer: make(chan byte, CodeOceanToRawReaderBufferSize),
priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize),
}
}
// readInputLoop reads from the WebSocket connection and buffers the user's input.
// This is necessary because input must be read for the connection to handle special messages like close and call the
// CloseHandler.
func (cr *codeOceanToRawReader) readInputLoop(ctx context.Context) {
readMessage := make(chan bool)
loopContext, cancelInputLoop := context.WithCancel(ctx)
defer cancelInputLoop()
readingContext, cancelNextMessage := context.WithCancel(loopContext)
defer cancelNextMessage()
for loopContext.Err() == nil {
var messageType int
var reader io.Reader
var err error
go func() {
messageType, reader, err = cr.connection.NextReader()
select {
case <-readingContext.Done():
case readMessage <- true:
}
}()
select {
case <-loopContext.Done():
return
case <-readMessage:
}
if inputContainsError(messageType, err) {
return
}
if handleInput(reader, cr.buffer, loopContext) {
return
}
}
}
// handleInput receives a new message from the client and may forward it to the executor.
func handleInput(reader io.Reader, buffer chan byte, ctx context.Context) (done bool) {
message, err := io.ReadAll(reader)
if err != nil {
log.WithError(err).Warn("error while reading WebSocket message")
return true
}
log.WithField("message", string(message)).Trace("Received message from client")
for _, character := range message {
select {
case <-ctx.Done():
return true
case buffer <- character:
}
}
return false
}
func inputContainsError(messageType int, err error) (done bool) {
if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure) {
log.Debug("ReadInputLoop: The client closed the connection!")
// The close handler will do something soon.
return true
} else if err != nil {
log.WithError(err).Warn("Error reading client message")
return true
}
if messageType != websocket.TextMessage {
log.WithField("messageType", messageType).Warn("Received message of wrong type")
return true
}
return false
}
// Start starts the read input loop asynchronously.
func (cr *codeOceanToRawReader) Start() {
ctx, cancel := context.WithCancel(cr.readCtx)
cr.cancelReadCtx = cancel
go cr.readInputLoop(ctx)
}
// Stop stops the asynchronous read input loop.
func (cr *codeOceanToRawReader) Stop() {
cr.cancelReadCtx()
}
// Read implements the io.Reader interface.
// It returns bytes from the buffer or priorityBuffer.
func (cr *codeOceanToRawReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
// Ensure to not return until at least one byte has been read to avoid busy waiting.
select {
case <-cr.executorCtx.Done():
return 0, io.EOF
case p[0] = <-cr.priorityBuffer:
case p[0] = <-cr.buffer:
}
var n int
for n = 1; n < len(p); n++ {
select {
case p[n] = <-cr.priorityBuffer:
case p[n] = <-cr.buffer:
default:
return n, nil
}
}
return n, nil
}
// Write implements the io.Writer interface.
// Data written to a codeOceanToRawReader using this method is returned by Read before other data from the WebSocket.
func (cr *codeOceanToRawReader) Write(p []byte) (n int, err error) {
var c byte
for n, c = range p {
select {
case cr.priorityBuffer <- c:
default:
break
}
}
return n, nil
}

View File

@ -0,0 +1,75 @@
package ws
import (
"context"
"github.com/gorilla/websocket"
"github.com/openHPI/poseidon/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"io"
"strings"
"testing"
)
func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasRead(t *testing.T) {
readingCtx, cancel := context.WithCancel(context.Background())
forwardingCtx := readingCtx
defer cancel()
reader := NewCodeOceanToRawReader(nil, readingCtx, forwardingCtx)
read := make(chan bool)
go func() {
//nolint:makezero // we can't make zero initial length here as the reader otherwise doesn't block
p := make([]byte, 10)
_, err := reader.Read(p)
require.NoError(t, err)
read <- true
}()
t.Run("Does not return immediately when there is no data", func(t *testing.T) {
assert.False(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
})
t.Run("Returns when there is data available", func(t *testing.T) {
reader.buffer <- byte(42)
assert.True(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
})
}
func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *testing.T) {
messages := make(chan io.Reader)
connection := &ConnectionMock{}
connection.On("WriteMessage", mock.AnythingOfType("int"), mock.AnythingOfType("[]uint8")).Return(nil)
connection.On("CloseHandler").Return(nil)
connection.On("SetCloseHandler", mock.Anything).Return()
call := connection.On("NextReader")
call.Run(func(_ mock.Arguments) {
call.Return(websocket.TextMessage, <-messages, nil)
})
readingCtx, cancel := context.WithCancel(context.Background())
forwardingCtx := readingCtx
defer cancel()
reader := NewCodeOceanToRawReader(connection, readingCtx, forwardingCtx)
reader.Start()
read := make(chan bool)
//nolint:makezero // this is required here to make the Read call blocking
message := make([]byte, 10)
go func() {
_, err := reader.Read(message)
require.NoError(t, err)
read <- true
}()
t.Run("Does not return immediately when there is no data", func(t *testing.T) {
assert.False(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
})
t.Run("Returns when there is data available", func(t *testing.T) {
messages <- strings.NewReader("Hello")
assert.True(t, tests.ChannelReceivesSomething(read, tests.ShortTimeout))
})
}

View File

@ -0,0 +1,155 @@
package ws
import (
"context"
"encoding/json"
"errors"
"github.com/gorilla/websocket"
"github.com/openHPI/poseidon/internal/runner"
"github.com/openHPI/poseidon/pkg/dto"
"io"
)
// CodeOceanOutputWriterBufferSize defines the number of messages.
const CodeOceanOutputWriterBufferSize = 64
// rawToCodeOceanWriter is a simple io.Writer implementation that just forwards the call to sendMessage.
type rawToCodeOceanWriter struct {
sendMessage func(string)
}
// Write implements the io.Writer interface.
func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) {
rc.sendMessage(string(p))
return len(p), nil
}
// WebSocketWriter is an interface that defines which data is required and which information can be passed.
type WebSocketWriter interface {
StdOut() io.Writer
StdErr() io.Writer
SendExitInfo(info *runner.ExitInfo)
}
// codeOceanOutputWriter is a concrete WebSocketWriter implementation.
// It forwards the data written to stdOut or stdErr (Nomad, AWS) to the WebSocket connection (CodeOcean).
type codeOceanOutputWriter struct {
connection Connection
stdOut io.Writer
stdErr io.Writer
queue chan *writingLoopMessage
stopped bool
}
// writingLoopMessage is an internal data structure to notify the writing loop when it should stop.
type writingLoopMessage struct {
done bool
data *dto.WebSocketMessage
}
// NewCodeOceanOutputWriter provies an codeOceanOutputWriter for the time the context ctx is active.
// The codeOceanOutputWriter handles all the messages defined in the websocket.schema.json (start, timeout, stdout, ..).
func NewCodeOceanOutputWriter(connection Connection, ctx context.Context) *codeOceanOutputWriter {
cw := &codeOceanOutputWriter{
connection: connection,
queue: make(chan *writingLoopMessage, CodeOceanOutputWriterBufferSize),
stopped: false,
}
cw.stdOut = &rawToCodeOceanWriter{sendMessage: func(s string) {
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputStdout, Data: s})
}}
cw.stdErr = &rawToCodeOceanWriter{sendMessage: func(s string) {
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputStderr, Data: s})
}}
go cw.startWritingLoop()
go cw.stopWhenContextDone(ctx)
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketMetaStart})
return cw
}
// StdOut provides an io.Writer that forwards the written data to CodeOcean as StdOut stream.
func (cw *codeOceanOutputWriter) StdOut() io.Writer {
return cw.stdOut
}
// StdErr provides an io.Writer that forwards the written data to CodeOcean as StdErr stream.
func (cw *codeOceanOutputWriter) StdErr() io.Writer {
return cw.stdErr
}
// SendExitInfo forwards the kind of exit (timeout, error, normal) to CodeOcean.
func (cw *codeOceanOutputWriter) SendExitInfo(info *runner.ExitInfo) {
switch {
case errors.Is(info.Err, context.DeadlineExceeded) || errors.Is(info.Err, runner.ErrorRunnerInactivityTimeout):
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout})
case info.Err != nil:
errorMessage := "Error executing the request"
log.WithError(info.Err).Warn(errorMessage)
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage})
default:
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketExit, ExitCode: info.Code})
}
}
// stopWhenContextDone notifies the writing loop to stop after the context has been passed.
func (cw *codeOceanOutputWriter) stopWhenContextDone(ctx context.Context) {
<-ctx.Done()
if !cw.stopped {
cw.queue <- &writingLoopMessage{done: true}
}
}
// send forwards the passed dto.WebSocketMessage to the writing loop.
func (cw *codeOceanOutputWriter) send(message *dto.WebSocketMessage) {
if cw.stopped {
return
}
done := message.Type == dto.WebSocketExit ||
message.Type == dto.WebSocketMetaTimeout ||
message.Type == dto.WebSocketOutputError
cw.queue <- &writingLoopMessage{done: done, data: message}
}
// startWritingLoop enables the writing loop.
// This is the central and only place where written changes to the WebSocket connection should be done.
// It synchronizes the messages to provide state checks of the WebSocket connection.
func (cw *codeOceanOutputWriter) startWritingLoop() {
for {
message := <-cw.queue
done := sendMessage(cw.connection, message.data)
if done || message.done {
break
}
}
cw.stopped = true
message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
err := cw.connection.WriteMessage(websocket.CloseMessage, message)
err2 := cw.connection.Close()
if err != nil || err2 != nil {
log.WithError(err).WithField("err2", err2).Warn("Error during websocket close")
}
}
// sendMessage is a helper function for the writing loop. It must not be called from somewhere else!
func sendMessage(connection Connection, message *dto.WebSocketMessage) (done bool) {
if message == nil {
return false
}
encodedMessage, err := json.Marshal(message)
if err != nil {
log.WithField("message", message).WithError(err).Warn("Marshal error")
return false
}
log.WithField("message", message).Trace("Sending message to client")
err = connection.WriteMessage(websocket.TextMessage, encodedMessage)
if err != nil {
errorMessage := "Error writing the message"
log.WithField("message", message).WithError(err).Warn(errorMessage)
return true
}
return false
}

View File

@ -0,0 +1,100 @@
package ws
import (
"context"
"encoding/json"
"github.com/gorilla/websocket"
"github.com/openHPI/poseidon/internal/runner"
"github.com/openHPI/poseidon/pkg/dto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"testing"
)
func TestRawToCodeOceanWriter(t *testing.T) {
connectionMock, message := buildConnectionMock(t)
proxyCtx, cancel := context.WithCancel(context.Background())
defer cancel()
output := NewCodeOceanOutputWriter(connectionMock, proxyCtx)
<-message // start message
t.Run("StdOut", func(t *testing.T) {
testMessage := "testStdOut"
_, err := output.StdOut().Write([]byte(testMessage))
require.NoError(t, err)
expected, err := json.Marshal(struct {
Type string `json:"type"`
Data string `json:"data"`
}{string(dto.WebSocketOutputStdout), testMessage})
require.NoError(t, err)
assert.Equal(t, expected, <-message)
})
t.Run("StdErr", func(t *testing.T) {
testMessage := "testStdErr"
_, err := output.StdErr().Write([]byte(testMessage))
require.NoError(t, err)
expected, err := json.Marshal(struct {
Type string `json:"type"`
Data string `json:"data"`
}{string(dto.WebSocketOutputStderr), testMessage})
require.NoError(t, err)
assert.Equal(t, expected, <-message)
})
}
type sendExitInfoTestCase struct {
name string
info *runner.ExitInfo
message dto.WebSocketMessage
}
func TestCodeOceanOutputWriter_SendExitInfo(t *testing.T) {
testCases := []sendExitInfoTestCase{
{"Timeout", &runner.ExitInfo{Err: runner.ErrorRunnerInactivityTimeout},
dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout}},
{"Error", &runner.ExitInfo{Err: websocket.ErrCloseSent},
dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: "Error executing the request"}},
{"Exit", &runner.ExitInfo{Code: 21},
dto.WebSocketMessage{Type: dto.WebSocketExit, ExitCode: 21}},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
connectionMock, message := buildConnectionMock(t)
proxyCtx, cancel := context.WithCancel(context.Background())
defer cancel()
output := NewCodeOceanOutputWriter(connectionMock, proxyCtx)
<-message // start message
output.SendExitInfo(test.info)
expected, err := json.Marshal(test.message)
require.NoError(t, err)
msg := <-message
assert.Equal(t, expected, msg)
})
}
}
func buildConnectionMock(t *testing.T) (conn *ConnectionMock, messages chan []byte) {
t.Helper()
message := make(chan []byte)
connectionMock := &ConnectionMock{}
connectionMock.On("WriteMessage", mock.AnythingOfType("int"), mock.AnythingOfType("[]uint8")).
Run(func(args mock.Arguments) {
m, ok := args.Get(1).([]byte)
require.True(t, ok)
message <- m
}).
Return(nil)
connectionMock.On("CloseHandler").Return(nil)
connectionMock.On("SetCloseHandler", mock.Anything).Return()
connectionMock.On("Close").Return()
return connectionMock, message
}

View File

@ -0,0 +1,14 @@
package ws
import (
"io"
)
// Connection is an internal interface for websocket.Conn in order to mock it for unit tests.
type Connection interface {
WriteMessage(messageType int, data []byte) error
Close() error
NextReader() (messageType int, r io.Reader, err error)
CloseHandler() func(code int, text string) error
SetCloseHandler(handler func(code int, text string) error)
}

View File

@ -1,6 +1,6 @@
// Code generated by mockery v0.0.0-dev. DO NOT EDIT. // Code generated by mockery v2.13.1. DO NOT EDIT.
package api package ws
import ( import (
io "io" io "io"
@ -8,13 +8,13 @@ import (
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
) )
// webSocketConnectionMock is an autogenerated mock type for the webSocketConnection type // ConnectionMock is an autogenerated mock type for the Connection type
type webSocketConnectionMock struct { type ConnectionMock struct {
mock.Mock mock.Mock
} }
// Close provides a mock function with given fields: // Close provides a mock function with given fields:
func (_m *webSocketConnectionMock) Close() error { func (_m *ConnectionMock) Close() error {
ret := _m.Called() ret := _m.Called()
var r0 error var r0 error
@ -28,7 +28,7 @@ func (_m *webSocketConnectionMock) Close() error {
} }
// CloseHandler provides a mock function with given fields: // CloseHandler provides a mock function with given fields:
func (_m *webSocketConnectionMock) CloseHandler() func(int, string) error { func (_m *ConnectionMock) CloseHandler() func(int, string) error {
ret := _m.Called() ret := _m.Called()
var r0 func(int, string) error var r0 func(int, string) error
@ -44,7 +44,7 @@ func (_m *webSocketConnectionMock) CloseHandler() func(int, string) error {
} }
// NextReader provides a mock function with given fields: // NextReader provides a mock function with given fields:
func (_m *webSocketConnectionMock) NextReader() (int, io.Reader, error) { func (_m *ConnectionMock) NextReader() (int, io.Reader, error) {
ret := _m.Called() ret := _m.Called()
var r0 int var r0 int
@ -73,13 +73,13 @@ func (_m *webSocketConnectionMock) NextReader() (int, io.Reader, error) {
return r0, r1, r2 return r0, r1, r2
} }
// SetCloseHandler provides a mock function with given fields: h // SetCloseHandler provides a mock function with given fields: handler
func (_m *webSocketConnectionMock) SetCloseHandler(h func(int, string) error) { func (_m *ConnectionMock) SetCloseHandler(handler func(int, string) error) {
_m.Called(h) _m.Called(handler)
} }
// WriteMessage provides a mock function with given fields: messageType, data // WriteMessage provides a mock function with given fields: messageType, data
func (_m *webSocketConnectionMock) WriteMessage(messageType int, data []byte) error { func (_m *ConnectionMock) WriteMessage(messageType int, data []byte) error {
ret := _m.Called(messageType, data) ret := _m.Called(messageType, data)
var r0 error var r0 error
@ -91,3 +91,18 @@ func (_m *webSocketConnectionMock) WriteMessage(messageType int, data []byte) er
return r0 return r0
} }
type mockConstructorTestingTNewConnectionMock interface {
mock.TestingT
Cleanup(func())
}
// NewConnectionMock creates a new instance of ConnectionMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewConnectionMock(t mockConstructorTestingTNewConnectionMock) *ConnectionMock {
mock := &ConnectionMock{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}