diff --git a/internal/runner/nomad_runner.go b/internal/runner/nomad_runner.go index 02c03ab..ebc6cf4 100644 --- a/internal/runner/nomad_runner.go +++ b/internal/runner/nomad_runner.go @@ -149,7 +149,8 @@ func (r *NomadJob) ExecuteInteractively( } func (r *NomadJob) ListFileSystem( - path string, recursive bool, content io.Writer, privilegedExecution bool, ctx context.Context) error { + path string, recursive bool, content io.Writer, privilegedExecution bool, requestCtx context.Context) error { + ctx := util.NewMergeContext([]context.Context{r.ctx, requestCtx}) r.ResetTimeout() command := lsCommand if recursive { @@ -173,7 +174,8 @@ func (r *NomadJob) ListFileSystem( return err } -func (r *NomadJob) UpdateFileSystem(copyRequest *dto.UpdateFileSystemRequest, ctx context.Context) error { +func (r *NomadJob) UpdateFileSystem(copyRequest *dto.UpdateFileSystemRequest, requestCtx context.Context) error { + ctx := util.NewMergeContext([]context.Context{r.ctx, requestCtx}) r.ResetTimeout() var tarBuffer bytes.Buffer @@ -206,7 +208,8 @@ func (r *NomadJob) UpdateFileSystem(copyRequest *dto.UpdateFileSystemRequest, ct } func (r *NomadJob) GetFileContent( - path string, content http.ResponseWriter, privilegedExecution bool, ctx context.Context) error { + path string, content http.ResponseWriter, privilegedExecution bool, requestCtx context.Context) error { + ctx := util.NewMergeContext([]context.Context{r.ctx, requestCtx}) r.ResetTimeout() contentLengthWriter := &nullio.ContentLengthWriter{Target: content} diff --git a/internal/runner/nomad_runner_test.go b/internal/runner/nomad_runner_test.go index d90c393..da954b2 100644 --- a/internal/runner/nomad_runner_test.go +++ b/internal/runner/nomad_runner_test.go @@ -468,3 +468,30 @@ func (s *UpdateFileSystemTestSuite) TestGetFileContentReturnsErrorIfExitCodeIsNo err := s.runner.GetFileContent("", logging.NewLoggingResponseWriter(nil), false, context.Background()) s.ErrorIs(err, ErrFileNotFound) } + +func (s *UpdateFileSystemTestSuite) TestFileCopyIsCanceledOnRunnerDestroy() { + s.mockedExecuteCommandCall.Run(func(args mock.Arguments) { + ctx, ok := args.Get(1).(context.Context) + s.Require().True(ok) + + select { + case <-ctx.Done(): + s.Fail("mergeContext is done before any of its parents") + return + case <-time.After(tests.ShortTimeout): + } + + select { + case <-ctx.Done(): + case <-time.After(3 * tests.ShortTimeout): + s.Fail("mergeContext is not done after the earliest of its parents") + return + } + }) + ctx, cancel := context.WithCancel(context.Background()) + s.runner.ctx = ctx + s.runner.cancel = cancel + + <-time.After(2 * tests.ShortTimeout) + s.runner.cancel() +}