diff --git a/cmd/poseidon/main.go b/cmd/poseidon/main.go index 711e672..d8b3049 100644 --- a/cmd/poseidon/main.go +++ b/cmd/poseidon/main.go @@ -196,16 +196,28 @@ func initServer() *http.Server { // shutdownOnOSSignal listens for a signal from the operating system // When receiving a signal the server shuts down but waits up to 15 seconds to close remaining connections. -func shutdownOnOSSignal(server *http.Server, ctx context.Context) { +func shutdownOnOSSignal(server *http.Server, ctx context.Context, stopProfiling func()) { // wait for SIGINT - signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + shutdownSignals := make(chan os.Signal, 1) + signal.Notify(shutdownSignals, syscall.SIGINT, syscall.SIGTERM) + + // wait for SIGUSR1 + writeProfileSignal := make(chan os.Signal, 1) + signal.Notify(writeProfileSignal, syscall.SIGUSR1) + select { case <-ctx.Done(): os.Exit(1) - case <-signals: + case <-writeProfileSignal: + log.Info("Received SIGUSR1 ...") + + stopProfiling() + // Continue listening on signals and replace `stopProfiling` with an empty function + shutdownOnOSSignal(server, ctx, func() {}) + case <-shutdownSignals: log.Info("Received SIGINT, shutting down ...") + defer stopProfiling() ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownWait) defer cancel() if err := server.Shutdown(ctx); err != nil { @@ -225,10 +237,9 @@ func main() { defer cancelInflux() stopProfiling := initProfiling(config.Config.Profiling) - defer stopProfiling() ctx, cancel := context.WithCancel(context.Background()) server := initServer() go runServer(server, cancel) - shutdownOnOSSignal(server, ctx) + shutdownOnOSSignal(server, ctx, stopProfiling) } diff --git a/cmd/poseidon/main_test.go b/cmd/poseidon/main_test.go index d0d46b7..b31f678 100644 --- a/cmd/poseidon/main_test.go +++ b/cmd/poseidon/main_test.go @@ -1,10 +1,15 @@ package main import ( + "context" "github.com/openHPI/poseidon/internal/environment" "github.com/openHPI/poseidon/internal/runner" + "github.com/openHPI/poseidon/tests" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "syscall" "testing" + "time" ) func TestAWSDisabledUsesNomadManager(t *testing.T) { @@ -24,3 +29,19 @@ func TestAWSEnabledWrappesNomadManager(t *testing.T) { assert.NotEqual(t, runnerManager, awsRunnerManager) assert.NotEqual(t, environmentManager, awsEnvironmentManager) } + +func TestShutdownOnOSSignal_Profiling(t *testing.T) { + called := false + + server := initServer() + go shutdownOnOSSignal(server, context.Background(), func() { + called = true + }) + + <-time.After(tests.ShortTimeout) + err := syscall.Kill(syscall.Getpid(), syscall.SIGUSR1) + require.NoError(t, err) + <-time.After(tests.ShortTimeout) + + assert.True(t, called) +}