Send SIGQUIT when cancelling an execution

When the context passed to Nomad Allocation Exec is cancelled, the
process is not terminated. Instead, just the WebSocket connection is
closed. In order to terminate long-running processes, a special
character is injected into the standard input stream. This character is
parsed by the tty line discipline (tty has to be true). The line
discipline sends a SIGQUIT signal to the process, terminating it and
producing a core dump (in a file called 'core'). The SIGQUIT signal can
be caught but isn't by default, which is why the runner is destroyed if
the program does not terminate during a grace period after the signal
was sent.
This commit is contained in:
Konrad Hanff
2021-07-21 09:12:44 +02:00
parent 91537a7364
commit 8d24bda61a
10 changed files with 237 additions and 63 deletions

View File

@ -24,8 +24,11 @@ type webSocketConnection interface {
SetCloseHandler(handler func(code int, text string) error)
}
// WebSocketReader is an interface that is intended for providing abstraction around reading from a WebSocket.
// Besides io.Reader, it also implements io.Writer. The Write method is used to inject data into the WebSocket stream.
type WebSocketReader interface {
io.Reader
io.Writer
startReadInputLoop() context.CancelFunc
}
@ -38,12 +41,17 @@ type codeOceanToRawReader struct {
// and retrieve it when Read(..) is called. Since channels are thread-safe, we use one here
// instead of bytes.Buffer.
buffer chan byte
// The priorityBuffer is a buffer for injecting data into stdin of the execution from Poseidon,
// for example the character that causes the tty to generate a SIGQUIT signal.
// It is always read before the regular buffer.
priorityBuffer chan byte
}
func newCodeOceanToRawReader(connection webSocketConnection) *codeOceanToRawReader {
return &codeOceanToRawReader{
connection: connection,
buffer: make(chan byte, CodeOceanToRawReaderBufferSize),
connection: connection,
buffer: make(chan byte, CodeOceanToRawReaderBufferSize),
priorityBuffer: make(chan byte, CodeOceanToRawReaderBufferSize),
}
}
@ -101,16 +109,20 @@ func (cr *codeOceanToRawReader) startReadInputLoop() context.CancelFunc {
}
// Read implements the io.Reader interface.
// It returns bytes from the buffer.
// It returns bytes from the buffer or priorityBuffer.
func (cr *codeOceanToRawReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
// Ensure to not return until at least one byte has been read to avoid busy waiting.
p[0] = <-cr.buffer
select {
case p[0] = <-cr.priorityBuffer:
case p[0] = <-cr.buffer:
}
var n int
for n = 1; n < len(p); n++ {
select {
case p[n] = <-cr.priorityBuffer:
case p[n] = <-cr.buffer:
default:
return n, nil
@ -119,6 +131,20 @@ func (cr *codeOceanToRawReader) Read(p []byte) (int, error) {
return n, nil
}
// Write implements the io.Writer interface.
// Data written to a codeOceanToRawReader using this method is returned by Read before other data from the WebSocket.
func (cr *codeOceanToRawReader) Write(p []byte) (n int, err error) {
var c byte
for n, c = range p {
select {
case cr.priorityBuffer <- c:
default:
break
}
}
return n, nil
}
// rawToCodeOceanWriter is an io.Writer implementation that, when written to, wraps the written data in the appropriate
// json structure and sends it to the CodeOcean via WebSocket.
type rawToCodeOceanWriter struct {

View File

@ -223,7 +223,7 @@ func (s *WebSocketTestSuite) TestWebsocketError() {
s.True(websocket.IsCloseError(err, websocket.CloseNormalClosure))
_, _, errMessages := helpers.WebSocketOutputMessages(messages)
s.Equal(1, len(errMessages))
s.Require().Equal(1, len(errMessages))
s.Equal("Error executing the request", errMessages[0])
}
@ -370,10 +370,12 @@ func TestCodeOceanToRawReaderReturnsOnlyAfterOneByteWasReadFromConnection(t *tes
// --- Test suite specific test helpers ---
func newNomadAllocationWithMockedAPIClient(runnerID string) (r runner.Runner, executorAPIMock *nomad.ExecutorAPIMock) {
executorAPIMock = &nomad.ExecutorAPIMock{}
r = runner.NewNomadJob(runnerID, nil, executorAPIMock, nil)
return
func newNomadAllocationWithMockedAPIClient(runnerID string) (runner.Runner, *nomad.ExecutorAPIMock) {
executorAPIMock := &nomad.ExecutorAPIMock{}
manager := &runner.ManagerMock{}
manager.On("Return", mock.Anything).Return(nil)
r := runner.NewNomadJob(runnerID, nil, executorAPIMock, manager)
return r, executorAPIMock
}
func webSocketURL(scheme string, server *httptest.Server, router *mux.Router,
@ -429,14 +431,21 @@ func mockAPIExecuteHead(api *nomad.ExecutorAPIMock) {
var executionRequestSleep = dto.ExecutionRequest{Command: "sleep infinity"}
// mockAPIExecuteSleep mocks the ExecuteCommand method of an ExecutorAPI to sleep until the execution is canceled.
// mockAPIExecuteSleep mocks the ExecuteCommand method of an ExecutorAPI to sleep
// until the execution receives a SIGQUIT.
func mockAPIExecuteSleep(api *nomad.ExecutorAPIMock) <-chan bool {
canceled := make(chan bool, 1)
mockAPIExecute(api, &executionRequestSleep,
func(_ string, ctx context.Context, _ []string, _ bool,
stdin io.Reader, stdout io.Writer, stderr io.Writer,
) (int, error) {
<-ctx.Done()
var err error
buffer := make([]byte, 1) //nolint:makezero // if the length is zero, the Read call never reads anything
for n := 0; !(n == 1 && buffer[0] == runner.SIGQUIT); n, err = stdin.Read(buffer) {
if err != nil {
return 0, fmt.Errorf("error while reading stdin: %w", err)
}
}
close(canceled)
return 0, ctx.Err()
})

View File

@ -8,7 +8,7 @@ import (
"github.com/hashicorp/nomad/nomad/structs"
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/config"
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/logging"
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullreader"
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullio"
"io"
"net/url"
"strconv"
@ -348,7 +348,7 @@ func (a *APIClient) executeCommandInteractivelyWithStderr(allocationID string, c
go func() {
// Catch stderr in separate execution.
exit, err := a.Execute(allocationID, ctx, stderrFifoCommand(currentNanoTime), true,
nullreader.NullReader{}, stderr, io.Discard)
nullio.Reader{}, stderr, io.Discard)
if err != nil {
log.WithError(err).WithField("runner", allocationID).Warn("Stderr task finished with error")
}

View File

@ -12,7 +12,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/config"
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullreader"
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullio"
"gitlab.hpi.de/codeocean/codemoon/poseidon/tests"
"io"
"net/url"
@ -679,7 +679,7 @@ func (s *ExecuteCommandTestSuite) TestWithSeparateStderr() {
})
exitCode, err := s.nomadAPIClient.
ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, &stdout, &stderr)
ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, &stdout, &stderr)
s.Require().NoError(err)
s.apiMock.AssertNumberOfCalls(s.T(), "Execute", 2)
@ -710,7 +710,7 @@ func (s *ExecuteCommandTestSuite) TestWithSeparateStderrReturnsCommandError() {
s.mockExecute(s.testCommandArray, 1, tests.ErrDefault, func(args mock.Arguments) {})
s.mockExecute(mock.AnythingOfType("[]string"), 1, nil, func(args mock.Arguments) {})
_, err := s.nomadAPIClient.
ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, io.Discard, io.Discard)
ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, io.Discard, io.Discard)
s.Equal(tests.ErrDefault, err)
}
@ -732,7 +732,7 @@ func (s *ExecuteCommandTestSuite) TestWithoutSeparateStderr() {
})
exitCode, err := s.nomadAPIClient.
ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, &stdout, &stderr)
ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, &stdout, &stderr)
s.Require().NoError(err)
s.apiMock.AssertNumberOfCalls(s.T(), "Execute", 1)
@ -745,7 +745,7 @@ func (s *ExecuteCommandTestSuite) TestWithoutSeparateStderrReturnsCommandError()
config.Config.Server.InteractiveStderr = false
s.mockExecute(s.testCommandArray, 1, tests.ErrDefault, func(args mock.Arguments) {})
_, err := s.nomadAPIClient.
ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullreader.NullReader{}, io.Discard, io.Discard)
ExecuteCommand(s.allocationID, s.ctx, s.testCommandArray, withTTY, nullio.Reader{}, io.Discard, io.Discard)
s.ErrorIs(err, tests.ErrDefault)
}

View File

@ -25,6 +25,11 @@ type ExecutionID string
const (
// runnerContextKey is the key used to store runners in context.Context.
runnerContextKey ContextKey = "runner"
// SIGQUIT is the character that causes a tty to send the SIGQUIT signal to the controlled process.
SIGQUIT = 0x1c
// executionTimeoutGracePeriod is the time to wait after sending a SIGQUIT signal to a timed out execution.
// If the execution does not return after this grace period, the runner is destroyed.
executionTimeoutGracePeriod = 3 * time.Second
)
var (
@ -143,7 +148,7 @@ type Runner interface {
// Output from the runner is forwarded immediately.
ExecuteInteractively(
request *dto.ExecutionRequest,
stdin io.Reader,
stdin io.ReadWriter,
stdout,
stderr io.Writer,
) (exit <-chan ExitInfo, cancel context.CancelFunc)
@ -151,6 +156,9 @@ type Runner interface {
// UpdateFileSystem processes a dto.UpdateFileSystemRequest by first deleting each given dto.FilePath recursively
// and then copying each given dto.File to the runner.
UpdateFileSystem(request *dto.UpdateFileSystemRequest) error
// Destroy destroys the Runner in Nomad.
Destroy() error
}
// NomadJob is an abstraction to communicate with Nomad environments.
@ -160,6 +168,7 @@ type NomadJob struct {
id string
portMappings []nomadApi.PortMapping
api nomad.ExecutorAPI
manager Manager
}
// NewNomadJob creates a new NomadJob with the provided id.
@ -171,6 +180,7 @@ func NewNomadJob(id string, portMappings []nomadApi.PortMapping,
portMappings: portMappings,
api: apiClient,
ExecutionStorage: NewLocalExecutionStorage(),
manager: manager,
}
job.InactivityTimer = NewInactivityTimer(job, manager)
return job
@ -196,30 +206,80 @@ type ExitInfo struct {
Err error
}
func (r *NomadJob) ExecuteInteractively(
request *dto.ExecutionRequest,
stdin io.Reader,
stdout, stderr io.Writer,
) (<-chan ExitInfo, context.CancelFunc) {
r.ResetTimeout()
command := request.FullCommand()
var ctx context.Context
var cancel context.CancelFunc
func prepareExecution(request *dto.ExecutionRequest) (
command []string, ctx context.Context, cancel context.CancelFunc,
) {
command = request.FullCommand()
if request.TimeLimit == 0 {
ctx, cancel = context.WithCancel(context.Background())
} else {
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(request.TimeLimit)*time.Second)
}
exit := make(chan ExitInfo)
go func() {
exitCode, err := r.api.ExecuteCommand(r.id, ctx, command, true, stdin, stdout, stderr)
if err == nil && r.TimeoutPassed() {
err = ErrorRunnerInactivityTimeout
}
exit <- ExitInfo{uint8(exitCode), err}
return command, ctx, cancel
}
func (r *NomadJob) executeCommand(ctx context.Context, command []string,
stdin io.ReadWriter, stdout, stderr io.Writer, exit chan<- ExitInfo,
) {
exitCode, err := r.api.ExecuteCommand(r.id, ctx, command, true, stdin, stdout, stderr)
if err == nil && r.TimeoutPassed() {
err = ErrorRunnerInactivityTimeout
}
exit <- ExitInfo{uint8(exitCode), err}
}
func (r *NomadJob) handleExitOrContextDone(ctx context.Context, cancelExecute context.CancelFunc,
exitInternal <-chan ExitInfo, exit chan<- ExitInfo, stdin io.ReadWriter,
) {
defer cancelExecute()
select {
case exitInfo := <-exitInternal:
exit <- exitInfo
close(exit)
}()
return
case <-ctx.Done():
// From this time on until the WebSocket connection to the client is closed in /internal/api/websocket.go
// waitForExit, output can still be forwarded to the client. We accept this race condition because adding
// a locking mechanism would complicate the interfaces used (currently io.Writer).
exit <- ExitInfo{255, ctx.Err()}
close(exit)
}
// This injects the SIGQUIT character into the stdin. This character is parsed by the tty line discipline
// (tty has to be true) and converted to a SIGQUIT signal sent to the foreground process attached to the tty.
// By default, SIGQUIT causes the process to terminate and produces a core dump. Processes can catch this signal
// and ignore it, which is why we destroy the runner if the process does not terminate after a grace period.
n, err := stdin.Write([]byte{SIGQUIT})
if n != 1 {
log.WithField("runner", r.id).Warn("Could not send SIGQUIT because nothing was written")
}
if err != nil {
log.WithField("runner", r.id).WithError(err).Warn("Could not send SIGQUIT due to error")
}
select {
case <-exitInternal:
log.WithField("runner", r.id).Debug("Execution terminated after SIGQUIT")
case <-time.After(executionTimeoutGracePeriod):
log.WithField("runner", r.id).Info("Execution did not quit after SIGQUIT")
if err := r.Destroy(); err != nil {
log.WithField("runner", r.id).Error("Error when destroying runner")
}
}
}
func (r *NomadJob) ExecuteInteractively(
request *dto.ExecutionRequest,
stdin io.ReadWriter,
stdout, stderr io.Writer,
) (<-chan ExitInfo, context.CancelFunc) {
r.ResetTimeout()
command, ctx, cancel := prepareExecution(request)
exitInternal := make(chan ExitInfo)
exit := make(chan ExitInfo, 1)
ctxExecute, cancelExecute := context.WithCancel(context.Background())
go r.executeCommand(ctxExecute, command, stdin, stdout, stderr, exitInternal)
go r.handleExitOrContextDone(ctx, cancelExecute, exitInternal, exit, stdin)
return exit, cancel
}
@ -255,6 +315,13 @@ func (r *NomadJob) UpdateFileSystem(copyRequest *dto.UpdateFileSystemRequest) er
return nil
}
func (r *NomadJob) Destroy() error {
if err := r.manager.Return(r); err != nil {
return fmt.Errorf("error while destroying runner: %w", err)
}
return nil
}
func createTarArchiveForFiles(filesToCopy []dto.File, w io.Writer) error {
tarWriter := tar.NewWriter(w)
for _, file := range filesToCopy {

View File

@ -23,12 +23,26 @@ func (_m *RunnerMock) Add(id ExecutionID, executionRequest *dto.ExecutionRequest
_m.Called(id, executionRequest)
}
// Destroy provides a mock function with given fields:
func (_m *RunnerMock) Destroy() error {
ret := _m.Called()
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// ExecuteInteractively provides a mock function with given fields: request, stdin, stdout, stderr
func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin io.Reader, stdout io.Writer, stderr io.Writer) (<-chan ExitInfo, context.CancelFunc) {
func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin io.ReadWriter, stdout io.Writer, stderr io.Writer) (<-chan ExitInfo, context.CancelFunc) {
ret := _m.Called(request, stdin, stdout, stderr)
var r0 <-chan ExitInfo
if rf, ok := ret.Get(0).(func(*dto.ExecutionRequest, io.Reader, io.Writer, io.Writer) <-chan ExitInfo); ok {
if rf, ok := ret.Get(0).(func(*dto.ExecutionRequest, io.ReadWriter, io.Writer, io.Writer) <-chan ExitInfo); ok {
r0 = rf(request, stdin, stdout, stderr)
} else {
if ret.Get(0) != nil {
@ -37,7 +51,7 @@ func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin
}
var r1 context.CancelFunc
if rf, ok := ret.Get(1).(func(*dto.ExecutionRequest, io.Reader, io.Writer, io.Writer) context.CancelFunc); ok {
if rf, ok := ret.Get(1).(func(*dto.ExecutionRequest, io.ReadWriter, io.Writer, io.Writer) context.CancelFunc); ok {
r1 = rf(request, stdin, stdout, stderr)
} else {
if ret.Get(1) != nil {
@ -48,7 +62,7 @@ func (_m *RunnerMock) ExecuteInteractively(request *dto.ExecutionRequest, stdin
return r0, r1
}
// Id provides a mock function with given fields:
// ID provides a mock function with given fields:
func (_m *RunnerMock) ID() string {
ret := _m.Called()

View File

@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/suite"
"gitlab.hpi.de/codeocean/codemoon/poseidon/internal/nomad"
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/dto"
"gitlab.hpi.de/codeocean/codemoon/poseidon/pkg/nullio"
"gitlab.hpi.de/codeocean/codemoon/poseidon/tests"
"io"
"regexp"
@ -82,6 +83,15 @@ func TestFromContextReturnsIsNotOkWhenContextHasNoRunner(t *testing.T) {
assert.False(t, ok)
}
func TestDestroyReturnsRunner(t *testing.T) {
manager := &ManagerMock{}
manager.On("Return", mock.Anything).Return(nil)
runner := NewRunner(tests.DefaultRunnerID, manager)
err := runner.Destroy()
assert.NoError(t, err)
manager.AssertCalled(t, "Return", runner)
}
func TestExecuteInteractivelyTestSuite(t *testing.T) {
suite.Run(t, new(ExecuteInteractivelyTestSuite))
}
@ -91,6 +101,7 @@ type ExecuteInteractivelyTestSuite struct {
runner *NomadJob
apiMock *nomad.ExecutorAPIMock
timer *InactivityTimerMock
manager *ManagerMock
mockedExecuteCommandCall *mock.Call
mockedTimeoutPassedCall *mock.Call
}
@ -103,11 +114,15 @@ func (s *ExecuteInteractivelyTestSuite) SetupTest() {
s.timer = &InactivityTimerMock{}
s.timer.On("ResetTimeout").Return()
s.mockedTimeoutPassedCall = s.timer.On("TimeoutPassed").Return(false)
s.manager = &ManagerMock{}
s.manager.On("Return", mock.Anything).Return(nil)
s.runner = &NomadJob{
ExecutionStorage: NewLocalExecutionStorage(),
InactivityTimer: s.timer,
id: tests.DefaultRunnerID,
api: s.apiMock,
manager: s.manager,
}
}
@ -122,15 +137,12 @@ func (s *ExecuteInteractivelyTestSuite) TestCallsApi() {
func (s *ExecuteInteractivelyTestSuite) TestReturnsAfterTimeout() {
s.mockedExecuteCommandCall.Run(func(args mock.Arguments) {
ctx, ok := args.Get(1).(context.Context)
s.Require().True(ok)
<-ctx.Done()
}).
Return(0, nil)
select {}
}).Return(0, nil)
timeLimit := 1
execution := &dto.ExecutionRequest{TimeLimit: timeLimit}
exit, _ := s.runner.ExecuteInteractively(execution, nil, nil, nil)
exit, _ := s.runner.ExecuteInteractively(execution, &nullio.ReadWriter{}, nil, nil)
select {
case <-exit:
@ -142,20 +154,52 @@ func (s *ExecuteInteractivelyTestSuite) TestReturnsAfterTimeout() {
case <-time.After(time.Duration(timeLimit) * time.Second):
s.FailNow("ExecuteInteractively should return after the time limit")
case exitInfo := <-exit:
s.Equal(uint8(0), exitInfo.Code)
s.Equal(uint8(255), exitInfo.Code)
}
}
func (s *ExecuteInteractivelyTestSuite) TestSendsSignalAfterTimeout() {
quit := make(chan struct{})
s.mockedExecuteCommandCall.Run(func(args mock.Arguments) {
stdin, ok := args.Get(4).(io.Reader)
s.Require().True(ok)
buffer := make([]byte, 1) //nolint:makezero,lll // If the length is zero, the Read call never reads anything. gofmt want this alignment.
for n := 0; !(n == 1 && buffer[0] == SIGQUIT); n, _ = stdin.Read(buffer) { //nolint:errcheck,lll // Read returns EOF errors but that is expected. This nolint makes the line too long.
time.After(tests.ShortTimeout)
}
close(quit)
}).Return(0, nil)
timeLimit := 1
execution := &dto.ExecutionRequest{TimeLimit: timeLimit}
_, _ = s.runner.ExecuteInteractively(execution, bytes.NewBuffer(make([]byte, 1)), nil, nil)
select {
case <-time.After(2 * (time.Duration(timeLimit) * time.Second)):
s.FailNow("The execution should receive a SIGQUIT after the timeout")
case <-quit:
}
}
func (s *ExecuteInteractivelyTestSuite) TestDestroysRunnerAfterTimeoutAndSignal() {
s.mockedExecuteCommandCall.Run(func(args mock.Arguments) {
select {}
})
timeLimit := 1
execution := &dto.ExecutionRequest{TimeLimit: timeLimit}
_, _ = s.runner.ExecuteInteractively(execution, bytes.NewBuffer(make([]byte, 1)), nil, nil)
<-time.After(executionTimeoutGracePeriod + time.Duration(timeLimit)*time.Second + tests.ShortTimeout)
s.manager.AssertCalled(s.T(), "Return", s.runner)
}
func (s *ExecuteInteractivelyTestSuite) TestResetTimerGetsCalled() {
execution := &dto.ExecutionRequest{}
s.runner.ExecuteInteractively(execution, nil, nil, nil)
s.timer.AssertCalled(s.T(), "ResetTimeout")
}
func (s *ExecuteInteractivelyTestSuite) TestExitHasTimeoutErrorIfExecutionTimesOut() {
func (s *ExecuteInteractivelyTestSuite) TestExitHasTimeoutErrorIfRunnerTimesOut() {
s.mockedTimeoutPassedCall.Return(true)
execution := &dto.ExecutionRequest{}
exitChannel, _ := s.runner.ExecuteInteractively(execution, nil, nil, nil)
exitChannel, _ := s.runner.ExecuteInteractively(execution, &nullio.ReadWriter{}, nil, nil)
exit := <-exitChannel
s.Equal(ErrorRunnerInactivityTimeout, exit.Err)
}