diff --git a/internal/environment/nomad_environment.go b/internal/environment/nomad_environment.go index b39fd34..ed3d7b3 100644 --- a/internal/environment/nomad_environment.go +++ b/internal/environment/nomad_environment.go @@ -29,6 +29,8 @@ type NomadEnvironment struct { jobHCL string job *nomadApi.Job idleRunners storage.Storage[runner.Runner] + ctx context.Context + cancel context.CancelFunc } func NewNomadEnvironment(id dto.EnvironmentID, apiClient nomad.ExecutorAPI, jobHCL string) (*NomadEnvironment, error) { @@ -37,9 +39,10 @@ func NewNomadEnvironment(id dto.EnvironmentID, apiClient nomad.ExecutorAPI, jobH return nil, fmt.Errorf("error parsing Nomad job: %w", err) } - e := &NomadEnvironment{apiClient, jobHCL, job, nil} + ctx, cancel := context.WithCancel(context.Background()) + e := &NomadEnvironment{apiClient, jobHCL, job, nil, ctx, cancel} e.idleRunners = storage.NewMonitoredLocalStorage[runner.Runner](monitoring.MeasurementIdleRunnerNomad, - runner.MonitorEnvironmentID[runner.Runner](id), time.Minute) + runner.MonitorEnvironmentID[runner.Runner](id), time.Minute, ctx) return e, nil } @@ -218,6 +221,7 @@ func (n *NomadEnvironment) Register() error { } func (n *NomadEnvironment) Delete() error { + n.cancel() err := n.removeRunners() if err != nil { return err diff --git a/internal/environment/nomad_environment_test.go b/internal/environment/nomad_environment_test.go index c278108..4dee667 100644 --- a/internal/environment/nomad_environment_test.go +++ b/internal/environment/nomad_environment_test.go @@ -18,7 +18,7 @@ import ( func TestConfigureNetworkCreatesNewNetworkWhenNoNetworkExists(t *testing.T) { _, job := helpers.CreateTemplateJob() defaultTaskGroup := nomad.FindAndValidateDefaultTaskGroup(job) - environment := &NomadEnvironment{nil, "", job, nil} + environment := &NomadEnvironment{nil, "", job, nil, nil, nil} if assert.Equal(t, 0, len(defaultTaskGroup.Networks)) { environment.SetNetworkAccess(true, []uint16{}) @@ -30,7 +30,7 @@ func TestConfigureNetworkCreatesNewNetworkWhenNoNetworkExists(t *testing.T) { func TestConfigureNetworkDoesNotCreateNewNetworkWhenNetworkExists(t *testing.T) { _, job := helpers.CreateTemplateJob() defaultTaskGroup := nomad.FindAndValidateDefaultTaskGroup(job) - environment := &NomadEnvironment{nil, "", job, nil} + environment := &NomadEnvironment{nil, "", job, nil, nil, nil} networkResource := &nomadApi.NetworkResource{Mode: "cni/secure-bridge"} defaultTaskGroup.Networks = []*nomadApi.NetworkResource{networkResource} @@ -59,7 +59,7 @@ func TestConfigureNetworkSetsCorrectValues(t *testing.T) { _, testJob := helpers.CreateTemplateJob() testTaskGroup := nomad.FindAndValidateDefaultTaskGroup(testJob) testTask := nomad.FindAndValidateDefaultTask(testTaskGroup) - testEnvironment := &NomadEnvironment{nil, "", job, nil} + testEnvironment := &NomadEnvironment{nil, "", job, nil, nil, nil} testEnvironment.SetNetworkAccess(false, ports) mode, ok := testTask.Config["network_mode"] @@ -74,7 +74,7 @@ func TestConfigureNetworkSetsCorrectValues(t *testing.T) { _, testJob := helpers.CreateTemplateJob() testTaskGroup := nomad.FindAndValidateDefaultTaskGroup(testJob) testTask := nomad.FindAndValidateDefaultTask(testTaskGroup) - testEnvironment := &NomadEnvironment{nil, "", testJob, nil} + testEnvironment := &NomadEnvironment{nil, "", testJob, nil, nil, nil} testEnvironment.SetNetworkAccess(true, ports) require.Equal(t, 1, len(testTaskGroup.Networks)) @@ -114,7 +114,8 @@ func TestRegisterFailsWhenNomadJobRegistrationFails(t *testing.T) { apiClientMock.On("LoadRunnerIDs", mock.AnythingOfType("string")).Return([]string{}, nil) apiClientMock.On("DeleteJob", mock.AnythingOfType("string")).Return(nil) - environment := &NomadEnvironment{apiClientMock, "", &nomadApi.Job{}, storage.NewLocalStorage[runner.Runner]()} + environment := &NomadEnvironment{apiClientMock, "", &nomadApi.Job{}, + storage.NewLocalStorage[runner.Runner](), nil, nil} environment.SetID(tests.DefaultEnvironmentIDAsInteger) err := environment.Register() @@ -131,7 +132,8 @@ func TestRegisterTemplateJobSucceedsWhenMonitoringEvaluationSucceeds(t *testing. apiClientMock.On("LoadRunnerIDs", mock.AnythingOfType("string")).Return([]string{}, nil) apiClientMock.On("DeleteJob", mock.AnythingOfType("string")).Return(nil) - environment := &NomadEnvironment{apiClientMock, "", &nomadApi.Job{}, storage.NewLocalStorage[runner.Runner]()} + environment := &NomadEnvironment{apiClientMock, "", &nomadApi.Job{}, + storage.NewLocalStorage[runner.Runner](), nil, nil} environment.SetID(tests.DefaultEnvironmentIDAsInteger) err := environment.Register() @@ -147,7 +149,8 @@ func TestRegisterTemplateJobReturnsErrorWhenMonitoringEvaluationFails(t *testing apiClientMock.On("LoadRunnerIDs", mock.AnythingOfType("string")).Return([]string{}, nil) apiClientMock.On("DeleteJob", mock.AnythingOfType("string")).Return(nil) - environment := &NomadEnvironment{apiClientMock, "", &nomadApi.Job{}, storage.NewLocalStorage[runner.Runner]()} + environment := &NomadEnvironment{apiClientMock, "", &nomadApi.Job{}, + storage.NewLocalStorage[runner.Runner](), nil, nil} environment.SetID(tests.DefaultEnvironmentIDAsInteger) err := environment.Register() @@ -173,7 +176,8 @@ func TestTwoSampleAddExactlyTwoRunners(t *testing.T) { apiMock.On("RegisterRunnerJob", mock.AnythingOfType("*api.Job")).Return(nil) _, job := helpers.CreateTemplateJob() - environment := &NomadEnvironment{apiMock, templateEnvironmentJobHCL, job, storage.NewLocalStorage[runner.Runner]()} + environment := &NomadEnvironment{apiMock, templateEnvironmentJobHCL, job, + storage.NewLocalStorage[runner.Runner](), nil, nil} runner1 := &runner.RunnerMock{} runner1.On("ID").Return(tests.DefaultRunnerID) runner2 := &runner.RunnerMock{} @@ -206,7 +210,8 @@ func TestSampleDoesNotSetForcePullFlag(t *testing.T) { }) _, job := helpers.CreateTemplateJob() - environment := &NomadEnvironment{apiMock, templateEnvironmentJobHCL, job, storage.NewLocalStorage[runner.Runner]()} + environment := &NomadEnvironment{apiMock, templateEnvironmentJobHCL, job, + storage.NewLocalStorage[runner.Runner](), nil, nil} runner1 := &runner.RunnerMock{} runner1.On("ID").Return(tests.DefaultRunnerID) environment.AddRunner(runner1) diff --git a/internal/environment/nomad_manager.go b/internal/environment/nomad_manager.go index 36e80cc..30901e2 100644 --- a/internal/environment/nomad_manager.go +++ b/internal/environment/nomad_manager.go @@ -1,6 +1,7 @@ package environment import ( + "context" _ "embed" "fmt" nomadApi "github.com/hashicorp/nomad/api" @@ -151,13 +152,16 @@ func (m *NomadEnvironmentManager) Load() error { // newNomadEnvironmetFromJob creates a Nomad environment from the passed Nomad job definition. func newNomadEnvironmetFromJob(job *nomadApi.Job, apiClient nomad.ExecutorAPI) *NomadEnvironment { + ctx, cancel := context.WithCancel(context.Background()) e := &NomadEnvironment{ apiClient: apiClient, jobHCL: templateEnvironmentJobHCL, job: job, + ctx: ctx, + cancel: cancel, } e.idleRunners = storage.NewMonitoredLocalStorage[runner.Runner](monitoring.MeasurementIdleRunnerNomad, - runner.MonitorEnvironmentID[runner.Runner](e.ID()), time.Minute) + runner.MonitorEnvironmentID[runner.Runner](e.ID()), time.Minute, ctx) return e } diff --git a/internal/runner/abstract_manager.go b/internal/runner/abstract_manager.go index 4250f3c..cae4dd0 100644 --- a/internal/runner/abstract_manager.go +++ b/internal/runner/abstract_manager.go @@ -1,6 +1,7 @@ package runner import ( + "context" "errors" "fmt" "github.com/influxdata/influxdb-client-go/v2/api/write" @@ -22,12 +23,13 @@ type AbstractManager struct { } // NewAbstractManager creates a new abstract runner manager that keeps track of all runners of one kind. +// Since this manager is currently directly bound to the lifespan of Poseidon, it does not need a context cancel. func NewAbstractManager() *AbstractManager { return &AbstractManager{ environments: storage.NewMonitoredLocalStorage[ExecutionEnvironment]( - monitoring.MeasurementEnvironments, monitorEnvironmentData, 0), + monitoring.MeasurementEnvironments, monitorEnvironmentData, 0, context.Background()), usedRunners: storage.NewMonitoredLocalStorage[Runner]( - monitoring.MeasurementUsedRunner, MonitorRunnersEnvironmentID, time.Hour), + monitoring.MeasurementUsedRunner, MonitorRunnersEnvironmentID, time.Hour, context.Background()), } } diff --git a/internal/runner/aws_runner.go b/internal/runner/aws_runner.go index 24e9d13..890afda 100644 --- a/internal/runner/aws_runner.go +++ b/internal/runner/aws_runner.go @@ -37,6 +37,8 @@ type AWSFunctionWorkload struct { runningExecutions map[execution.ID]context.CancelFunc onDestroy DestroyRunnerHandler environment ExecutionEnvironment + ctx context.Context + cancel context.CancelFunc } // NewAWSFunctionWorkload creates a new AWSFunctionWorkload with the provided id. @@ -47,15 +49,18 @@ func NewAWSFunctionWorkload( return nil, fmt.Errorf("failed generating runner id: %w", err) } + ctx, cancel := context.WithCancel(context.Background()) workload := &AWSFunctionWorkload{ id: newUUID.String(), fs: make(map[dto.FilePath][]byte), runningExecutions: make(map[execution.ID]context.CancelFunc), onDestroy: onDestroy, environment: environment, + ctx: ctx, + cancel: cancel, } workload.executions = storage.NewMonitoredLocalStorage[*dto.ExecutionRequest]( - monitoring.MeasurementExecutionsAWS, monitorExecutionsRunnerID(environment.ID(), workload.id), time.Minute) + monitoring.MeasurementExecutionsAWS, monitorExecutionsRunnerID(environment.ID(), workload.id), time.Minute, ctx) workload.InactivityTimer = NewInactivityTimer(workload, func(_ Runner) error { return workload.Destroy() }) @@ -92,7 +97,7 @@ func (w *AWSFunctionWorkload) ExecuteInteractively(id string, _ io.ReadWriter, s } hideEnvironmentVariables(request, "AWS") request.PrivilegedExecution = true // AWS does not support multiple users at this moment. - command, ctx, cancel := prepareExecution(request) + command, ctx, cancel := prepareExecution(request, w.ctx) exitInternal := make(chan ExitInfo) exit := make(chan ExitInfo, 1) @@ -131,9 +136,7 @@ func (w *AWSFunctionWorkload) GetFileContent(_ string, _ http.ResponseWriter, _ } func (w *AWSFunctionWorkload) Destroy() error { - for _, cancel := range w.runningExecutions { - cancel() - } + w.cancel() if err := w.onDestroy(w); err != nil { return fmt.Errorf("error while destroying aws runner: %w", err) } diff --git a/internal/runner/nomad_runner.go b/internal/runner/nomad_runner.go index b6383db..faf8309 100644 --- a/internal/runner/nomad_runner.go +++ b/internal/runner/nomad_runner.go @@ -47,6 +47,8 @@ type NomadJob struct { portMappings []nomadApi.PortMapping api nomad.ExecutorAPI onDestroy DestroyRunnerHandler + ctx context.Context + cancel context.CancelFunc } // NewNomadJob creates a new NomadJob with the provided id. @@ -55,14 +57,17 @@ type NomadJob struct { func NewNomadJob(id string, portMappings []nomadApi.PortMapping, apiClient nomad.ExecutorAPI, onDestroy DestroyRunnerHandler, ) *NomadJob { + ctx, cancel := context.WithCancel(context.Background()) job := &NomadJob{ id: id, portMappings: portMappings, api: apiClient, onDestroy: onDestroy, + ctx: ctx, + cancel: cancel, } job.executions = storage.NewMonitoredLocalStorage[*dto.ExecutionRequest]( - monitoring.MeasurementExecutionsNomad, monitorExecutionsRunnerID(job.Environment(), id), time.Minute) + monitoring.MeasurementExecutionsNomad, monitorExecutionsRunnerID(job.Environment(), id), time.Minute, ctx) job.InactivityTimer = NewInactivityTimer(job, onDestroy) return job } @@ -111,10 +116,10 @@ func (r *NomadJob) ExecuteInteractively( r.ResetTimeout() - command, ctx, cancel := prepareExecution(request) + command, ctx, cancel := prepareExecution(request, r.ctx) exitInternal := make(chan ExitInfo) exit := make(chan ExitInfo, 1) - ctxExecute, cancelExecute := context.WithCancel(context.Background()) + ctxExecute, cancelExecute := context.WithCancel(r.ctx) go r.executeCommand(ctxExecute, command, request.PrivilegedExecution, stdin, stdout, stderr, exitInternal) go r.handleExitOrContextDone(ctx, cancelExecute, exitInternal, exit, stdin) @@ -203,20 +208,21 @@ func (r *NomadJob) GetFileContent( } func (r *NomadJob) Destroy() error { + r.cancel() if err := r.onDestroy(r); err != nil { return fmt.Errorf("error while destroying runner: %w", err) } return nil } -func prepareExecution(request *dto.ExecutionRequest) ( +func prepareExecution(request *dto.ExecutionRequest, environmentCtx context.Context) ( command []string, ctx context.Context, cancel context.CancelFunc, ) { command = request.FullCommand() if request.TimeLimit == 0 { - ctx, cancel = context.WithCancel(context.Background()) + ctx, cancel = context.WithCancel(environmentCtx) } else { - ctx, cancel = context.WithTimeout(context.Background(), time.Duration(request.TimeLimit)*time.Second) + ctx, cancel = context.WithTimeout(environmentCtx, time.Duration(request.TimeLimit)*time.Second) } return command, ctx, cancel } diff --git a/internal/runner/nomad_runner_test.go b/internal/runner/nomad_runner_test.go index d5dfa13..a243c18 100644 --- a/internal/runner/nomad_runner_test.go +++ b/internal/runner/nomad_runner_test.go @@ -127,6 +127,7 @@ func (s *ExecuteInteractivelyTestSuite) SetupTest() { id: tests.DefaultRunnerID, api: s.apiMock, onDestroy: s.manager.Return, + ctx: context.Background(), } } @@ -207,6 +208,7 @@ func (s *ExecuteInteractivelyTestSuite) TestDestroysRunnerAfterTimeoutAndSignal( }) timeLimit := 1 executionRequest := &dto.ExecutionRequest{TimeLimit: timeLimit} + s.runner.cancel = func() {} s.runner.StoreExecution(defaultExecutionID, executionRequest) _, _, err := s.runner.ExecuteInteractively(defaultExecutionID, bytes.NewBuffer(make([]byte, 1)), nil, nil) s.Require().NoError(err) diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index b0c0c5b..a22610c 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -75,14 +75,14 @@ func NewLocalStorage[T any]() *localStorage[T] { // Iff callback is set, it will be called on a write operation. // Iff additionalEvents not zero, the duration will be used to periodically send additional monitoring events. func NewMonitoredLocalStorage[T any]( - measurement string, callback WriteCallback[T], additionalEvents time.Duration) *localStorage[T] { + measurement string, callback WriteCallback[T], additionalEvents time.Duration, ctx context.Context) *localStorage[T] { s := &localStorage[T]{ objects: make(map[string]T), measurement: measurement, callback: callback, } if additionalEvents != 0 { - go s.periodicallySendMonitoringData(additionalEvents) + go s.periodicallySendMonitoringData(additionalEvents, ctx) } return s } @@ -172,8 +172,7 @@ func (s *localStorage[T]) sendMonitoringData(id string, o T, eventType EventType } } -func (s *localStorage[T]) periodicallySendMonitoringData(d time.Duration) { - ctx := context.Background() +func (s *localStorage[T]) periodicallySendMonitoringData(d time.Duration, ctx context.Context) { for ctx.Err() == nil { stub := new(T) s.sendMonitoringData("", *stub, Periodically, s.Length()) diff --git a/pkg/storage/storage_test.go b/pkg/storage/storage_test.go index c08d52b..46b5e6e 100644 --- a/pkg/storage/storage_test.go +++ b/pkg/storage/storage_test.go @@ -1,6 +1,7 @@ package storage import ( + "context" "github.com/influxdata/influxdb-client-go/v2/api/write" "github.com/openHPI/poseidon/tests" "github.com/stretchr/testify/assert" @@ -123,7 +124,7 @@ func TestNewMonitoredLocalStorage_Callback(t *testing.T) { } else if eventType == Creation { callbackAdditions++ } - }, 0) + }, 0, context.Background()) assertCallbackCounts := func(test func(), totalCalls, additions, deletions int) { beforeTotal := callbackCalls @@ -174,11 +175,13 @@ func TestNewMonitoredLocalStorage_Callback(t *testing.T) { } func TestNewMonitoredLocalStorage_Periodically(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() callbackCalls := 0 NewMonitoredLocalStorage[string]("testMeasurement", func(p *write.Point, o string, eventType EventType) { callbackCalls++ assert.Equal(t, Periodically, eventType) - }, 200*time.Millisecond) + }, 200*time.Millisecond, ctx) time.Sleep(tests.ShortTimeout) assert.Equal(t, 1, callbackCalls)