diff --git a/internal/api/runners.go b/internal/api/runners.go index 9ed4a6e..ca411e6 100644 --- a/internal/api/runners.go +++ b/internal/api/runners.go @@ -8,6 +8,7 @@ import ( "github.com/openHPI/poseidon/internal/config" "github.com/openHPI/poseidon/internal/runner" "github.com/openHPI/poseidon/pkg/dto" + "io" "net/http" "net/url" ) @@ -131,6 +132,12 @@ func (r *RunnerController) findRunnerMiddleware(next http.Handler) http.Handler runnerID := mux.Vars(request)[RunnerIDKey] targetRunner, err := r.manager.Get(runnerID) if err != nil { + // We discard the request body because an early write causes errors for some clients. + // See https://github.com/openHPI/poseidon/issues/54 + _, readErr := io.ReadAll(request.Body) + if readErr != nil { + log.WithError(readErr).Warn("Failed to discard the request body") + } writeNotFound(writer, err) return } diff --git a/internal/api/runners_test.go b/internal/api/runners_test.go index 2f6c5f0..5e06a26 100644 --- a/internal/api/runners_test.go +++ b/internal/api/runners_test.go @@ -17,6 +17,8 @@ import ( "testing" ) +const invalidID = "some-invalid-runner-id" + type MiddlewareTestSuite struct { suite.Suite manager *runner.ManagerMock @@ -67,7 +69,6 @@ func (s *MiddlewareTestSuite) TestFindRunnerMiddlewareIfRunnerExists() { } func (s *MiddlewareTestSuite) TestFindRunnerMiddlewareIfRunnerDoesNotExist() { - invalidID := "some-invalid-runner-id" s.manager.On("Get", invalidID).Return(nil, runner.ErrRunnerNotFound) recorder := httptest.NewRecorder() @@ -76,6 +77,22 @@ func (s *MiddlewareTestSuite) TestFindRunnerMiddlewareIfRunnerDoesNotExist() { s.Equal(http.StatusNotFound, recorder.Code) } +func (s *MiddlewareTestSuite) TestFindRunnerMiddlewareDoesNotEarlyRespond() { + body := strings.NewReader(strings.Repeat("A", 798968)) + + path, err := s.router.Get("test-runner-id").URL(RunnerIDKey, invalidID) + s.Require().NoError(err) + request, err := http.NewRequest(http.MethodPost, path.String(), body) + s.Require().NoError(err) + + s.manager.On("Get", mock.AnythingOfType("string")).Return(nil, runner.ErrRunnerNotFound) + recorder := httptest.NewRecorder() + s.router.ServeHTTP(recorder, request) + + s.Equal(http.StatusNotFound, recorder.Code) + s.Equal(0, body.Len()) // No data should be unread +} + func TestRunnerRouteTestSuite(t *testing.T) { suite.Run(t, new(RunnerRouteTestSuite)) } @@ -261,7 +278,6 @@ func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemReturnsBadRequestOn } func (s *UpdateFileSystemRouteTestSuite) TestUpdateFileSystemToNonExistingRunnerReturnsNotFound() { - invalidID := "some-invalid-runner-id" s.runnerManager.On("Get", invalidID).Return(nil, runner.ErrRunnerNotFound) path, err := s.router.Get(UpdateFileSystemPath).URL(RunnerIDKey, invalidID) s.Require().NoError(err)