diff --git a/internal/api/websocket.go b/internal/api/websocket.go index c5d2e1b..9f201ba 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -10,6 +10,7 @@ import ( "gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto" "io" "net/http" + "sync" ) const CodeOceanToRawReaderBufferSize = 1024 @@ -161,11 +162,12 @@ func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) { // webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean. type webSocketProxy struct { - userExit chan bool - connection webSocketConnection - Stdin WebSocketReader - Stdout io.Writer - Stderr io.Writer + userExit chan bool + connection webSocketConnection + connectionMu sync.Mutex + Stdin WebSocketReader + Stdout io.Writer + Stderr io.Writer } // upgradeConnection upgrades a connection to a websocket and returns a webSocketProxy for this connection. @@ -261,7 +263,7 @@ func (wp *webSocketProxy) sendToClient(message dto.WebSocketMessage) error { wp.closeWithError("Error creating message") return fmt.Errorf("error marshaling WebSocket message: %w", err) } - err = wp.connection.WriteMessage(websocket.TextMessage, encodedMessage) + err = wp.writeMessage(websocket.TextMessage, encodedMessage) if err != nil { errorMessage := "Error writing the exit message" log.WithError(err).Warn(errorMessage) @@ -280,13 +282,19 @@ func (wp *webSocketProxy) closeNormal() { } func (wp *webSocketProxy) close(message []byte) { - err := wp.connection.WriteMessage(websocket.CloseMessage, message) + err := wp.writeMessage(websocket.CloseMessage, message) _ = wp.connection.Close() if err != nil { log.WithError(err).Warn("Error during websocket close") } } +func (wp *webSocketProxy) writeMessage(messageType int, data []byte) error { + wp.connectionMu.Lock() + defer wp.connectionMu.Unlock() + return wp.connection.WriteMessage(messageType, data) //nolint:wrapcheck // Wrap the original WriteMessage in a mutex. +} + // connectToRunner is the endpoint for websocket connections. func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request *http.Request) { targetRunner, _ := runner.FromContext(request.Context())