Fix webSocket error
that was introduced by closing the WebSocket connection without stopping the inputLoop.
This commit is contained in:
@ -38,8 +38,10 @@ type WebSocketReader interface {
|
|||||||
type codeOceanToRawReader struct {
|
type codeOceanToRawReader struct {
|
||||||
connection webSocketConnection
|
connection webSocketConnection
|
||||||
|
|
||||||
// ctx is used to cancel the reading routine.
|
// wsCtx is the context in that messages from CodeOcean are read.
|
||||||
ctx context.Context
|
wsCtx context.Context
|
||||||
|
// executorCtx is the context in that messages are forwarded to the executor.
|
||||||
|
executorCtx context.Context
|
||||||
|
|
||||||
// A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket
|
// A buffered channel of bytes is used to store data coming from CodeOcean via WebSocket
|
||||||
// and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here
|
// and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here
|
||||||
@ -51,10 +53,11 @@ type codeOceanToRawReader struct {
|
|||||||
priorityBuffer chan byte
|
priorityBuffer chan byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCodeOceanToRawReader(connection webSocketConnection, ctx context.Context) *codeOceanToRawReader {
|
func newCodeOceanToRawReader(connection webSocketConnection, wsCtx, executorCtx context.Context) *codeOceanToRawReader {
|
||||||
return &codeOceanToRawReader{
|
return &codeOceanToRawReader{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
ctx: ctx,
|
wsCtx: wsCtx,
|
||||||
|
executorCtx: executorCtx,
|
||||||
buffer: make(chan byte, CodeOceanToRawReaderBufferSize),
|
buffer: make(chan byte, CodeOceanToRawReaderBufferSize),
|
||||||
priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize),
|
priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize),
|
||||||
}
|
}
|
||||||
@ -65,7 +68,7 @@ func newCodeOceanToRawReader(connection webSocketConnection, ctx context.Context
|
|||||||
// CloseHandler.
|
// CloseHandler.
|
||||||
func (cr *codeOceanToRawReader) readInputLoop() {
|
func (cr *codeOceanToRawReader) readInputLoop() {
|
||||||
readMessage := make(chan bool)
|
readMessage := make(chan bool)
|
||||||
loopContext, cancelInputLoop := context.WithCancel(cr.ctx)
|
loopContext, cancelInputLoop := context.WithCancel(cr.wsCtx)
|
||||||
defer cancelInputLoop()
|
defer cancelInputLoop()
|
||||||
readingContext, cancelNextMessage := context.WithCancel(loopContext)
|
readingContext, cancelNextMessage := context.WithCancel(loopContext)
|
||||||
defer cancelNextMessage()
|
defer cancelNextMessage()
|
||||||
@ -138,7 +141,7 @@ func (cr *codeOceanToRawReader) Read(p []byte) (int, error) {
|
|||||||
|
|
||||||
// Ensure to not return until at least one byte has been read to avoid busy waiting.
|
// Ensure to not return until at least one byte has been read to avoid busy waiting.
|
||||||
select {
|
select {
|
||||||
case <-cr.ctx.Done():
|
case <-cr.executorCtx.Done():
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
case p[0] = <-cr.priorityBuffer:
|
case p[0] = <-cr.priorityBuffer:
|
||||||
case p[0] = <-cr.buffer:
|
case p[0] = <-cr.buffer:
|
||||||
@ -208,23 +211,23 @@ func upgradeConnection(writer http.ResponseWriter, request *http.Request) (webSo
|
|||||||
return connection, nil
|
return connection, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// newWebSocketProxy returns a initiated and started webSocketProxy.
|
// newWebSocketProxy returns an initiated and started webSocketProxy.
|
||||||
// As this proxy is already started, a start message is send to the client.
|
// As this proxy is already started, a start message is send to the client.
|
||||||
func newWebSocketProxy(connection webSocketConnection, ctx context.Context) (*webSocketProxy, error) {
|
func newWebSocketProxy(connection webSocketConnection, proxyCtx context.Context) (*webSocketProxy, error) {
|
||||||
stdin := newCodeOceanToRawReader(connection, ctx)
|
wsCtx, cancelWsCommunication := context.WithCancel(proxyCtx)
|
||||||
inputCtx, inputCancel := context.WithCancel(ctx)
|
stdin := newCodeOceanToRawReader(connection, wsCtx, proxyCtx)
|
||||||
proxy := &webSocketProxy{
|
proxy := &webSocketProxy{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
Stdin: stdin,
|
Stdin: stdin,
|
||||||
webSocketCtx: inputCtx,
|
webSocketCtx: wsCtx,
|
||||||
cancelWebSocket: inputCancel,
|
cancelWebSocket: cancelWsCommunication,
|
||||||
}
|
}
|
||||||
proxy.Stdout = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStdout}
|
proxy.Stdout = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStdout}
|
||||||
proxy.Stderr = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStderr}
|
proxy.Stderr = &rawToCodeOceanWriter{proxy: proxy, outputType: dto.WebSocketOutputStderr}
|
||||||
|
|
||||||
err := proxy.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaStart})
|
err := proxy.sendToClient(dto.WebSocketMessage{Type: dto.WebSocketMetaStart})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
inputCancel()
|
cancelWsCommunication()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,7 +235,7 @@ func newWebSocketProxy(connection webSocketConnection, ctx context.Context) (*we
|
|||||||
connection.SetCloseHandler(func(code int, text string) error {
|
connection.SetCloseHandler(func(code int, text string) error {
|
||||||
//nolint:errcheck // The default close handler always returns nil, so the error can be safely ignored.
|
//nolint:errcheck // The default close handler always returns nil, so the error can be safely ignored.
|
||||||
_ = closeHandler(code, text)
|
_ = closeHandler(code, text)
|
||||||
inputCancel()
|
cancelWsCommunication()
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
return proxy, nil
|
return proxy, nil
|
||||||
|
@ -8,11 +8,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/openHPI/poseidon/internal/environment"
|
||||||
"github.com/openHPI/poseidon/internal/nomad"
|
"github.com/openHPI/poseidon/internal/nomad"
|
||||||
"github.com/openHPI/poseidon/internal/runner"
|
"github.com/openHPI/poseidon/internal/runner"
|
||||||
"github.com/openHPI/poseidon/pkg/dto"
|
"github.com/openHPI/poseidon/pkg/dto"
|
||||||
"github.com/openHPI/poseidon/tests"
|
"github.com/openHPI/poseidon/tests"
|
||||||
"github.com/openHPI/poseidon/tests/helpers"
|
"github.com/openHPI/poseidon/tests/helpers"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/sirupsen/logrus/hooks/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -44,7 +47,7 @@ func (s *WebSocketTestSuite) SetupTest() {
|
|||||||
s.runner, s.apiMock = newNomadAllocationWithMockedAPIClient(runnerID)
|
s.runner, s.apiMock = newNomadAllocationWithMockedAPIClient(runnerID)
|
||||||
|
|
||||||
// default execution
|
// default execution
|
||||||
s.executionID = "execution-id"
|
s.executionID = tests.DefaultExecutionID
|
||||||
s.runner.StoreExecution(s.executionID, &executionRequestHead)
|
s.runner.StoreExecution(s.executionID, &executionRequestHead)
|
||||||
mockAPIExecuteHead(s.apiMock)
|
mockAPIExecuteHead(s.apiMock)
|
||||||
|
|
||||||
@ -250,7 +253,7 @@ func TestWebsocketTLS(t *testing.T) {
|
|||||||
runnerID := "runner-id"
|
runnerID := "runner-id"
|
||||||
r, apiMock := newNomadAllocationWithMockedAPIClient(runnerID)
|
r, apiMock := newNomadAllocationWithMockedAPIClient(runnerID)
|
||||||
|
|
||||||
executionID := "execution-id"
|
executionID := tests.DefaultExecutionID
|
||||||
r.StoreExecution(executionID, &executionRequestLs)
|
r.StoreExecution(executionID, &executionRequestLs)
|
||||||
mockAPIExecuteLs(apiMock)
|
mockAPIExecuteLs(apiMock)
|
||||||
|
|
||||||
@ -315,8 +318,9 @@ func TestRawToCodeOceanWriter(t *testing.T) {
|
|||||||
|
|
||||||
func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasRead(t *testing.T) {
|
func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasRead(t *testing.T) {
|
||||||
readingCtx, cancel := context.WithCancel(context.Background())
|
readingCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
forwardingCtx := readingCtx
|
||||||
defer cancel()
|
defer cancel()
|
||||||
reader := newCodeOceanToRawReader(nil, readingCtx)
|
reader := newCodeOceanToRawReader(nil, readingCtx, forwardingCtx)
|
||||||
|
|
||||||
read := make(chan bool)
|
read := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
@ -350,8 +354,9 @@ func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *tes
|
|||||||
})
|
})
|
||||||
|
|
||||||
readingCtx, cancel := context.WithCancel(context.Background())
|
readingCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
forwardingCtx := readingCtx
|
||||||
defer cancel()
|
defer cancel()
|
||||||
reader := newCodeOceanToRawReader(connection, readingCtx)
|
reader := newCodeOceanToRawReader(connection, readingCtx, forwardingCtx)
|
||||||
reader.startReadInputLoop()
|
reader.startReadInputLoop()
|
||||||
|
|
||||||
read := make(chan bool)
|
read := make(chan bool)
|
||||||
@ -373,6 +378,32 @@ func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *tes
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWebSocketProxyStopsReadingTheWebSocketAfterClosingIt(t *testing.T) {
|
||||||
|
apiMock := &nomad.ExecutorAPIMock{}
|
||||||
|
executionID := tests.DefaultExecutionID
|
||||||
|
r, wsURL := newRunnerWithNotMockedRunnerManager(t, apiMock, executionID)
|
||||||
|
|
||||||
|
logger, hook := test.NewNullLogger()
|
||||||
|
log = logger.WithField("pkg", "api")
|
||||||
|
|
||||||
|
r.StoreExecution(executionID, &executionRequestHead)
|
||||||
|
mockAPIExecute(apiMock, &executionRequestHead,
|
||||||
|
func(_ string, ctx context.Context, _ []string, _ bool, _ io.Reader, _, _ io.Writer) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
})
|
||||||
|
connection, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = helpers.ReceiveAllWebSocketMessages(connection)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, websocket.IsCloseError(err, websocket.CloseNormalClosure))
|
||||||
|
for _, logMsg := range hook.Entries {
|
||||||
|
if logMsg.Level < logrus.InfoLevel {
|
||||||
|
assert.Fail(t, logMsg.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- Test suite specific test helpers ---
|
// --- Test suite specific test helpers ---
|
||||||
|
|
||||||
func newNomadAllocationWithMockedAPIClient(runnerID string) (runner.Runner, *nomad.ExecutorAPIMock) {
|
func newNomadAllocationWithMockedAPIClient(runnerID string) (runner.Runner, *nomad.ExecutorAPIMock) {
|
||||||
@ -383,6 +414,39 @@ func newNomadAllocationWithMockedAPIClient(runnerID string) (runner.Runner, *nom
|
|||||||
return r, executorAPIMock
|
return r, executorAPIMock
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newRunnerWithNotMockedRunnerManager(t *testing.T, apiMock *nomad.ExecutorAPIMock, executionID string) (
|
||||||
|
r runner.Runner, wsURL *url.URL) {
|
||||||
|
t.Helper()
|
||||||
|
apiMock.On("MarkRunnerAsUsed", mock.AnythingOfType("string"), mock.AnythingOfType("int")).Return(nil)
|
||||||
|
apiMock.On("DeleteJob", mock.AnythingOfType("string")).Return(nil)
|
||||||
|
apiMock.On("RegisterRunnerJob", mock.AnythingOfType("*api.Job")).Return(nil)
|
||||||
|
call := apiMock.On("WatchEventStream", mock.Anything, mock.Anything, mock.Anything)
|
||||||
|
call.Run(func(args mock.Arguments) {
|
||||||
|
<-context.Background().Done()
|
||||||
|
call.ReturnArguments = mock.Arguments{nil}
|
||||||
|
})
|
||||||
|
runnerManager := runner.NewNomadRunnerManager(apiMock, context.Background())
|
||||||
|
router := NewRouter(runnerManager, nil)
|
||||||
|
server := httptest.NewServer(router)
|
||||||
|
|
||||||
|
runnerID := tests.DefaultRunnerID
|
||||||
|
runnerJob := runner.NewNomadJob(runnerID, nil, apiMock, runnerManager)
|
||||||
|
e, err := environment.NewNomadEnvironment("job \"template-0\" {}")
|
||||||
|
require.NoError(t, err)
|
||||||
|
eID, err := nomad.EnvironmentIDFromRunnerID(runnerID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
e.SetID(eID)
|
||||||
|
e.SetPrewarmingPoolSize(0)
|
||||||
|
runnerManager.SetEnvironment(e)
|
||||||
|
e.AddRunner(runnerJob)
|
||||||
|
|
||||||
|
r, err = runnerManager.Claim(e.ID(), int(tests.DefaultTestTimeout.Seconds()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
wsURL, err = webSocketURL("ws", server, router, r.ID(), executionID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return r, wsURL
|
||||||
|
}
|
||||||
|
|
||||||
func webSocketURL(scheme string, server *httptest.Server, router *mux.Router,
|
func webSocketURL(scheme string, server *httptest.Server, router *mux.Router,
|
||||||
runnerID string, executionID string,
|
runnerID string, executionID string,
|
||||||
) (*url.URL, error) {
|
) (*url.URL, error) {
|
||||||
|
@ -261,7 +261,7 @@ func TestNomadEnvironmentManager_List(t *testing.T) {
|
|||||||
func mockWatchAllocations(apiMock *nomad.ExecutorAPIMock) {
|
func mockWatchAllocations(apiMock *nomad.ExecutorAPIMock) {
|
||||||
call := apiMock.On("WatchEventStream", mock.Anything, mock.Anything, mock.Anything)
|
call := apiMock.On("WatchEventStream", mock.Anything, mock.Anything, mock.Anything)
|
||||||
call.Run(func(args mock.Arguments) {
|
call.Run(func(args mock.Arguments) {
|
||||||
<-time.After(10 * time.Minute) // 10 minutes is the default test timeout
|
<-time.After(tests.DefaultTestTimeout)
|
||||||
call.ReturnArguments = mock.Arguments{nil}
|
call.ReturnArguments = mock.Arguments{nil}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,7 @@ func mockRunnerQueries(apiMock *nomad.ExecutorAPIMock, returnedRunnerIds []strin
|
|||||||
apiMock.ExpectedCalls = []*mock.Call{}
|
apiMock.ExpectedCalls = []*mock.Call{}
|
||||||
call := apiMock.On("WatchEventStream", mock.Anything, mock.Anything, mock.Anything)
|
call := apiMock.On("WatchEventStream", mock.Anything, mock.Anything, mock.Anything)
|
||||||
call.Run(func(args mock.Arguments) {
|
call.Run(func(args mock.Arguments) {
|
||||||
<-time.After(10 * time.Minute) // 10 minutes is the default test timeout
|
<-time.After(tests.DefaultTestTimeout)
|
||||||
call.ReturnArguments = mock.Arguments{nil}
|
call.ReturnArguments = mock.Arguments{nil}
|
||||||
})
|
})
|
||||||
apiMock.On("LoadEnvironmentJobs").Return([]*nomadApi.Job{}, nil)
|
apiMock.On("LoadEnvironmentJobs").Return([]*nomadApi.Job{}, nil)
|
||||||
|
@ -26,6 +26,7 @@ const (
|
|||||||
DefaultExecutionID = "s0m3-3x3cu710n-1d"
|
DefaultExecutionID = "s0m3-3x3cu710n-1d"
|
||||||
DefaultMockID = "m0ck-1d"
|
DefaultMockID = "m0ck-1d"
|
||||||
ShortTimeout = 100 * time.Millisecond
|
ShortTimeout = 100 * time.Millisecond
|
||||||
|
DefaultTestTimeout = 10 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
Reference in New Issue
Block a user