From 4550a4589e42d0a4bc33ecb4267ad6b17e424ff3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Pa=C3=9F?= <22845248+mpass99@users.noreply.github.com> Date: Fri, 3 Feb 2023 01:27:50 +0000 Subject: [PATCH] Dangerous Context Enrichment by passing the Sentry Context down our abstraction stack. This included changes in the complex context management of managing a Command Execution. --- internal/api/websocket.go | 4 ++-- internal/nomad/api_querier.go | 19 ++++++++++++++++--- internal/nomad/nomad.go | 20 +++++++++++++------- internal/nomad/nomad_test.go | 2 +- internal/runner/aws_runner.go | 3 ++- internal/runner/aws_runner_test.go | 9 ++++++--- internal/runner/nomad_runner.go | 15 +++++++++++---- internal/runner/nomad_runner_test.go | 17 ++++++++++------- internal/runner/runner.go | 1 + internal/runner/runner_mock.go | 18 +++++++++--------- 10 files changed, 71 insertions(+), 37 deletions(-) diff --git a/internal/api/websocket.go b/internal/api/websocket.go index 8a44d68..56d6b15 100644 --- a/internal/api/websocket.go +++ b/internal/api/websocket.go @@ -96,9 +96,9 @@ func (r *RunnerController) connectToRunner(writer http.ResponseWriter, request * log.WithField("runnerId", targetRunner.ID()). WithField("executionID", logging.RemoveNewlineSymbol(executionID)). Info("Running execution") - logging.StartSpan("api.runner.connect", "Execute Interactively", request.Context(), func(_ context.Context) { + logging.StartSpan("api.runner.connect", "Execute Interactively", request.Context(), func(ctx context.Context) { exit, cancel, err := targetRunner.ExecuteInteractively(executionID, - proxy.Input, proxy.Output.StdOut(), proxy.Output.StdErr()) + proxy.Input, proxy.Output.StdOut(), proxy.Output.StdErr(), ctx) if err != nil { log.WithError(err).Warn("Cannot execute request.") return // The proxy is stopped by the deferred cancel. diff --git a/internal/nomad/api_querier.go b/internal/nomad/api_querier.go index 78c9827..5e870cd 100644 --- a/internal/nomad/api_querier.go +++ b/internal/nomad/api_querier.go @@ -7,6 +7,7 @@ import ( "github.com/gorilla/websocket" nomadApi "github.com/hashicorp/nomad/api" "github.com/openHPI/poseidon/internal/config" + "github.com/openHPI/poseidon/pkg/logging" "io" ) @@ -88,18 +89,30 @@ func (nc *nomadAPIClient) Execute(runnerID string, ctx context.Context, command []string, tty bool, stdin io.Reader, stdout, stderr io.Writer, ) (int, error) { - allocations, _, err := nc.client.Jobs().Allocations(runnerID, false, nil) + var allocations []*nomadApi.AllocationListStub + var err error + logging.StartSpan("nomad.execute.list", "List Allocations for id", ctx, func(_ context.Context) { + allocations, _, err = nc.client.Jobs().Allocations(runnerID, false, nil) + }) if err != nil { return 1, fmt.Errorf("error retrieving allocations for runner: %w", err) } if len(allocations) == 0 { return 1, ErrorNoAllocationFound } - allocation, _, err := nc.client.Allocations().Info(allocations[0].ID, nil) + + var allocation *nomadApi.Allocation + logging.StartSpan("nomad.execute.info", "List Data of Allocation", ctx, func(_ context.Context) { + allocation, _, err = nc.client.Allocations().Info(allocations[0].ID, nil) + }) if err != nil { return 1, fmt.Errorf("error retrieving allocation info: %w", err) } - exitCode, err := nc.client.Allocations().Exec(ctx, allocation, TaskName, tty, command, stdin, stdout, stderr, nil, nil) + + var exitCode int + logging.StartSpan("nomad.execute.exec", "Execute Command in Allocation", ctx, func(ctx context.Context) { + exitCode, err = nc.client.Allocations().Exec(ctx, allocation, TaskName, tty, command, stdin, stdout, stderr, nil, nil) + }) switch { case err == nil: return exitCode, nil diff --git a/internal/nomad/nomad.go b/internal/nomad/nomad.go index cd24bfb..28fa447 100644 --- a/internal/nomad/nomad.go +++ b/internal/nomad/nomad.go @@ -421,16 +421,22 @@ func (a *APIClient) executeCommandInteractivelyWithStderr(allocationID string, c defer cancel() // Catch stderr in separate execution. - exit, err := a.Execute(allocationID, ctx, prepareCommandTTYStdErr(currentNanoTime, privilegedExecution), true, - nullio.Reader{Ctx: readingContext}, stderr, io.Discard) - if err != nil { - log.WithError(err).WithField("runner", allocationID).Warn("Stderr task finished with error") - } - stderrExitChan <- exit + logging.StartSpan("nomad.execute.stderr", "Execution for separate StdErr", ctx, func(ctx context.Context) { + exit, err := a.Execute(allocationID, ctx, prepareCommandTTYStdErr(currentNanoTime, privilegedExecution), true, + nullio.Reader{Ctx: readingContext}, stderr, io.Discard) + if err != nil { + log.WithError(err).WithField("runner", allocationID).Warn("Stderr task finished with error") + } + stderrExitChan <- exit + }) }() command = prepareCommandTTY(command, currentNanoTime, privilegedExecution) - exit, err := a.Execute(allocationID, ctx, command, true, stdin, stdout, io.Discard) + var exit int + var err error + logging.StartSpan("nomad.execute.tty", "Interactive Execution", ctx, func(ctx context.Context) { + exit, err = a.Execute(allocationID, ctx, command, true, stdin, stdout, io.Discard) + }) // Wait until the stderr catch command finished to make sure we receive all output. <-stderrExitChan diff --git a/internal/nomad/nomad_test.go b/internal/nomad/nomad_test.go index 9818f68..f356480 100644 --- a/internal/nomad/nomad_test.go +++ b/internal/nomad/nomad_test.go @@ -784,7 +784,7 @@ func (s *ExecuteCommandTestSuite) TestWithoutSeparateStderrReturnsCommandError() func (s *ExecuteCommandTestSuite) mockExecute(command interface{}, exitCode int, err error, runFunc func(arguments mock.Arguments)) *mock.Call { - return s.apiMock.On("Execute", s.allocationID, s.ctx, command, withTTY, + return s.apiMock.On("Execute", s.allocationID, mock.Anything, command, withTTY, mock.Anything, mock.Anything, mock.Anything). Run(runFunc). Return(exitCode, err) diff --git a/internal/runner/aws_runner.go b/internal/runner/aws_runner.go index 33b6b10..1963c12 100644 --- a/internal/runner/aws_runner.go +++ b/internal/runner/aws_runner.go @@ -87,7 +87,8 @@ func (w *AWSFunctionWorkload) ExecutionExists(id string) bool { return ok } -func (w *AWSFunctionWorkload) ExecuteInteractively(id string, _ io.ReadWriter, stdout, stderr io.Writer) ( +func (w *AWSFunctionWorkload) ExecuteInteractively( + id string, _ io.ReadWriter, stdout, stderr io.Writer, _ context.Context) ( <-chan ExitInfo, context.CancelFunc, error) { w.ResetTimeout() request, ok := w.executions.Pop(id) diff --git a/internal/runner/aws_runner_test.go b/internal/runner/aws_runner_test.go index a84d80a..08161a5 100644 --- a/internal/runner/aws_runner_test.go +++ b/internal/runner/aws_runner_test.go @@ -76,7 +76,8 @@ func TestAWSFunctionWorkload_ExecuteInteractively(t *testing.T) { cancel() r.StoreExecution(tests.DefaultEnvironmentIDAsString, &dto.ExecutionRequest{}) - exit, _, err := r.ExecuteInteractively(tests.DefaultEnvironmentIDAsString, nil, io.Discard, io.Discard) + exit, _, err := r.ExecuteInteractively( + tests.DefaultEnvironmentIDAsString, nil, io.Discard, io.Discard, context.Background()) require.NoError(t, err) <-exit assert.True(t, awsMock.hasConnected) @@ -89,7 +90,8 @@ func TestAWSFunctionWorkload_ExecuteInteractively(t *testing.T) { request := &dto.ExecutionRequest{Command: command} r.StoreExecution(tests.DefaultEnvironmentIDAsString, request) - _, cancel, err := r.ExecuteInteractively(tests.DefaultEnvironmentIDAsString, nil, io.Discard, io.Discard) + _, cancel, err := r.ExecuteInteractively( + tests.DefaultEnvironmentIDAsString, nil, io.Discard, io.Discard, context.Background()) require.NoError(t, err) <-time.After(tests.ShortTimeout) cancel() @@ -123,7 +125,8 @@ func TestAWSFunctionWorkload_UpdateFileSystem(t *testing.T) { err = r.UpdateFileSystem(&dto.UpdateFileSystemRequest{Copy: []dto.File{myFile}}, context.Background()) assert.NoError(t, err) - _, execCancel, err := r.ExecuteInteractively(tests.DefaultEnvironmentIDAsString, nil, io.Discard, io.Discard) + _, execCancel, err := r.ExecuteInteractively( + tests.DefaultEnvironmentIDAsString, nil, io.Discard, io.Discard, context.Background()) require.NoError(t, err) <-time.After(tests.ShortTimeout) execCancel() diff --git a/internal/runner/nomad_runner.go b/internal/runner/nomad_runner.go index 506b22a..1ec31e6 100644 --- a/internal/runner/nomad_runner.go +++ b/internal/runner/nomad_runner.go @@ -109,6 +109,7 @@ func (r *NomadJob) ExecuteInteractively( id string, stdin io.ReadWriter, stdout, stderr io.Writer, + requestCtx context.Context, ) (<-chan ExitInfo, context.CancelFunc, error) { request, ok := r.executions.Pop(id) if !ok { @@ -117,13 +118,19 @@ func (r *NomadJob) ExecuteInteractively( r.ResetTimeout() - command, ctx, cancel := prepareExecution(request, r.ctx) + // We have to handle three contexts + // - requestCtx: The context of the http request (including Sentry data) + // - r.ctx: The context of the runner (runner timeout) + // - executionCtx: The context of the execution (execution timeout) + // -> The executionCtx cancel that might be triggered (when the client connection breaks) + + command, executionCtx, cancel := prepareExecution(request, r.ctx) exitInternal := make(chan ExitInfo) exit := make(chan ExitInfo, 1) - ctxExecute, cancelExecute := context.WithCancel(r.ctx) + ctxExecute, cancelExecute := context.WithCancel(requestCtx) go r.executeCommand(ctxExecute, command, request.PrivilegedExecution, stdin, stdout, stderr, exitInternal) - go r.handleExitOrContextDone(ctx, cancelExecute, exitInternal, exit, stdin) + go r.handleExitOrContextDone(executionCtx, cancelExecute, exitInternal, exit, stdin) return exit, cancel, nil } @@ -166,7 +173,7 @@ func (r *NomadJob) UpdateFileSystem(copyRequest *dto.UpdateFileSystemRequest, ct updateFileCommand := (&dto.ExecutionRequest{Command: fileDeletionCommand + copyCommand}).FullCommand() stdOut := bytes.Buffer{} stdErr := bytes.Buffer{} - exitCode, err := r.api.ExecuteCommand(r.id, context.Background(), updateFileCommand, false, + exitCode, err := r.api.ExecuteCommand(r.id, ctx, updateFileCommand, false, nomad.PrivilegedExecution, // All files should be written and owned by a privileged user #211. &tarBuffer, &stdOut, &stdErr) if err != nil { diff --git a/internal/runner/nomad_runner_test.go b/internal/runner/nomad_runner_test.go index d9ec4cd..87e51b3 100644 --- a/internal/runner/nomad_runner_test.go +++ b/internal/runner/nomad_runner_test.go @@ -132,14 +132,14 @@ func (s *ExecuteInteractivelyTestSuite) SetupTest() { } func (s *ExecuteInteractivelyTestSuite) TestReturnsErrorWhenExecutionDoesNotExist() { - _, _, err := s.runner.ExecuteInteractively("non-existent-id", nil, nil, nil) + _, _, err := s.runner.ExecuteInteractively("non-existent-id", nil, nil, nil, context.Background()) s.ErrorIs(err, ErrorUnknownExecution) } func (s *ExecuteInteractivelyTestSuite) TestCallsApi() { request := &dto.ExecutionRequest{Command: "echo 'Hello World!'"} s.runner.StoreExecution(defaultExecutionID, request) - _, _, err := s.runner.ExecuteInteractively(defaultExecutionID, nil, nil, nil) + _, _, err := s.runner.ExecuteInteractively(defaultExecutionID, nil, nil, nil, context.Background()) s.Require().NoError(err) time.Sleep(tests.ShortTimeout) @@ -155,7 +155,7 @@ func (s *ExecuteInteractivelyTestSuite) TestReturnsAfterTimeout() { timeLimit := 1 executionRequest := &dto.ExecutionRequest{TimeLimit: timeLimit} s.runner.StoreExecution(defaultExecutionID, executionRequest) - exit, _, err := s.runner.ExecuteInteractively(defaultExecutionID, &nullio.ReadWriter{}, nil, nil) + exit, _, err := s.runner.ExecuteInteractively(defaultExecutionID, &nullio.ReadWriter{}, nil, nil, context.Background()) s.Require().NoError(err) select { @@ -191,7 +191,8 @@ func (s *ExecuteInteractivelyTestSuite) TestSendsSignalAfterTimeout() { timeLimit := 1 executionRequest := &dto.ExecutionRequest{TimeLimit: timeLimit} s.runner.StoreExecution(defaultExecutionID, executionRequest) - _, _, err := s.runner.ExecuteInteractively(defaultExecutionID, bytes.NewBuffer(make([]byte, 1)), nil, nil) + _, _, err := s.runner.ExecuteInteractively( + defaultExecutionID, bytes.NewBuffer(make([]byte, 1)), nil, nil, context.Background()) s.Require().NoError(err) log.Info("Before waiting") select { @@ -210,7 +211,8 @@ func (s *ExecuteInteractivelyTestSuite) TestDestroysRunnerAfterTimeoutAndSignal( executionRequest := &dto.ExecutionRequest{TimeLimit: timeLimit} s.runner.cancel = func() {} s.runner.StoreExecution(defaultExecutionID, executionRequest) - _, _, err := s.runner.ExecuteInteractively(defaultExecutionID, bytes.NewBuffer(make([]byte, 1)), nil, nil) + _, _, err := s.runner.ExecuteInteractively( + defaultExecutionID, bytes.NewBuffer(make([]byte, 1)), nil, nil, context.Background()) s.Require().NoError(err) <-time.After(executionTimeoutGracePeriod + time.Duration(timeLimit)*time.Second + tests.ShortTimeout) s.manager.AssertCalled(s.T(), "Return", s.runner) @@ -219,7 +221,7 @@ func (s *ExecuteInteractivelyTestSuite) TestDestroysRunnerAfterTimeoutAndSignal( func (s *ExecuteInteractivelyTestSuite) TestResetTimerGetsCalled() { executionRequest := &dto.ExecutionRequest{} s.runner.StoreExecution(defaultExecutionID, executionRequest) - _, _, err := s.runner.ExecuteInteractively(defaultExecutionID, nil, nil, nil) + _, _, err := s.runner.ExecuteInteractively(defaultExecutionID, nil, nil, nil, context.Background()) s.Require().NoError(err) s.timer.AssertCalled(s.T(), "ResetTimeout") } @@ -228,7 +230,8 @@ func (s *ExecuteInteractivelyTestSuite) TestExitHasTimeoutErrorIfRunnerTimesOut( s.mockedTimeoutPassedCall.Return(true) executionRequest := &dto.ExecutionRequest{} s.runner.StoreExecution(defaultExecutionID, executionRequest) - exitChannel, _, err := s.runner.ExecuteInteractively(defaultExecutionID, &nullio.ReadWriter{}, nil, nil) + exitChannel, _, err := s.runner.ExecuteInteractively( + defaultExecutionID, &nullio.ReadWriter{}, nil, nil, context.Background()) s.Require().NoError(err) exit := <-exitChannel s.Equal(ErrorRunnerInactivityTimeout, exit.Err) diff --git a/internal/runner/runner.go b/internal/runner/runner.go index fd7920d..7212dfd 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -43,6 +43,7 @@ type Runner interface { stdin io.ReadWriter, stdout, stderr io.Writer, + ctx context.Context, ) (exit <-chan ExitInfo, cancel context.CancelFunc, err error) // ListFileSystem streams the listing of the file system of the requested directory into the Writer provided. diff --git a/internal/runner/runner_mock.go b/internal/runner/runner_mock.go index 0d48cae..a40a796 100644 --- a/internal/runner/runner_mock.go +++ b/internal/runner/runner_mock.go @@ -48,13 +48,13 @@ func (_m *RunnerMock) Environment() dto.EnvironmentID { return r0 } -// ExecuteInteractively provides a mock function with given fields: id, stdin, stdout, stderr -func (_m *RunnerMock) ExecuteInteractively(id string, stdin io.ReadWriter, stdout io.Writer, stderr io.Writer) (<-chan ExitInfo, context.CancelFunc, error) { - ret := _m.Called(id, stdin, stdout, stderr) +// ExecuteInteractively provides a mock function with given fields: id, stdin, stdout, stderr, ctx +func (_m *RunnerMock) ExecuteInteractively(id string, stdin io.ReadWriter, stdout io.Writer, stderr io.Writer, ctx context.Context) (<-chan ExitInfo, context.CancelFunc, error) { + ret := _m.Called(id, stdin, stdout, stderr, ctx) var r0 <-chan ExitInfo - if rf, ok := ret.Get(0).(func(string, io.ReadWriter, io.Writer, io.Writer) <-chan ExitInfo); ok { - r0 = rf(id, stdin, stdout, stderr) + if rf, ok := ret.Get(0).(func(string, io.ReadWriter, io.Writer, io.Writer, context.Context) <-chan ExitInfo); ok { + r0 = rf(id, stdin, stdout, stderr, ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(<-chan ExitInfo) @@ -62,8 +62,8 @@ func (_m *RunnerMock) ExecuteInteractively(id string, stdin io.ReadWriter, stdou } var r1 context.CancelFunc - if rf, ok := ret.Get(1).(func(string, io.ReadWriter, io.Writer, io.Writer) context.CancelFunc); ok { - r1 = rf(id, stdin, stdout, stderr) + if rf, ok := ret.Get(1).(func(string, io.ReadWriter, io.Writer, io.Writer, context.Context) context.CancelFunc); ok { + r1 = rf(id, stdin, stdout, stderr, ctx) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(context.CancelFunc) @@ -71,8 +71,8 @@ func (_m *RunnerMock) ExecuteInteractively(id string, stdin io.ReadWriter, stdou } var r2 error - if rf, ok := ret.Get(2).(func(string, io.ReadWriter, io.Writer, io.Writer) error); ok { - r2 = rf(id, stdin, stdout, stderr) + if rf, ok := ret.Get(2).(func(string, io.ReadWriter, io.Writer, io.Writer, context.Context) error); ok { + r2 = rf(id, stdin, stdout, stderr, ctx) } else { r2 = ret.Error(2) }