Introduce context for the codeOceanOutputWriter

that represents its lifespan.
This commit is contained in:
Maximilian Paß
2023-04-11 19:29:12 +01:00
parent 0c8fa9ccfa
commit 2aa10a130f
3 changed files with 53 additions and 49 deletions

View File

@ -19,7 +19,6 @@ var ErrUnknownExecutionID = errors.New("execution id unknown")
// webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean. // webSocketProxy is an encapsulation of logic for forwarding between Runners and CodeOcean.
type webSocketProxy struct { type webSocketProxy struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc
Input ws.WebSocketReader Input ws.WebSocketReader
Output ws.WebSocketWriter Output ws.WebSocketWriter
} }
@ -38,16 +37,19 @@ func upgradeConnection(writer http.ResponseWriter, request *http.Request) (ws.Co
// newWebSocketProxy returns an initiated and started webSocketProxy. // newWebSocketProxy returns an initiated and started webSocketProxy.
// As this proxy is already started, a start message is send to the client. // As this proxy is already started, a start message is send to the client.
func newWebSocketProxy(connection ws.Connection, proxyCtx context.Context) *webSocketProxy { func newWebSocketProxy(connection ws.Connection, proxyCtx context.Context) *webSocketProxy {
wsCtx, cancelWsCommunication := context.WithCancel(proxyCtx) // wsCtx is detached from the proxyCtx
// as it should send all messages in the buffer even if the execution/proxy is done.
wsCtx, cancelWsCommunication := context.WithCancel(context.Background())
wsCtx = sentry.SetHubOnContext(wsCtx, sentry.GetHubFromContext(proxyCtx))
proxy := &webSocketProxy{ proxy := &webSocketProxy{
ctx: wsCtx, ctx: wsCtx,
cancel: cancelWsCommunication,
Input: ws.NewCodeOceanToRawReader(connection, wsCtx, proxyCtx), Input: ws.NewCodeOceanToRawReader(connection, wsCtx, proxyCtx),
Output: ws.NewCodeOceanOutputWriter(connection, wsCtx), Output: ws.NewCodeOceanOutputWriter(connection, wsCtx, cancelWsCommunication),
} }
connection.SetCloseHandler(func(code int, text string) error { connection.SetCloseHandler(func(code int, text string) error {
log.WithField("code", code).WithField("text", text).Debug("The client closed the connection.") log.WithContext(wsCtx).WithField("code", code).WithField("text", text).Debug("The client closed the connection.")
cancelWsCommunication() cancelWsCommunication()
return nil return nil
}) })
@ -70,7 +72,7 @@ func (wp *webSocketProxy) waitForExit(exit <-chan runner.ExitInfo, cancelExecuti
case exitInfo = <-exit: case exitInfo = <-exit:
log.WithContext(wp.ctx).Info("Execution returned") log.WithContext(wp.ctx).Info("Execution returned")
wp.Input.Stop() wp.Input.Stop()
wp.Output.SendExitInfo(&exitInfo) wp.Output.Close(&exitInfo)
} }
} }

View File

@ -28,7 +28,7 @@ func (rc *rawToCodeOceanWriter) Write(p []byte) (int, error) {
type WebSocketWriter interface { type WebSocketWriter interface {
StdOut() io.Writer StdOut() io.Writer
StdErr() io.Writer StdErr() io.Writer
SendExitInfo(info *runner.ExitInfo) Close(info *runner.ExitInfo)
} }
// codeOceanOutputWriter is a concrete WebSocketWriter implementation. // codeOceanOutputWriter is a concrete WebSocketWriter implementation.
@ -38,7 +38,7 @@ type codeOceanOutputWriter struct {
stdOut io.Writer stdOut io.Writer
stdErr io.Writer stdErr io.Writer
queue chan *writingLoopMessage queue chan *writingLoopMessage
stopped bool ctx context.Context
} }
// writingLoopMessage is an internal data structure to notify the writing loop when it should stop. // writingLoopMessage is an internal data structure to notify the writing loop when it should stop.
@ -47,13 +47,14 @@ type writingLoopMessage struct {
data *dto.WebSocketMessage data *dto.WebSocketMessage
} }
// NewCodeOceanOutputWriter provies an codeOceanOutputWriter for the time the context ctx is active. // NewCodeOceanOutputWriter provides an codeOceanOutputWriter for the time the context ctx is active.
// The codeOceanOutputWriter handles all the messages defined in the websocket.schema.json (start, timeout, stdout, ..). // The codeOceanOutputWriter handles all the messages defined in the websocket.schema.json (start, timeout, stdout, ..).
func NewCodeOceanOutputWriter(connection Connection, ctx context.Context) *codeOceanOutputWriter { func NewCodeOceanOutputWriter(
connection Connection, ctx context.Context, done context.CancelFunc) *codeOceanOutputWriter {
cw := &codeOceanOutputWriter{ cw := &codeOceanOutputWriter{
connection: connection, connection: connection,
queue: make(chan *writingLoopMessage, CodeOceanOutputWriterBufferSize), queue: make(chan *writingLoopMessage, CodeOceanOutputWriterBufferSize),
stopped: false, ctx: ctx,
} }
cw.stdOut = &rawToCodeOceanWriter{sendMessage: func(s string) { cw.stdOut = &rawToCodeOceanWriter{sendMessage: func(s string) {
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputStdout, Data: s}) cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputStdout, Data: s})
@ -62,8 +63,7 @@ func NewCodeOceanOutputWriter(connection Connection, ctx context.Context) *codeO
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputStderr, Data: s}) cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputStderr, Data: s})
}} }}
go cw.startWritingLoop() go cw.startWritingLoop(done)
go cw.stopWhenContextDone(ctx)
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketMetaStart}) cw.send(&dto.WebSocketMessage{Type: dto.WebSocketMetaStart})
return cw return cw
} }
@ -78,76 +78,78 @@ func (cw *codeOceanOutputWriter) StdErr() io.Writer {
return cw.stdErr return cw.stdErr
} }
// SendExitInfo forwards the kind of exit (timeout, error, normal) to CodeOcean. // Close forwards the kind of exit (timeout, error, normal) to CodeOcean.
func (cw *codeOceanOutputWriter) SendExitInfo(info *runner.ExitInfo) { // This results in the closing of the WebSocket connection.
func (cw *codeOceanOutputWriter) Close(info *runner.ExitInfo) {
switch { switch {
case errors.Is(info.Err, context.DeadlineExceeded) || errors.Is(info.Err, runner.ErrorRunnerInactivityTimeout): case errors.Is(info.Err, context.DeadlineExceeded) || errors.Is(info.Err, runner.ErrorRunnerInactivityTimeout):
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout}) cw.send(&dto.WebSocketMessage{Type: dto.WebSocketMetaTimeout})
case info.Err != nil: case info.Err != nil:
errorMessage := "Error executing the request" errorMessage := "Error executing the request"
log.WithError(info.Err).Warn(errorMessage) log.WithContext(cw.ctx).WithError(info.Err).Warn(errorMessage)
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage}) cw.send(&dto.WebSocketMessage{Type: dto.WebSocketOutputError, Data: errorMessage})
default: default:
cw.send(&dto.WebSocketMessage{Type: dto.WebSocketExit, ExitCode: info.Code}) cw.send(&dto.WebSocketMessage{Type: dto.WebSocketExit, ExitCode: info.Code})
} }
} }
// stopWhenContextDone notifies the writing loop to stop after the context has been passed.
func (cw *codeOceanOutputWriter) stopWhenContextDone(ctx context.Context) {
<-ctx.Done()
if !cw.stopped {
cw.queue <- &writingLoopMessage{done: true}
}
}
// send forwards the passed dto.WebSocketMessage to the writing loop. // send forwards the passed dto.WebSocketMessage to the writing loop.
func (cw *codeOceanOutputWriter) send(message *dto.WebSocketMessage) { func (cw *codeOceanOutputWriter) send(message *dto.WebSocketMessage) {
if cw.stopped { select {
case <-cw.ctx.Done():
return return
default:
done := message.Type == dto.WebSocketExit ||
message.Type == dto.WebSocketMetaTimeout ||
message.Type == dto.WebSocketOutputError
cw.queue <- &writingLoopMessage{done: done, data: message}
} }
done := message.Type == dto.WebSocketExit ||
message.Type == dto.WebSocketMetaTimeout ||
message.Type == dto.WebSocketOutputError
cw.queue <- &writingLoopMessage{done: done, data: message}
} }
// startWritingLoop enables the writing loop. // startWritingLoop enables the writing loop.
// This is the central and only place where written changes to the WebSocket connection should be done. // This is the central and only place where written changes to the WebSocket connection should be done.
// It synchronizes the messages to provide state checks of the WebSocket connection. // It synchronizes the messages to provide state checks of the WebSocket connection.
func (cw *codeOceanOutputWriter) startWritingLoop() { func (cw *codeOceanOutputWriter) startWritingLoop(writingLoopDone context.CancelFunc) {
for { defer func() {
message := <-cw.queue message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
done := sendMessage(cw.connection, message.data) err := cw.connection.WriteMessage(websocket.CloseMessage, message)
if done || message.done { err2 := cw.connection.Close()
break if err != nil || err2 != nil {
log.WithContext(cw.ctx).WithError(err).WithField("err2", err2).Warn("Error during websocket close")
}
}()
for {
select {
case <-cw.ctx.Done():
return
case message := <-cw.queue:
done := sendMessage(cw.connection, message.data, cw.ctx)
if done || message.done {
writingLoopDone()
return
}
} }
}
cw.stopped = true
message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
err := cw.connection.WriteMessage(websocket.CloseMessage, message)
err2 := cw.connection.Close()
if err != nil || err2 != nil {
log.WithError(err).WithField("err2", err2).Warn("Error during websocket close")
} }
} }
// sendMessage is a helper function for the writing loop. It must not be called from somewhere else! // sendMessage is a helper function for the writing loop. It must not be called from somewhere else!
func sendMessage(connection Connection, message *dto.WebSocketMessage) (done bool) { func sendMessage(connection Connection, message *dto.WebSocketMessage, ctx context.Context) (done bool) {
if message == nil { if message == nil {
return false return false
} }
encodedMessage, err := json.Marshal(message) encodedMessage, err := json.Marshal(message)
if err != nil { if err != nil {
log.WithField("message", message).WithError(err).Warn("Marshal error") log.WithContext(ctx).WithField("message", message).WithError(err).Warn("Marshal error")
return false return false
} }
log.WithField("message", message).Trace("Sending message to client") log.WithContext(ctx).WithField("message", message).Trace("Sending message to client")
err = connection.WriteMessage(websocket.TextMessage, encodedMessage) err = connection.WriteMessage(websocket.TextMessage, encodedMessage)
if err != nil { if err != nil {
errorMessage := "Error writing the message" errorMessage := "Error writing the message"
log.WithField("message", message).WithError(err).Warn(errorMessage) log.WithContext(ctx).WithField("message", message).WithError(err).Warn(errorMessage)
return true return true
} }

View File

@ -16,7 +16,7 @@ func TestRawToCodeOceanWriter(t *testing.T) {
connectionMock, message := buildConnectionMock(t) connectionMock, message := buildConnectionMock(t)
proxyCtx, cancel := context.WithCancel(context.Background()) proxyCtx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
output := NewCodeOceanOutputWriter(connectionMock, proxyCtx) output := NewCodeOceanOutputWriter(connectionMock, proxyCtx, cancel)
<-message // start message <-message // start message
t.Run("StdOut", func(t *testing.T) { t.Run("StdOut", func(t *testing.T) {
@ -69,10 +69,10 @@ func TestCodeOceanOutputWriter_SendExitInfo(t *testing.T) {
connectionMock, message := buildConnectionMock(t) connectionMock, message := buildConnectionMock(t)
proxyCtx, cancel := context.WithCancel(context.Background()) proxyCtx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
output := NewCodeOceanOutputWriter(connectionMock, proxyCtx) output := NewCodeOceanOutputWriter(connectionMock, proxyCtx, cancel)
<-message // start message <-message // start message
output.SendExitInfo(test.info) output.Close(test.info)
expected, err := json.Marshal(test.message) expected, err := json.Marshal(test.message)
require.NoError(t, err) require.NoError(t, err)