Files
poseidon/internal/api/websocket.go
sirkrypt0 36dc99f019 Don't embed the execution.Storer interface into a runner
Previously, the execution.Storer interface was embedded in the Runner interface.
However, this resulted in calls like runner.Add(...) to add an execution to the
store which is kind of ugly. Thus, we decided to add only the required functions to
the runner interface and make the execution.Storer a field of the implementation.
2021-07-30 14:28:06 +02:00

326 lines
10 KiB
Go

package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gorilla/websocket"
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/runner"
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto"
"io"
"net/http"
"sync"
)
const CodeOceanToRawReaderBufferSize = 1024
var ErrUnknownExecutionID = errors.New("execution id unknown")
type webSocketConnection interface {
WriteMessage(messageType int, data []byte) error
Close() error
NextReader() (messageType int, r io.Reader, err error)
CloseHandler() func(code int, text string) error
SetCloseHandler(handler func(code int, text string) error)
}
// WebSocketReader is an interface that is intended for providing abstraction around reading from a WebSocket.
// Besides io.Reader, it also implements io.Writer. The Write method is used to inject data into the WebSocket stream.
type WebSocketReader interface {
io.Reader
io.Writer
startReadInputLoop() context.CancelFunc
}
// codeOceanToRawReader is an io.Reader implementation that provides the content of the WebSocket connection
// to CodeOcean. You have to start the Reader by calling readInputLoop. After that you can use the Read function.
type codeOceanToRawReader struct {
connection webSocketConnection
// A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket
// and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here
// instead of bytes.Buffer.
buffer chan byte
// The priorityBuffer is a buffer for injecting data into stdin of the execution from Poseidon,
// for example the character that causes the tty to generate a SIGQUIT signal.
// It is always read before the regular buffer.
priorityBuffer chan byte
}
func newCodeOceanToRawReader(connection webSocketConnection) *codeOceanToRawReader {
return &codeOceanToRawReader{
connection: connection,
buffer: make(chan byte, CodeOceanToRawReaderBufferSize),
priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize),
}
}
// readInputLoop reads from the WebSocket connection and buffers the user's input.
// This is necessary because input must be read for the connection to handle special messages like close and call the
// CloseHandler.
func (cr *codeOceanToRawReader) readInputLoop(ctx context.Context) {
readMessage := make(chan bool)
for {
var messageType int
var reader io.Reader
var err error
go func() {
messageType, reader, err = cr.connection.NextReader()
readMessage <- true
}()
select {
case <-ctx.Done():
return
case <-readMessage:
}
if err != nil {
log.WithField("remote", cr.connection.(*websocket.Conn).UnderlyingConn().RemoteAddr()).
WithError(err).Warn("Error reading client message")
return
}
if messageType != websocket.TextMessage {
log.WithField("messageType", messageType).Warn("Received message of wrong type")
return
}
message, err := io.ReadAll(reader)
if err != nil {
log.WithError(err).Warn("error while reading WebSocket message")
return
}
for _, character := range message {
select {
case <-ctx.Done():
return
case cr.buffer <- character:
}
}
}
}
// startReadInputLoop start the read input loop asynchronously and returns a context.CancelFunc which can be used
// to cancel the read input loop.
func (cr *codeOceanToRawReader) startReadInputLoop() context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
go cr.readInputLoop(ctx)
return cancel
}
// Read implements the io.Reader interface.
// It returns bytes from the buffer or priorityBuffer.
func (cr *codeOceanToRawReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
// Ensure to not return until at least one byte has been read to avoid busy waiting.
select {
case p[0] = <-cr.priorityBuffer:
case p[0] = <-cr.buffer:
}
var n int
for n = 1; n < len(p); n++ {
select {
case p[n] = <-cr.priorityBuffer:
case p[n] = <-cr.buffer:
default:
return n, nil
}
}
return n, nil
}
// Write implements the io.Writer interface.
// Data written to a codeOceanToRawReader using this method is returned by Read before other data from the WebSocket.
func (cr *codeOceanToRawReader) Write(p []byte) (n int, err error) {
var c byte
for n, c = range p {
select {
case cr.priorityBuffer <- c:
default:
break
}
}
return n, nil
}
// rawToCodeOceanWriter is an io.Writer implementation that, when written to, wraps the written data in the appropriate
// json structure and sends it to the CodeOcean via WebSocket.
type rawToCodeOceanWriter struct {
proxy *webSocketProxy
outputType dto.WebSocketMessageType
}
// Write implements the io.Writer interface.
// The passed data is forwarded to the WebSocket to CodeOcean.
func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) {
err := rc.proxy.sendToClient(dto.WebSocketMessage{Type: rc.outputType, Data: string(p)})
return len(p), err
}
// webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean.
type webSocketProxy struct {
userExit chan bool
connection webSocketConnection
connectionMu sync.Mutex
Stdin WebSocketReader
Stdout io.Writer
Stderr io.Writer
}
// upgradeConnection upgrades a connection to a websocket and returns a webSocketProxy for this connection.
func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSocketConnection, error) {
connUpgrader := websocket.Upgrader{}
connection, err := connUpgrader.Upgrade(writer, request, nil)
if err != nil {
log.WithError(err).Warn("Connection upgrade failed")
return nil, fmt.Errorf("error upgrading the connection: %w", err)
}
return connection, nil
}
// newWebSocketProxy returns a initiated and started webSocketProxy.
// As this proxy is already started, a start message is send to the client.
func newWebSocketProxy(connection webSocketConnection) (*webSocketProxy, error) {
stdin := newCodeOceanToRawReader(connection)
proxy := &webSocketProxy{
connection: connection,
Stdin: stdin,
userExit: make(chan bool),
}
proxy.Stdout = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStdout}
proxy.Stderr = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStderr}
err := proxy.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaStart})
if err != nil {
return nil, err
}
closeHandler := connection.CloseHandler()
connection.SetCloseHandler(func(code int, text string) error {
//nolint:errcheck // The default close handler always returns nil, so the error can be safely ignored.
_ = closeHandler(code, text)
close(proxy.userExit)
return nil
})
return proxy, nil
}
// waitForExit waits for an exit of either the runner (when the command terminates) or the client closing the WebSocket
// and handles WebSocket exit messages.
func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecution context.CancelFunc) {
cancelInputLoop := wp.Stdin.startReadInputLoop()
var exitInfo runner.ExitInfo
select {
case exitInfo = <-exit:
cancelInputLoop()
log.Info("Execution returned")
case <-wp.userExit:
cancelInputLoop()
cancelExecution()
log.Info("Client closed the connection")
return
}
if errors.Is(exitInfo.Err, context.DeadlineExceeded) || errors.Is(exitInfo.Err, runner.ErrorRunnerInactivityTimeout) {
err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout})
if err == nil {
wp.closeNormal()
} else {
log.WithError(err).Warn("Failed to send timeout message to client")
}
return
} else if exitInfo.Err != nil {
errorMessage := "Error executing the request"
log.WithError(exitInfo.Err).Warn(errorMessage)
err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage})
if err == nil {
wp.closeNormal()
} else {
log.WithError(err).Warn("Failed to send output error message to client")
}
return
}
log.WithField("exit_code", exitInfo.Code).Debug()
err := wp.sendToClient(dto.WebSocketMessage{
Type: dto.WebSocketExit,
ExitCode: exitInfo.Code,
})
if err != nil {
log.WithError(err).Warn("Error sending exit message")
return
}
wp.closeNormal()
}
func (wp *webSocketProxy) sendToClient(message dto.WebSocketMessage) error {
encodedMessage, err := json.Marshal(message)
if err != nil {
log.WithField("message", message).WithError(err).Warn("Marshal error")
wp.closeWithError("Error creating message")
return fmt.Errorf("error marshaling WebSocket message: %w", err)
}
err = wp.writeMessage(websocket.TextMessage, encodedMessage)
if err != nil {
errorMessage := "Error writing the exit message"
log.WithError(err).Warn(errorMessage)
wp.closeWithError(errorMessage)
return fmt.Errorf("error writing WebSocket message: %w", err)
}
return nil
}
func (wp *webSocketProxy) closeWithError(message string) {
wp.close(websocket.FormatCloseMessage(websocket.CloseInternalServerErr, message))
}
func (wp *webSocketProxy) closeNormal() {
wp.close(websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
}
func (wp *webSocketProxy) close(message []byte) {
err := wp.writeMessage(websocket.CloseMessage, message)
_ = wp.connection.Close()
if err != nil {
log.WithError(err).Warn("Error during websocket close")
}
}
func (wp *webSocketProxy) writeMessage(messageType int, data []byte) error {
wp.connectionMu.Lock()
defer wp.connectionMu.Unlock()
return wp.connection.WriteMessage(messageType, data) //nolint:wrapcheck // Wrap the original WriteMessage in a mutex.
}
// connectToRunner is the endpoint for websocket connections.
func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request *http.Request) {
targetRunner, _ := runner.FromContext(request.Context())
executionID := request.URL.Query().Get(ExecutionIDKey)
if !targetRunner.ExecutionExists(executionID) {
writeNotFound(writer, ErrUnknownExecutionID)
return
}
connection, err := upgradeConnection(writer, request)
if err != nil {
writeInternalServerError(writer, err, dto.ErrorUnknown)
return
}
proxy, err := newWebSocketProxy(connection)
if err != nil {
return
}
log.WithField("runnerId", targetRunner.ID()).WithField("executionID", executionID).Info("Running execution")
exit, cancel, err := targetRunner.ExecuteInteractively(executionID, proxy.Stdin, proxy.Stdout, proxy.Stderr)
if err != nil {
proxy.closeWithError(fmt.Sprintf("execution failed with: %v", err))
return
}
proxy.waitForExit(exit, cancel)
}