Execute commands in runner via WebSocket

This enables executing commands in runners and forwarding input and
output between the runner and the websocket to the client.

Co-authored-by: Maximilian Paß <maximilian.pass@student.hpi.uni-potsdam.de>
This commit is contained in:
Konrad Hanff
2021-05-20 08:47:51 +02:00
parent 892f902377
commit 3afcdeaba8
10 changed files with 479 additions and 78 deletions

View File

@ -1,5 +1,11 @@
package dto
import (
"encoding/json"
"errors"
"fmt"
)
// RunnerRequest is the expected json structure of the request body for the ProvideRunner function.
type RunnerRequest struct {
ExecutionEnvironmentId int `json:"executionEnvironmentId"`
@ -13,6 +19,16 @@ type ExecutionRequest struct {
Environment map[string]string
}
func (er *ExecutionRequest) FullCommand() []string {
var command []string
command = append(command, "env", "-")
for variable, value := range er.Environment {
command = append(command, fmt.Sprintf("%s=%s", variable, value))
}
command = append(command, "sh", "-c", er.Command)
return command
}
// ExecutionEnvironmentRequest is the expected json structure of the request body for the create execution environment function.
// nolint:unused,structcheck
type ExecutionEnvironmentRequest struct {
@ -33,9 +49,102 @@ type RunnerResponse struct {
// TODO: specify content of the struct
type FileCreation struct{}
// WebsocketResponse is the expected response when creating an execution for a runner.
type WebsocketResponse struct {
WebsocketUrl string `json:"websocketUrl"`
// ExecutionResponse is the expected response when creating an execution for a runner.
type ExecutionResponse struct {
WebSocketUrl string `json:"websocketUrl"`
}
// WebSocketMessageType is the type for the messages from Poseidon to the client.
type WebSocketMessageType string
const (
WebSocketOutputStdout WebSocketMessageType = "stdout"
WebSocketOutputStderr WebSocketMessageType = "stderr"
WebSocketOutputError WebSocketMessageType = "error"
WebSocketMetaStart WebSocketMessageType = "start"
WebSocketMetaTimeout WebSocketMessageType = "timeout"
WebSocketExit WebSocketMessageType = "exit"
)
// WebSocketMessage is the type for all messages send in the WebSocket to the client.
// Depending on the MessageType the Data or ExitCode might not be included in the marshaled json message.
type WebSocketMessage struct {
Type WebSocketMessageType
Data string
ExitCode uint8
}
// MarshalJSON implements the json.Marshaler interface.
// This converts the WebSocketMessage into the expected schema (see docs/websocket.schema.json).
func (m WebSocketMessage) MarshalJSON() ([]byte, error) {
switch m.Type {
case WebSocketOutputStdout, WebSocketOutputStderr, WebSocketOutputError:
return json.Marshal(struct {
MessageType WebSocketMessageType `json:"type"`
Data string `json:"data"`
}{m.Type, m.Data})
case WebSocketMetaStart, WebSocketMetaTimeout:
return json.Marshal(struct {
MessageType WebSocketMessageType `json:"type"`
}{m.Type})
case WebSocketExit:
return json.Marshal(struct {
MessageType WebSocketMessageType `json:"type"`
ExitCode uint8 `json:"data"`
}{m.Type, m.ExitCode})
}
return nil, errors.New("unhandled WebSocket message type")
}
// UnmarshalJSON implements the json.Unmarshaler interface.
// It is used by tests in order to ReceiveNextWebSocketMessage.
func (m *WebSocketMessage) UnmarshalJSON(rawMessage []byte) error {
messageMap := make(map[string]interface{})
err := json.Unmarshal(rawMessage, &messageMap)
if err != nil {
return err
}
messageType, ok := messageMap["type"]
if !ok {
return errors.New("missing key type")
}
messageTypeString, ok := messageType.(string)
if !ok {
return errors.New("value of key type must be a string")
}
switch messageType := WebSocketMessageType(messageTypeString); messageType {
case WebSocketExit:
data, ok := messageMap["data"]
if !ok {
return errors.New("missing key data")
}
// json.Unmarshal converts any number to a float64 in the massageMap, so we must first cast it to the float.
exit, ok := data.(float64)
if !ok {
return errors.New("value of key data must be a number")
}
if exit != float64(uint8(exit)) {
return errors.New("value of key data must be uint8")
}
m.Type = messageType
m.ExitCode = uint8(exit)
case WebSocketOutputStdout, WebSocketOutputStderr, WebSocketOutputError:
data, ok := messageMap["data"]
if !ok {
return errors.New("missing key data")
}
text, ok := data.(string)
if !ok {
return errors.New("value of key data must be a string")
}
m.Type = messageType
m.Data = text
case WebSocketMetaStart, WebSocketMetaTimeout:
m.Type = messageType
default:
return errors.New("unknown WebSocket message type")
}
return nil
}
// ClientError is the response interface if the request is not valid.

View File

@ -2,6 +2,7 @@ package api
import (
"fmt"
"github.com/google/uuid"
"github.com/gorilla/mux"
"gitlab.hpi.de/codeocean/codemoon/poseidon/api/dto"
"gitlab.hpi.de/codeocean/codemoon/poseidon/config"
@ -81,20 +82,22 @@ func (r *RunnerController) execute(writer http.ResponseWriter, request *http.Req
writeInternalServerError(writer, err, dto.ErrorUnknown)
return
}
id, err := targetRunner.AddExecution(*executionRequest)
newUuid, err := uuid.NewRandom()
if err != nil {
log.WithError(err).Error("Could not store execution.")
log.WithError(err).Error("Could not create execution id")
writeInternalServerError(writer, err, dto.ErrorUnknown)
return
}
websocketUrl := url.URL{
id := runner.ExecutionId(newUuid.String())
targetRunner.Add(id, executionRequest)
webSocketUrl := url.URL{
Scheme: scheme,
Host: request.Host,
Path: path.String(),
RawQuery: fmt.Sprintf("%s=%s", ExecutionIdKey, id),
}
sendJson(writer, &dto.WebsocketResponse{WebsocketUrl: websocketUrl.String()}, http.StatusOK)
sendJson(writer, &dto.ExecutionResponse{WebSocketUrl: webSocketUrl.String()}, http.StatusOK)
}
// The findRunnerMiddleware looks up the runnerId for routes containing it

View File

@ -1,34 +1,264 @@
package api
import (
"context"
"encoding/json"
"errors"
"github.com/gorilla/websocket"
"gitlab.hpi.de/codeocean/codemoon/poseidon/api/dto"
"gitlab.hpi.de/codeocean/codemoon/poseidon/runner"
"io"
"net/http"
)
type webSocketConnection interface {
WriteMessage(messageType int, data []byte) error
Close() error
NextReader() (messageType int, r io.Reader, err error)
CloseHandler() func(code int, text string) error
SetCloseHandler(handler func(code int, text string) error)
}
type WebSocketReader interface {
io.Reader
readInputLoop() context.CancelFunc
}
// codeOceanToRawReader is an io.Reader implementation that provides the content of the WebSocket connection to CodeOcean.
// You have to start the Reader by calling readInputLoop. After that you can use the Read function.
type codeOceanToRawReader struct {
connection webSocketConnection
// A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket
// and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here
// instead of bytes.Buffer.
buffer chan byte
}
func newCodeOceanToRawReader(connection webSocketConnection) *codeOceanToRawReader {
return &codeOceanToRawReader{
connection: connection,
buffer: make(chan byte, 1024),
}
}
// readInputLoop asynchronously reads from the WebSocket connection and buffers the user's input.
// This is necessary because input must be read for the connection to handle special messages like close and call the
// CloseHandler.
func (cr *codeOceanToRawReader) readInputLoop() context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
go func() {
readMessage := make(chan bool)
for {
var messageType int
var reader io.Reader
var err error
go func() {
messageType, reader, err = cr.connection.NextReader()
readMessage <- true
}()
select {
case <-readMessage:
case <-ctx.Done():
return
}
if err != nil {
log.WithError(err).Warn("Error reading client message")
return
}
if messageType != websocket.TextMessage {
log.WithField("messageType", messageType).Warn("Received message of wrong type")
return
}
message, err := io.ReadAll(reader)
if err != nil {
log.WithError(err).Warn("error while reading WebSocket message")
return
}
for _, character := range message {
select {
case cr.buffer <- character:
case <-ctx.Done():
return
}
}
}
}()
return cancel
}
// Read implements the io.Reader interface.
// It returns bytes from the buffer.
func (cr *codeOceanToRawReader) Read(p []byte) (n int, err error) {
for n = 0; n < len(p); n++ {
select {
case p[n] = <-cr.buffer:
default:
return
}
}
return
}
// rawToCodeOceanWriter is an io.Writer implementation that, when written to, wraps the written data in the appropriate
// json structure and sends it to the CodeOcean via WebSocket.
type rawToCodeOceanWriter struct {
proxy *webSocketProxy
outputType dto.WebSocketMessageType
}
// Write implements the io.Writer interface.
// The passed data is forwarded to the WebSocket to CodeOcean.
func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) {
err := rc.proxy.sendToClient(dto.WebSocketMessage{Type: rc.outputType, Data: string(p)})
return len(p), err
}
// webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean.
type webSocketProxy struct {
userExit chan bool
connection webSocketConnection
Stdin WebSocketReader
Stdout io.Writer
Stderr io.Writer
}
// upgradeConnection upgrades a connection to a websocket and returns a webSocketProxy for this connection.
func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSocketConnection, error) {
connUpgrader := websocket.Upgrader{}
connection, err := connUpgrader.Upgrade(writer, request, nil)
if err != nil {
log.WithError(err).Warn("Connection upgrade failed")
return nil, err
}
return connection, nil
}
// newWebSocketProxy returns a initiated and started webSocketProxy.
// As this proxy is already started, a start message is send to the client.
func newWebSocketProxy(connection webSocketConnection) (*webSocketProxy, error) {
stdin := newCodeOceanToRawReader(connection)
proxy := &webSocketProxy{
connection: connection,
Stdin: stdin,
userExit: make(chan bool),
}
proxy.Stdout = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStdout}
proxy.Stderr = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStderr}
err := proxy.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaStart})
if err != nil {
return nil, err
}
closeHandler := connection.CloseHandler()
connection.SetCloseHandler(func(code int, text string) error {
// The default close handler always returns nil, so the error can be safely ignored.
_ = closeHandler(code, text)
close(proxy.userExit)
return nil
})
return proxy, nil
}
// waitForExit waits for an exit of either the runner (when the command terminates) or the client closing the WebSocket
// and handles WebSocket exit messages.
func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecution context.CancelFunc) {
defer wp.close()
cancelInputLoop := wp.Stdin.readInputLoop()
var exitInfo runner.ExitInfo
select {
case exitInfo = <-exit:
cancelInputLoop()
log.Info("Execution returned")
case <-wp.userExit:
cancelInputLoop()
cancelExecution()
log.Info("Client closed the connection")
return
}
if exitInfo.Err == context.DeadlineExceeded {
err := wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout})
if err != nil {
return
}
return
} else if exitInfo.Err != nil {
errorMessage := "Error executing the request"
log.WithError(exitInfo.Err).Warn(errorMessage)
_ = wp.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage})
return
}
log.WithField("exit_code", exitInfo.Code).Debug()
err := wp.sendToClient(dto.WebSocketMessage{
Type: dto.WebSocketExit,
ExitCode: exitInfo.Code,
})
if err != nil {
return
}
}
func (wp *webSocketProxy) sendToClient(message dto.WebSocketMessage) error {
encodedMessage, err := json.Marshal(message)
if err != nil {
log.WithField("message", message).WithError(err).Warn("Marshal error")
wp.closeWithError("Error creating message")
return err
}
err = wp.connection.WriteMessage(websocket.TextMessage, encodedMessage)
if err != nil {
errorMessage := "Error writing the exit message"
log.WithError(err).Warn(errorMessage)
wp.closeWithError(errorMessage)
return err
}
return nil
}
func (wp *webSocketProxy) closeWithError(message string) {
err := wp.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseInternalServerErr, message))
if err != nil {
log.WithError(err).Warn("Error during websocket close")
}
}
func (wp *webSocketProxy) close() {
err := wp.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
_ = wp.connection.Close()
if err != nil {
log.WithError(err).Warn("Error during websocket close")
}
}
// connectToRunner is the endpoint for websocket connections.
func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request *http.Request) {
targetRunner, _ := runner.FromContext(request.Context())
executionId := runner.ExecutionId(request.URL.Query().Get(ExecutionIdKey))
_, ok := targetRunner.Execution(executionId)
executionRequest, ok := targetRunner.Pop(executionId)
if !ok {
writeNotFound(writer, errors.New("executionId does not exist"))
return
}
log.
WithField("runnerId", targetRunner.Id()).
WithField("executionId", executionId).
Info("Running execution")
connUpgrader := websocket.Upgrader{}
connClient, err := connUpgrader.Upgrade(writer, request, nil)
if err != nil {
log.WithError(err).Warn("Connection upgrade failed")
return
}
err = connClient.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
connection, err := upgradeConnection(writer, request)
if err != nil {
writeInternalServerError(writer, err, dto.ErrorUnknown)
return
}
proxy, err := newWebSocketProxy(connection)
if err != nil {
return
}
log.WithField("runnerId", targetRunner.Id()).WithField("executionId", executionId).Info("Running execution")
exit, cancel := targetRunner.Execute(executionRequest, proxy.Stdin, proxy.Stdout, proxy.Stderr)
proxy.waitForExit(exit, cancel)
}