diff --git a/cmd/poseidon/main.go b/cmd/poseidon/main.go index 4389033..75ee82e 100644 --- a/cmd/poseidon/main.go +++ b/cmd/poseidon/main.go @@ -135,49 +135,51 @@ func runServer(server *http.Server, cancel context.CancelFunc) { } } -type managerCreator func() (runnerManager runner.Manager, environmentManager environment.ManagerHandler) +type managerCreator func(ctx context.Context) ( + runnerManager runner.Manager, environmentManager environment.ManagerHandler) // createManagerHandler adds the managers of the passed managerCreator to the chain of responsibility. func createManagerHandler(handler managerCreator, enabled bool, - nextRunnerManager runner.Manager, nextEnvironmentManager environment.ManagerHandler) ( + nextRunnerManager runner.Manager, nextEnvironmentManager environment.ManagerHandler, ctx context.Context) ( runnerManager runner.Manager, environmentManager environment.ManagerHandler) { if !enabled { return nextRunnerManager, nextEnvironmentManager } - runnerManager, environmentManager = handler() + runnerManager, environmentManager = handler(ctx) runnerManager.SetNextHandler(nextRunnerManager) environmentManager.SetNextHandler(nextEnvironmentManager) return runnerManager, environmentManager } -func createNomadManager() (runnerManager runner.Manager, environmentManager environment.ManagerHandler) { +func createNomadManager(ctx context.Context) ( + runnerManager runner.Manager, environmentManager environment.ManagerHandler) { // API initialization nomadAPIClient, err := nomad.NewExecutorAPI(&config.Config.Nomad) if err != nil { log.WithError(err).WithField("nomad config", config.Config.Nomad).Fatal("Error creating Nomad API client") } - runnerManager = runner.NewNomadRunnerManager(nomadAPIClient, context.Background()) + runnerManager = runner.NewNomadRunnerManager(nomadAPIClient, ctx) environmentManager, err = environment. - NewNomadEnvironmentManager(runnerManager, nomadAPIClient, config.Config.Server.TemplateJobFile) + NewNomadEnvironmentManager(runnerManager, nomadAPIClient, config.Config.Server.TemplateJobFile, ctx) if err != nil { log.WithError(err).Fatal("Error initializing environment manager") } return runnerManager, environmentManager } -func createAWSManager() (runnerManager runner.Manager, environmentManager environment.ManagerHandler) { +func createAWSManager(_ context.Context) (runnerManager runner.Manager, environmentManager environment.ManagerHandler) { runnerManager = runner.NewAWSRunnerManager() return runnerManager, environment.NewAWSEnvironmentManager(runnerManager) } // initServer builds the http server and configures it with the chain of responsibility for multiple managers. -func initServer() *http.Server { +func initServer(ctx context.Context) *http.Server { runnerManager, environmentManager := createManagerHandler(createNomadManager, config.Config.Nomad.Enabled, - nil, nil) + nil, nil, ctx) runnerManager, environmentManager = createManagerHandler(createAWSManager, config.Config.AWS.Enabled, - runnerManager, environmentManager) + runnerManager, environmentManager, ctx) handler := api.NewRouter(runnerManager, environmentManager) sentryHandler := sentryhttp.New(sentryhttp.Options{}).Handle(handler) @@ -239,7 +241,7 @@ func main() { stopProfiling := initProfiling(config.Config.Profiling) ctx, cancel := context.WithCancel(context.Background()) - server := initServer() + server := initServer(ctx) go runServer(server, cancel) shutdownOnOSSignal(server, ctx, stopProfiling) } diff --git a/cmd/poseidon/main_test.go b/cmd/poseidon/main_test.go index b31f678..c7acf3f 100644 --- a/cmd/poseidon/main_test.go +++ b/cmd/poseidon/main_test.go @@ -13,27 +13,35 @@ import ( ) func TestAWSDisabledUsesNomadManager(t *testing.T) { + disableRecovery, cancel := context.WithCancel(context.Background()) + cancel() + runnerManager, environmentManager := createManagerHandler(createNomadManager, true, - runner.NewAbstractManager(), &environment.AbstractManager{}) + runner.NewAbstractManager(), &environment.AbstractManager{}, disableRecovery) awsRunnerManager, awsEnvironmentManager := createManagerHandler(createAWSManager, false, - runnerManager, environmentManager) + runnerManager, environmentManager, disableRecovery) assert.Equal(t, runnerManager, awsRunnerManager) assert.Equal(t, environmentManager, awsEnvironmentManager) } func TestAWSEnabledWrappesNomadManager(t *testing.T) { + disableRecovery, cancel := context.WithCancel(context.Background()) + cancel() + runnerManager, environmentManager := createManagerHandler(createNomadManager, true, - runner.NewAbstractManager(), &environment.AbstractManager{}) + runner.NewAbstractManager(), &environment.AbstractManager{}, disableRecovery) awsRunnerManager, awsEnvironmentManager := createManagerHandler(createAWSManager, - true, runnerManager, environmentManager) + true, runnerManager, environmentManager, disableRecovery) assert.NotEqual(t, runnerManager, awsRunnerManager) assert.NotEqual(t, environmentManager, awsEnvironmentManager) } func TestShutdownOnOSSignal_Profiling(t *testing.T) { called := false + disableRecovery, cancel := context.WithCancel(context.Background()) + cancel() - server := initServer() + server := initServer(disableRecovery) go shutdownOnOSSignal(server, context.Background(), func() { called = true }) diff --git a/internal/environment/nomad_environment.go b/internal/environment/nomad_environment.go index 0d40781..0a6e96d 100644 --- a/internal/environment/nomad_environment.go +++ b/internal/environment/nomad_environment.go @@ -245,9 +245,7 @@ func (n *NomadEnvironment) Sample() (runner.Runner, bool) { r, ok := n.idleRunners.Sample() if ok && n.idleRunners.Length() < n.PrewarmingPoolSize() { go func() { - err := util.RetryExponential(func() error { - return n.createRunner(false) - }) + err := util.RetryExponentialContext(n.ctx, func() error { return n.createRunner(false) }) if err != nil { log.WithError(err).WithField(dto.KeyEnvironmentID, n.ID().ToString()). Error("Couldn't create new runner for claimed one") diff --git a/internal/environment/nomad_environment_test.go b/internal/environment/nomad_environment_test.go index f94e22c..fdb9bb5 100644 --- a/internal/environment/nomad_environment_test.go +++ b/internal/environment/nomad_environment_test.go @@ -1,6 +1,7 @@ package environment import ( + "context" "fmt" nomadApi "github.com/hashicorp/nomad/api" "github.com/openHPI/poseidon/internal/nomad" @@ -18,7 +19,7 @@ import ( func TestConfigureNetworkCreatesNewNetworkWhenNoNetworkExists(t *testing.T) { _, job := helpers.CreateTemplateJob() defaultTaskGroup := nomad.FindAndValidateDefaultTaskGroup(job) - environment := &NomadEnvironment{nil, "", job, nil, nil, nil} + environment := &NomadEnvironment{nil, "", job, nil, context.Background(), nil} if assert.Equal(t, 0, len(defaultTaskGroup.Networks)) { environment.SetNetworkAccess(true, []uint16{}) @@ -30,7 +31,7 @@ func TestConfigureNetworkCreatesNewNetworkWhenNoNetworkExists(t *testing.T) { func TestConfigureNetworkDoesNotCreateNewNetworkWhenNetworkExists(t *testing.T) { _, job := helpers.CreateTemplateJob() defaultTaskGroup := nomad.FindAndValidateDefaultTaskGroup(job) - environment := &NomadEnvironment{nil, "", job, nil, nil, nil} + environment := &NomadEnvironment{nil, "", job, nil, context.Background(), nil} networkResource := &nomadApi.NetworkResource{Mode: "cni/secure-bridge"} defaultTaskGroup.Networks = []*nomadApi.NetworkResource{networkResource} @@ -59,7 +60,7 @@ func TestConfigureNetworkSetsCorrectValues(t *testing.T) { _, testJob := helpers.CreateTemplateJob() testTaskGroup := nomad.FindAndValidateDefaultTaskGroup(testJob) testTask := nomad.FindAndValidateDefaultTask(testTaskGroup) - testEnvironment := &NomadEnvironment{nil, "", job, nil, nil, nil} + testEnvironment := &NomadEnvironment{nil, "", job, nil, context.Background(), nil} testEnvironment.SetNetworkAccess(false, ports) mode, ok := testTask.Config["network_mode"] @@ -74,7 +75,7 @@ func TestConfigureNetworkSetsCorrectValues(t *testing.T) { _, testJob := helpers.CreateTemplateJob() testTaskGroup := nomad.FindAndValidateDefaultTaskGroup(testJob) testTask := nomad.FindAndValidateDefaultTask(testTaskGroup) - testEnvironment := &NomadEnvironment{nil, "", testJob, nil, nil, nil} + testEnvironment := &NomadEnvironment{nil, "", testJob, nil, context.Background(), nil} testEnvironment.SetNetworkAccess(true, ports) require.Equal(t, 1, len(testTaskGroup.Networks)) @@ -133,7 +134,7 @@ func TestRegisterTemplateJobSucceedsWhenMonitoringEvaluationSucceeds(t *testing. apiClientMock.On("DeleteJob", mock.AnythingOfType("string")).Return(nil) environment := &NomadEnvironment{apiClientMock, "", &nomadApi.Job{}, - storage.NewLocalStorage[runner.Runner](), nil, nil} + storage.NewLocalStorage[runner.Runner](), context.Background(), nil} environment.SetID(tests.DefaultEnvironmentIDAsInteger) err := environment.Register() @@ -150,7 +151,7 @@ func TestRegisterTemplateJobReturnsErrorWhenMonitoringEvaluationFails(t *testing apiClientMock.On("DeleteJob", mock.AnythingOfType("string")).Return(nil) environment := &NomadEnvironment{apiClientMock, "", &nomadApi.Job{}, - storage.NewLocalStorage[runner.Runner](), nil, nil} + storage.NewLocalStorage[runner.Runner](), context.Background(), nil} environment.SetID(tests.DefaultEnvironmentIDAsInteger) err := environment.Register() @@ -177,7 +178,7 @@ func TestTwoSampleAddExactlyTwoRunners(t *testing.T) { _, job := helpers.CreateTemplateJob() environment := &NomadEnvironment{apiMock, templateEnvironmentJobHCL, job, - storage.NewLocalStorage[runner.Runner](), nil, nil} + storage.NewLocalStorage[runner.Runner](), context.Background(), nil} environment.SetPrewarmingPoolSize(2) runner1 := &runner.RunnerMock{} runner1.On("ID").Return(tests.DefaultRunnerID) @@ -212,7 +213,7 @@ func TestSampleDoesNotSetForcePullFlag(t *testing.T) { _, job := helpers.CreateTemplateJob() environment := &NomadEnvironment{apiMock, templateEnvironmentJobHCL, job, - storage.NewLocalStorage[runner.Runner](), nil, nil} + storage.NewLocalStorage[runner.Runner](), context.Background(), nil} runner1 := &runner.RunnerMock{} runner1.On("ID").Return(tests.DefaultRunnerID) environment.AddRunner(runner1) diff --git a/internal/environment/nomad_manager.go b/internal/environment/nomad_manager.go index 7874e68..0ac67ff 100644 --- a/internal/environment/nomad_manager.go +++ b/internal/environment/nomad_manager.go @@ -36,6 +36,7 @@ func NewNomadEnvironmentManager( runnerManager runner.Manager, apiClient nomad.ExecutorAPI, templateJobFile string, + ctx context.Context, ) (*NomadEnvironmentManager, error) { if err := loadTemplateEnvironmentJobHCL(templateJobFile); err != nil { return nil, err @@ -43,7 +44,7 @@ func NewNomadEnvironmentManager( m := &NomadEnvironmentManager{&AbstractManager{nil, runnerManager}, apiClient, templateEnvironmentJobHCL} - if err := util.RetryExponential(func() error { return m.Load() }); err != nil { + if err := util.RetryExponentialContext(ctx, func() error { return m.Load() }); err != nil { log.WithError(err).Error("Error recovering the execution environments") } runnerManager.Load() diff --git a/internal/environment/nomad_manager_test.go b/internal/environment/nomad_manager_test.go index cb3093a..86c556d 100644 --- a/internal/environment/nomad_manager_test.go +++ b/internal/environment/nomad_manager_test.go @@ -95,6 +95,9 @@ func (s *CreateOrUpdateTestSuite) TestCreateOrUpdatesSetsForcePullFlag() { } func TestNewNomadEnvironmentManager(t *testing.T) { + disableRecovery, cancel := context.WithCancel(context.Background()) + cancel() + executorAPIMock := &nomad.ExecutorAPIMock{} executorAPIMock.On("LoadEnvironmentJobs").Return([]*nomadApi.Job{}, nil) @@ -104,7 +107,7 @@ func TestNewNomadEnvironmentManager(t *testing.T) { previousTemplateEnvironmentJobHCL := templateEnvironmentJobHCL t.Run("returns error if template file does not exist", func(t *testing.T) { - _, err := NewNomadEnvironmentManager(runnerManagerMock, executorAPIMock, "/non-existent/file") + _, err := NewNomadEnvironmentManager(runnerManagerMock, executorAPIMock, "/non-existent/file", disableRecovery) assert.Error(t, err) }) @@ -115,7 +118,7 @@ func TestNewNomadEnvironmentManager(t *testing.T) { f := createTempFile(t, templateJobHCL) defer os.Remove(f.Name()) - m, err := NewNomadEnvironmentManager(runnerManagerMock, executorAPIMock, f.Name()) + m, err := NewNomadEnvironmentManager(runnerManagerMock, executorAPIMock, f.Name(), disableRecovery) assert.NoError(t, err) assert.NotNil(t, m) assert.Equal(t, templateJobHCL, m.templateEnvironmentHCL) @@ -126,7 +129,7 @@ func TestNewNomadEnvironmentManager(t *testing.T) { f := createTempFile(t, templateJobHCL) defer os.Remove(f.Name()) - m, err := NewNomadEnvironmentManager(runnerManagerMock, executorAPIMock, f.Name()) + m, err := NewNomadEnvironmentManager(runnerManagerMock, executorAPIMock, f.Name(), disableRecovery) require.NoError(t, err) _, err = NewNomadEnvironment(tests.DefaultEnvironmentIDAsInteger, nil, m.templateEnvironmentHCL) assert.Error(t, err) @@ -136,6 +139,9 @@ func TestNewNomadEnvironmentManager(t *testing.T) { } func TestNomadEnvironmentManager_Get(t *testing.T) { + disableRecovery, cancel := context.WithCancel(context.Background()) + cancel() + apiMock := &nomad.ExecutorAPIMock{} mockWatchAllocations(apiMock) apiMock.On("LoadRunnerIDs", mock.AnythingOfType("string")).Return([]string{}, nil) @@ -146,7 +152,7 @@ func TestNomadEnvironmentManager_Get(t *testing.T) { }) runnerManager := runner.NewNomadRunnerManager(apiMock, context.Background()) - m, err := NewNomadEnvironmentManager(runnerManager, apiMock, "") + m, err := NewNomadEnvironmentManager(runnerManager, apiMock, "", disableRecovery) require.NoError(t, err) t.Run("Returns error when not found", func(t *testing.T) { @@ -217,6 +223,9 @@ func TestNomadEnvironmentManager_Get(t *testing.T) { } func TestNomadEnvironmentManager_List(t *testing.T) { + disableRecovery, cancel := context.WithCancel(context.Background()) + cancel() + apiMock := &nomad.ExecutorAPIMock{} mockWatchAllocations(apiMock) call := apiMock.On("LoadEnvironmentJobs") @@ -225,7 +234,7 @@ func TestNomadEnvironmentManager_List(t *testing.T) { }) runnerManager := runner.NewNomadRunnerManager(apiMock, context.Background()) - m, err := NewNomadEnvironmentManager(runnerManager, apiMock, "") + m, err := NewNomadEnvironmentManager(runnerManager, apiMock, "", disableRecovery) require.NoError(t, err) t.Run("with no environments", func(t *testing.T) { @@ -287,7 +296,7 @@ func TestNomadEnvironmentManager_Load(t *testing.T) { _, ok := runnerManager.GetEnvironment(tests.DefaultEnvironmentIDAsInteger) require.False(t, ok) - _, err := NewNomadEnvironmentManager(runnerManager, apiMock, "") + _, err := NewNomadEnvironmentManager(runnerManager, apiMock, "", context.Background()) require.NoError(t, err) environment, ok := runnerManager.GetEnvironment(tests.DefaultEnvironmentIDAsInteger) @@ -305,7 +314,7 @@ func TestNomadEnvironmentManager_Load(t *testing.T) { _, ok := runnerManager.GetEnvironment(tests.DefaultEnvironmentIDAsInteger) require.False(t, ok) - _, err := NewNomadEnvironmentManager(runnerManager, apiMock, "") + _, err := NewNomadEnvironmentManager(runnerManager, apiMock, "", context.Background()) require.NoError(t, err) _, ok = runnerManager.GetEnvironment(tests.DefaultEnvironmentIDAsInteger) diff --git a/pkg/util/util.go b/pkg/util/util.go index a27ed7b..578b8c2 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -1,6 +1,8 @@ package util import ( + "context" + "fmt" "github.com/openHPI/poseidon/pkg/logging" "time" ) @@ -13,13 +15,15 @@ var ( InitialWaitingDuration = time.Second ) -// RetryExponentialAttempts executes the passed function +// RetryExponentialAttemptsContext executes the passed function // with exponentially increasing time in between starting at the passed sleep duration -// up to a maximum of attempts tries. -func RetryExponentialAttempts(attempts int, sleep time.Duration, f func() error) (err error) { +// up to a maximum of attempts tries as long as the context is not done. +func RetryExponentialAttemptsContext( + ctx context.Context, attempts int, sleep time.Duration, f func() error) (err error) { for i := 0; i < attempts; i++ { - err = f() - if err == nil { + if ctx.Err() != nil { + return fmt.Errorf("stopped retry due to: %w", ctx.Err()) + } else if err = f(); err == nil { return nil } else { log.WithField("count", i).WithError(err).Debug("retrying after error") @@ -30,8 +34,12 @@ func RetryExponentialAttempts(attempts int, sleep time.Duration, f func() error) return err } +func RetryExponentialContext(ctx context.Context, f func() error) error { + return RetryExponentialAttemptsContext(ctx, MaxConnectionRetriesExponential, InitialWaitingDuration, f) +} + func RetryExponentialDuration(sleep time.Duration, f func() error) error { - return RetryExponentialAttempts(MaxConnectionRetriesExponential, sleep, f) + return RetryExponentialAttemptsContext(context.Background(), MaxConnectionRetriesExponential, sleep, f) } func RetryExponential(f func() error) error {