Files
poseidon/internal/api/websocket.go
Maximilian Paß eabe3a1b27 Add the Environment ID to the influxdb data.
Also move the interface of an execution environment into its own file, execution_environment.go.
2022-04-18 13:17:49 +02:00

368 lines
12 KiB
Go

package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gorilla/websocket"
"github.com/openHPI/poseidon/internal/runner"
"github.com/openHPI/poseidon/pkg/dto"
"io"
"net/http"
"sync"
)
const CodeOceanToRawReaderBufferSize = 1024
var ErrUnknownExecutionID = errors.New("execution id unknown")
type webSocketConnection interface {
WriteMessage(messageType int, data []byte) error
Close() error
NextReader() (messageType int, r io.Reader, err error)
CloseHandler() func(code int, text string) error
SetCloseHandler(handler func(code int, text string) error)
}
// WebSocketReader is an interface that is intended for providing abstraction around reading from a WebSocket.
// Besides io.Reader, it also implements io.Writer. The Write method is used to inject data into the WebSocket stream.
type WebSocketReader interface {
io.Reader
io.Writer
startReadInputLoop()
}
// codeOceanToRawReader is an io.Reader implementation that provides the content of the WebSocket connection
// to CodeOcean. You have to start the Reader by calling readInputLoop. After that you can use the Read function.
type codeOceanToRawReader struct {
connection webSocketConnection
// wsCtx is the context in that messages from CodeOcean are read.
wsCtx context.Context
// executorCtx is the context in that messages are forwarded to the executor.
executorCtx context.Context
// A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket
// and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here
// instead of bytes.Buffer.
buffer chan byte
// The priorityBuffer is a buffer for injecting data into stdin of the execution from Poseidon,
// for example the character that causes the tty to generate a SIGQUIT signal.
// It is always read before the regular buffer.
priorityBuffer chan byte
}
func newCodeOceanToRawReader(connection webSocketConnection, wsCtx, executorCtx context.Context) *codeOceanToRawReader {
return &codeOceanToRawReader{
connection: connection,
wsCtx: wsCtx,
executorCtx: executorCtx,
buffer: make(chan byte, CodeOceanToRawReaderBufferSize),
priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize),
}
}
// readInputLoop reads from the WebSocket connection and buffers the user's input.
// This is necessary because input must be read for the connection to handle special messages like close and call the
// CloseHandler.
func (cr *codeOceanToRawReader) readInputLoop() {
readMessage := make(chan bool)
loopContext, cancelInputLoop := context.WithCancel(cr.wsCtx)
defer cancelInputLoop()
readingContext, cancelNextMessage := context.WithCancel(loopContext)
defer cancelNextMessage()
for loopContext.Err() == nil {
var messageType int
var reader io.Reader
var err error
go func() {
messageType, reader, err = cr.connection.NextReader()
select {
case <-readingContext.Done():
case readMessage <- true:
}
}()
select {
case <-loopContext.Done():
return
case <-readMessage:
}
if handleInput(messageType, reader, err, cr.buffer, loopContext) {
return
}
}
}
// handleInput receives a new message from the client and may forward it to the executor.
func handleInput(messageType int, reader io.Reader, err error, buffer chan byte, ctx context.Context) (done bool) {
if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure) {
log.Debug("ReadInputLoop: The client closed the connection!")
// The close handler will do something soon.
return true
} else if err != nil {
log.WithError(err).Warn("Error reading client message")
return true
}
if messageType != websocket.TextMessage {
log.WithField("messageType", messageType).Warn("Received message of wrong type")
return true
}
message, err := io.ReadAll(reader)
if err != nil {
log.WithError(err).Warn("error while reading WebSocket message")
return true
}
log.WithField("message", string(message)).Trace("Received message from client")
for _, character := range message {
select {
case <-ctx.Done():
return true
case buffer <- character:
}
}
return false
}
// startReadInputLoop start the read input loop asynchronously.
func (cr *codeOceanToRawReader) startReadInputLoop() {
go cr.readInputLoop()
}
// Read implements the io.Reader interface.
// It returns bytes from the buffer or priorityBuffer.
func (cr *codeOceanToRawReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
// Ensure to not return until at least one byte has been read to avoid busy waiting.
select {
case <-cr.executorCtx.Done():
return 0, io.EOF
case p[0] = <-cr.priorityBuffer:
case p[0] = <-cr.buffer:
}
var n int
for n = 1; n < len(p); n++ {
select {
case p[n] = <-cr.priorityBuffer:
case p[n] = <-cr.buffer:
default:
return n, nil
}
}
return n, nil
}
// Write implements the io.Writer interface.
// Data written to a codeOceanToRawReader using this method is returned by Read before other data from the WebSocket.
func (cr *codeOceanToRawReader) Write(p []byte) (n int, err error) {
var c byte
for n, c = range p {
select {
case cr.priorityBuffer <- c:
default:
break
}
}
return n, nil
}
// rawToCodeOceanWriter is an io.Writer implementation that, when written to, wraps the written data in the appropriate
// json structure and sends it to the CodeOcean via WebSocket.
type rawToCodeOceanWriter struct {
proxy *webSocketProxy
outputType dto.WebSocketMessageType
}
// Write implements the io.Writer interface.
// The passed data is forwarded to the WebSocket to CodeOcean.
func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) {
if rc.proxy.webSocketCtx.Err() != nil {
return 0, nil
}
err := rc.proxy.sendToClient(dto.WebSocketMessage{Type: rc.outputType, Data: string(p)})
return len(p), err
}
// webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean.
type webSocketProxy struct {
webSocketCtx context.Context
cancelWebSocket context.CancelFunc
connection webSocketConnection
connectionMu sync.Mutex
Stdin WebSocketReader
Stdout io.Writer
Stderr io.Writer
}
// upgradeConnection upgrades a connection to a websocket and returns a webSocketProxy for this connection.
func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSocketConnection, error) {
connUpgrader := websocket.Upgrader{}
connection, err := connUpgrader.Upgrade(writer, request, nil)
if err != nil {
log.WithError(err).Warn("Connection upgrade failed")
return nil, fmt.Errorf("error upgrading the connection: %w", err)
}
return connection, nil
}
// newWebSocketProxy returns an initiated and started webSocketProxy.
// As this proxy is already started, a start message is send to the client.
func newWebSocketProxy(connection webSocketConnection, proxyCtx context.Context) (*webSocketProxy, error) {
wsCtx, cancelWsCommunication := context.WithCancel(proxyCtx)
stdin := newCodeOceanToRawReader(connection, wsCtx, proxyCtx)
proxy := &webSocketProxy{
connection: connection,
Stdin: stdin,
webSocketCtx: wsCtx,
cancelWebSocket: cancelWsCommunication,
}
proxy.Stdout = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStdout}
proxy.Stderr = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStderr}
err := proxy.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaStart})
if err != nil {
cancelWsCommunication()
return nil, err
}
closeHandler := connection.CloseHandler()
connection.SetCloseHandler(func(code int, text string) error {
//nolint:errcheck // The default close handler always returns nil, so the error can be safely ignored.
_ = closeHandler(code, text)
cancelWsCommunication()
return nil
})
return proxy, nil
}
// waitForExit waits for an exit of either the runner (when the command terminates) or the client closing the WebSocket
// and handles WebSocket exit messages.
func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecution context.CancelFunc) {
wp.Stdin.startReadInputLoop()
var exitInfo runner.ExitInfo
select {
case <-wp.webSocketCtx.Done():
log.Info("Client closed the connection")
cancelExecution()
<-exit // /internal/runner/runner.go handleExitOrContextDone does not require client connection anymore.
<-exit // The goroutine closes this channel indicating that it does not use the connection to the executor anymore.
return
case exitInfo = <-exit:
log.Info("Execution returned")
wp.cancelWebSocket()
}
if errors.Is(exitInfo.Err, context.DeadlineExceeded) || errors.Is(exitInfo.Err, runner.ErrorRunnerInactivityTimeout) {
err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout})
if err == nil {
wp.closeNormal()
} else {
log.WithError(err).Warn("Failed to send timeout message to client")
}
return
} else if exitInfo.Err != nil {
errorMessage := "Error executing the request"
log.WithError(exitInfo.Err).Warn(errorMessage)
err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage})
if err == nil {
wp.closeNormal()
} else {
log.WithError(err).Warn("Failed to send output error message to client")
}
return
}
log.WithField("exit_code", exitInfo.Code).Debug()
err := wp.sendToClient(dto.WebSocketMessage{
Type: dto.WebSocketExit,
ExitCode: exitInfo.Code,
})
if err != nil {
log.WithError(err).Warn("Error sending exit message")
return
}
wp.closeNormal()
}
func (wp *webSocketProxy) sendToClient(message dto.WebSocketMessage) error {
encodedMessage, err := json.Marshal(message)
if err != nil {
log.WithField("message", message).WithError(err).Warn("Marshal error")
wp.closeWithError("Error creating message")
return fmt.Errorf("error marshaling WebSocket message: %w", err)
}
log.WithField("message", message).Trace("Sending message to client")
err = wp.writeMessage(websocket.TextMessage, encodedMessage)
if err != nil {
errorMessage := "Error writing the message"
log.WithField("message", message).WithError(err).Warn(errorMessage)
wp.closeWithError(errorMessage)
return fmt.Errorf("error writing WebSocket message: %w", err)
}
return nil
}
func (wp *webSocketProxy) closeWithError(message string) {
wp.close(websocket.FormatCloseMessage(websocket.CloseInternalServerErr, message))
}
func (wp *webSocketProxy) closeNormal() {
wp.close(websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
}
func (wp *webSocketProxy) close(message []byte) {
err := wp.writeMessage(websocket.CloseMessage, message)
_ = wp.connection.Close()
if err != nil {
log.WithError(err).Warn("Error during websocket close")
}
}
func (wp *webSocketProxy) writeMessage(messageType int, data []byte) error {
wp.connectionMu.Lock()
defer wp.connectionMu.Unlock()
return wp.connection.WriteMessage(messageType, data) //nolint:wrapcheck // Wrap the original WriteMessage in a mutex.
}
// connectToRunner is the endpoint for websocket connections.
func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request *http.Request) {
targetRunner, _ := runner.FromContext(request.Context())
addMonitoringData(request, targetRunner)
executionID := request.URL.Query().Get(ExecutionIDKey)
if !targetRunner.ExecutionExists(executionID) {
writeNotFound(writer, ErrUnknownExecutionID)
return
}
connection, err := upgradeConnection(writer, request)
if err != nil {
writeInternalServerError(writer, err, dto.ErrorUnknown)
return
}
proxyCtx, cancelProxy := context.WithCancel(context.Background())
defer cancelProxy()
proxy, err := newWebSocketProxy(connection, proxyCtx)
if err != nil {
return
}
log.WithField("runnerId", targetRunner.ID()).WithField("executionID", executionID).Info("Running execution")
exit, cancel, err := targetRunner.ExecuteInteractively(executionID, proxy.Stdin, proxy.Stdout, proxy.Stderr)
if err != nil {
proxy.closeWithError(fmt.Sprintf("execution failed with: %v", err))
return
}
proxy.waitForExit(exit, cancel)
}