From 09604997a71f02d74ec73bbbb5bb40074307bf1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Pa=C3=9F?= <22845248+mpass99@users.noreply.github.com> Date: Mon, 21 Aug 2023 17:16:25 +0200 Subject: [PATCH] Implement MergeContext that has multiple contexts as parent and chooses the earliest deadline. --- pkg/util/merge_context.go | 67 +++++++++++++++++++++++++++++++++ pkg/util/merge_context_test.go | 69 ++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 pkg/util/merge_context.go create mode 100644 pkg/util/merge_context_test.go diff --git a/pkg/util/merge_context.go b/pkg/util/merge_context.go new file mode 100644 index 0000000..9a44ec1 --- /dev/null +++ b/pkg/util/merge_context.go @@ -0,0 +1,67 @@ +package util + +import ( + "context" + "fmt" + "reflect" + "time" +) + +// mergeContext combines multiple contexts. +type mergeContext struct { + contexts []context.Context +} + +func NewMergeContext(contexts []context.Context) context.Context { + return mergeContext{contexts: contexts} +} + +// Deadline returns the earliest Deadline of all contexts. +func (m mergeContext) Deadline() (deadline time.Time, ok bool) { + for _, ctx := range m.contexts { + if anotherDeadline, anotherOk := ctx.Deadline(); anotherOk { + if ok && anotherDeadline.After(deadline) { + continue + } + deadline = anotherDeadline + ok = anotherOk + } + } + return deadline, ok +} + +// Done notifies when the first context is done. +func (m mergeContext) Done() <-chan struct{} { + ch := make(chan struct{}) + cases := make([]reflect.SelectCase, 0, len(m.contexts)) + for _, ctx := range m.contexts { + cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())}) + } + go func(cases []reflect.SelectCase, ch chan struct{}) { + _, _, _ = reflect.Select(cases) + ch <- struct{}{} + close(ch) + }(cases, ch) + return ch +} + +// Err returns the error of any (random) context and nil iff no context has an error. +func (m mergeContext) Err() error { + for _, ctx := range m.contexts { + if ctx.Err() != nil { + return fmt.Errorf("mergeContext wrapped: %w", ctx.Err()) + } + } + return nil +} + +// Value returns the value for the key if any context has it. +// If multiple contexts have a value for the key, the result is any (random) of them. +func (m mergeContext) Value(key any) any { + for _, ctx := range m.contexts { + if value := ctx.Value(key); value != nil { + return value + } + } + return nil +} diff --git a/pkg/util/merge_context_test.go b/pkg/util/merge_context_test.go new file mode 100644 index 0000000..3dabf34 --- /dev/null +++ b/pkg/util/merge_context_test.go @@ -0,0 +1,69 @@ +package util + +import ( + "context" + "github.com/openHPI/poseidon/pkg/dto" + "github.com/openHPI/poseidon/tests" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestMergeContext_Deadline(t *testing.T) { + ctxWithoutDeadline := context.Background() + earlyDeadline := time.Now().Add(time.Second) + ctxWithEarlyDeadline, cancel := context.WithDeadline(context.Background(), earlyDeadline) + defer cancel() + ctxWithLateDeadline, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Hour)) + defer cancel() + + ctx := NewMergeContext([]context.Context{ctxWithoutDeadline, ctxWithEarlyDeadline, ctxWithLateDeadline}) + deadline, ok := ctx.Deadline() + + assert.True(t, ok) + assert.Equal(t, earlyDeadline, deadline, "The ealiest deadline is returned") +} + +func TestMergeContext_Done(t *testing.T) { + ctxWithoutDeadline := context.Background() + ctxWithEarlyDeadline, cancel := context.WithTimeout(context.Background(), 2*tests.ShortTimeout) + defer cancel() + ctxWithLateDeadline, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + ctx := NewMergeContext([]context.Context{ctxWithoutDeadline, ctxWithEarlyDeadline, ctxWithLateDeadline}) + + select { + case <-ctx.Done(): + assert.Fail(t, "mergeContext is done before any of its parents") + return + case <-time.After(tests.ShortTimeout): + } + + select { + case <-ctx.Done(): + case <-time.After(3 * tests.ShortTimeout): + assert.Fail(t, "mergeContext is not done after the earliest of its parents") + return + } +} + +func TestMergeContext_Err(t *testing.T) { + ctxWithoutDeadline := context.Background() + ctxCancelled, cancel := context.WithCancel(context.Background()) + ctx := NewMergeContext([]context.Context{ctxWithoutDeadline, ctxCancelled}) + + assert.NoError(t, ctx.Err()) + cancel() + assert.Error(t, ctx.Err()) +} + +func TestMergeContext_Value(t *testing.T) { + ctxWithAValue := context.WithValue(context.Background(), dto.ContextKey("keyA"), "valueA") + ctxWithAnotherValue := context.WithValue(context.Background(), dto.ContextKey("keyB"), "valueB") + ctx := NewMergeContext([]context.Context{ctxWithAValue, ctxWithAnotherValue}) + + assert.Equal(t, "valueA", ctx.Value(dto.ContextKey("keyA"))) + assert.Equal(t, "valueB", ctx.Value(dto.ContextKey("keyB"))) + assert.Nil(t, ctx.Value("keyC")) +}