diff --git a/internal/api/runners.go b/internal/api/runners.go index afeb10a..7579db2 100644 --- a/internal/api/runners.go +++ b/internal/api/runners.go @@ -139,7 +139,7 @@ func (r *RunnerController) fileContent(writer http.ResponseWriter, request *http privilegedExecution = false } - writer.Header().Set("Content-Type", "application/octet-stream") + writer.Header().Set("Content-Disposition", "attachment; filename=\""+path+"\"") err = targetRunner.GetFileContent(path, writer, privilegedExecution, request.Context()) if errors.Is(err, runner.ErrFileNotFound) { writeClientError(writer, err, http.StatusFailedDependency) diff --git a/internal/runner/aws_runner.go b/internal/runner/aws_runner.go index 6c1e13c..24e9d13 100644 --- a/internal/runner/aws_runner.go +++ b/internal/runner/aws_runner.go @@ -13,6 +13,7 @@ import ( "github.com/openHPI/poseidon/pkg/monitoring" "github.com/openHPI/poseidon/pkg/storage" "io" + "net/http" "time" ) @@ -125,7 +126,7 @@ func (w *AWSFunctionWorkload) UpdateFileSystem(request *dto.UpdateFileSystemRequ // GetFileContent is currently not supported with this aws serverless function. // This is because the function execution ends with the termination of the workload code. // So an on-demand file streaming after the termination is not possible. Also, we do not want to copy all files. -func (w *AWSFunctionWorkload) GetFileContent(_ string, _ io.Writer, _ bool, _ context.Context) error { +func (w *AWSFunctionWorkload) GetFileContent(_ string, _ http.ResponseWriter, _ bool, _ context.Context) error { return dto.ErrNotSupported } diff --git a/internal/runner/nomad_runner.go b/internal/runner/nomad_runner.go index d2e3deb..fb82e46 100644 --- a/internal/runner/nomad_runner.go +++ b/internal/runner/nomad_runner.go @@ -15,6 +15,7 @@ import ( "github.com/openHPI/poseidon/pkg/nullio" "github.com/openHPI/poseidon/pkg/storage" "io" + "net/http" "strings" "time" ) @@ -27,6 +28,9 @@ const ( // executionTimeoutGracePeriod is the time to wait after sending a SIGQUIT signal to a timed out execution. // If the execution does not return after this grace period, the runner is destroyed. executionTimeoutGracePeriod = 3 * time.Second + // lsCommand is our format for parsing information of a file(system). + lsCommand = "ls -l --time-style=+%s -1 --literal" + lsCommandRecursive = lsCommand + " --recursive" ) var ( @@ -121,9 +125,9 @@ func (r *NomadJob) ExecuteInteractively( func (r *NomadJob) ListFileSystem( path string, recursive bool, content io.Writer, privilegedExecution bool, ctx context.Context) error { r.ResetTimeout() - command := "ls -l --time-style=+%s -1 --literal" + command := lsCommand if recursive { - command += " --recursive" + command = lsCommandRecursive } ls2json := &nullio.Ls2JsonWriter{Target: content} @@ -174,13 +178,17 @@ func (r *NomadJob) UpdateFileSystem(copyRequest *dto.UpdateFileSystemRequest) er return nil } -func (r *NomadJob) GetFileContent(path string, content io.Writer, privilegedExecution bool, ctx context.Context) error { +func (r *NomadJob) GetFileContent( + path string, content http.ResponseWriter, privilegedExecution bool, ctx context.Context) error { r.ResetTimeout() - retrieveCommand := (&dto.ExecutionRequest{Command: fmt.Sprintf("cat %q", path)}).FullCommand() + contentLengthWriter := &nullio.ContentLengthWriter{Target: content} + retrieveCommand := (&dto.ExecutionRequest{ + Command: fmt.Sprintf("%s %q && cat %q", lsCommand, path, path), + }).FullCommand() // Improve: Instead of using io.Discard use a **fixed-sized** buffer. With that we could improve the error message. exitCode, err := r.api.ExecuteCommand(r.id, ctx, retrieveCommand, false, privilegedExecution, - &nullio.Reader{}, content, io.Discard) + &nullio.Reader{}, contentLengthWriter, io.Discard) if err != nil { return fmt.Errorf("%w: nomad error during retrieve file content copy: %v", diff --git a/internal/runner/nomad_runner_test.go b/internal/runner/nomad_runner_test.go index 3571b41..d5dfa13 100644 --- a/internal/runner/nomad_runner_test.go +++ b/internal/runner/nomad_runner_test.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/openHPI/poseidon/internal/nomad" "github.com/openHPI/poseidon/pkg/dto" + "github.com/openHPI/poseidon/pkg/logging" "github.com/openHPI/poseidon/pkg/nullio" "github.com/openHPI/poseidon/pkg/storage" "github.com/openHPI/poseidon/tests" @@ -403,6 +404,6 @@ func NewRunner(id string, manager Accessor) Runner { func (s *UpdateFileSystemTestSuite) TestGetFileContentReturnsErrorIfExitCodeIsNotZero() { s.mockedExecuteCommandCall.RunFn = nil s.mockedExecuteCommandCall.Return(1, nil) - err := s.runner.GetFileContent("", &bytes.Buffer{}, false, context.Background()) + err := s.runner.GetFileContent("", logging.NewLoggingResponseWriter(nil), false, context.Background()) s.ErrorIs(err, ErrFileNotFound) } diff --git a/internal/runner/runner.go b/internal/runner/runner.go index e177f8d..cbced12 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -7,6 +7,7 @@ import ( "github.com/openHPI/poseidon/pkg/monitoring" "github.com/openHPI/poseidon/pkg/storage" "io" + "net/http" ) type ExitInfo struct { @@ -54,7 +55,7 @@ type Runner interface { // GetFileContent streams the file content at the requested path into the Writer provided at content. // The result is streamed via the io.Writer in order to not overload the memory with user input. - GetFileContent(path string, content io.Writer, privilegedExecution bool, ctx context.Context) error + GetFileContent(path string, content http.ResponseWriter, privilegedExecution bool, ctx context.Context) error // Destroy destroys the Runner in Nomad. Destroy() error diff --git a/internal/runner/runner_mock.go b/internal/runner/runner_mock.go index 6a91d3b..080da63 100644 --- a/internal/runner/runner_mock.go +++ b/internal/runner/runner_mock.go @@ -4,10 +4,12 @@ package runner import ( context "context" - io "io" + http "net/http" dto "github.com/openHPI/poseidon/pkg/dto" + io "io" + mock "github.com/stretchr/testify/mock" time "time" @@ -93,11 +95,11 @@ func (_m *RunnerMock) ExecutionExists(id string) bool { } // GetFileContent provides a mock function with given fields: path, content, privilegedExecution, ctx -func (_m *RunnerMock) GetFileContent(path string, content io.Writer, privilegedExecution bool, ctx context.Context) error { +func (_m *RunnerMock) GetFileContent(path string, content http.ResponseWriter, privilegedExecution bool, ctx context.Context) error { ret := _m.Called(path, content, privilegedExecution, ctx) var r0 error - if rf, ok := ret.Get(0).(func(string, io.Writer, bool, context.Context) error); ok { + if rf, ok := ret.Get(0).(func(string, http.ResponseWriter, bool, context.Context) error); ok { r0 = rf(path, content, privilegedExecution, ctx) } else { r0 = ret.Error(0) diff --git a/pkg/nullio/content_length.go b/pkg/nullio/content_length.go new file mode 100644 index 0000000..2877c60 --- /dev/null +++ b/pkg/nullio/content_length.go @@ -0,0 +1,59 @@ +package nullio + +import ( + "errors" + "fmt" + "net/http" +) + +var ErrRegexMatching = errors.New("could not match content length") + +// ContentLengthWriter implements io.Writer. +// It parses the size from the first line as Content Length Header and streams the following data to the Target. +// The first line is expected to follow the format headerLineRegex. +type ContentLengthWriter struct { + Target http.ResponseWriter + contentLengthSet bool + firstLine []byte +} + +func (w *ContentLengthWriter) Write(p []byte) (count int, err error) { + if w.contentLengthSet { + count, err = w.Target.Write(p) + if err != nil { + err = fmt.Errorf("could not write to target: %w", err) + } + return count, err + } + + for i, char := range p { + if char != '\n' { + continue + } + + w.firstLine = append(w.firstLine, p[:i]...) + matches := headerLineRegex.FindSubmatch(w.firstLine) + if len(matches) < headerLineGroupName { + log.WithField("line", string(w.firstLine)).Error(ErrRegexMatching.Error()) + return 0, ErrRegexMatching + } + size := string(matches[headerLineGroupSize]) + w.Target.Header().Set("Content-Length", size) + w.contentLengthSet = true + + if i < len(p)-1 { + count, err = w.Target.Write(p[i+1:]) + if err != nil { + err = fmt.Errorf("could not write to target: %w", err) + } + } + + return len(p[:i]) + 1 + count, err + } + + if !w.contentLengthSet { + w.firstLine = append(w.firstLine, p...) + } + + return len(p), nil +} diff --git a/pkg/nullio/content_length_test.go b/pkg/nullio/content_length_test.go new file mode 100644 index 0000000..9b247fd --- /dev/null +++ b/pkg/nullio/content_length_test.go @@ -0,0 +1,48 @@ +package nullio + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "net/http" + "testing" +) + +type responseWriterStub struct { + bytes.Buffer + header http.Header +} + +func (r *responseWriterStub) Header() http.Header { + return r.header +} +func (r *responseWriterStub) WriteHeader(_ int) { +} + +func TestContentLengthWriter_Write(t *testing.T) { + header := http.Header(make(map[string][]string)) + buf := &responseWriterStub{header: header} + writer := &ContentLengthWriter{Target: buf} + part1 := []byte("-rw-rw-r-- 1 kali ka") + contentLength := "42" + part2 := []byte("li " + contentLength + " 1660763446 flag\nFL") + part3 := []byte("AG") + + count, err := writer.Write(part1) + require.NoError(t, err) + assert.Equal(t, len(part1), count) + assert.Empty(t, buf.String()) + assert.Equal(t, "", header.Get("Content-Length")) + + count, err = writer.Write(part2) + require.NoError(t, err) + assert.Equal(t, len(part2), count) + assert.Equal(t, "FL", buf.String()) + assert.Equal(t, contentLength, header.Get("Content-Length")) + + count, err = writer.Write(part3) + require.NoError(t, err) + assert.Equal(t, len(part3), count) + assert.Equal(t, "FLAG", buf.String()) + assert.Equal(t, contentLength, header.Get("Content-Length")) +} diff --git a/pkg/nullio/ls2json.go b/pkg/nullio/ls2json.go index 8a30ab4..13fd236 100644 --- a/pkg/nullio/ls2json.go +++ b/pkg/nullio/ls2json.go @@ -18,11 +18,22 @@ var ( headerLineRegex = regexp.MustCompile(`([-aAbcCdDlMnpPsw?])([-rwxXsStT]{9})([+ ])\d+ (.+?) (.+?) +(\d+) (\d+) (.*)$`) ) +const ( + headerLineGroupEntryType = 1 + headerLineGroupPermissions = 2 + headerLineGroupACL = 3 + headerLineGroupOwner = 4 + headerLineGroupGroup = 5 + headerLineGroupSize = 6 + headerLineGroupTimestamp = 7 + headerLineGroupName = 8 +) + // Ls2JsonWriter implements io.Writer. // It streams the passed data to the Target and transforms the data into the json format. type Ls2JsonWriter struct { Target io.Writer - jsonStartSend bool + jsonStartSent bool setCommaPrefix bool remaining []byte latestPath []byte @@ -64,20 +75,20 @@ func (w *Ls2JsonWriter) Write(p []byte) (int, error) { } func (w *Ls2JsonWriter) initializeJSONObject() (count int, err error) { - if !w.jsonStartSend { + if !w.jsonStartSent { count, err = w.Target.Write([]byte("{\"files\": [")) if count == 0 || err != nil { log.WithError(err).Warn("Could not write to target") err = fmt.Errorf("could not write to target: %w", err) } else { - w.jsonStartSend = true + w.jsonStartSent = true } } return count, err } func (w *Ls2JsonWriter) Close() { - if w.jsonStartSend { + if w.jsonStartSent { count, err := w.Target.Write([]byte("]}")) if count == 0 || err != nil { log.WithError(err).Warn("Could not Close ls2json writer") @@ -118,22 +129,20 @@ func (w *Ls2JsonWriter) writeLine(line []byte) (count int, err error) { } func (w *Ls2JsonWriter) parseFileHeader(matches [][]byte) ([]byte, error) { - entryType := dto.EntryType(matches[1][0]) - permissions := string(matches[2]) - acl := string(matches[3]) + entryType := dto.EntryType(matches[headerLineGroupEntryType][0]) + permissions := string(matches[headerLineGroupPermissions]) + acl := string(matches[headerLineGroupACL]) if acl == "+" { permissions += "+" } - owner := string(matches[4]) - group := string(matches[5]) - size, err1 := strconv.Atoi(string(matches[6])) - timestamp, err2 := strconv.Atoi(string(matches[7])) + size, err1 := strconv.Atoi(string(matches[headerLineGroupSize])) + timestamp, err2 := strconv.Atoi(string(matches[headerLineGroupTimestamp])) if err1 != nil || err2 != nil { return nil, fmt.Errorf("could not parse file details: %w %+v", err1, err2) } - name := dto.FilePath(append(w.latestPath, matches[8]...)) + name := dto.FilePath(append(w.latestPath, matches[headerLineGroupName]...)) linkTarget := dto.FilePath("") if entryType == dto.EntryTypeLink { parts := strings.Split(string(name), " -> ") @@ -153,8 +162,8 @@ func (w *Ls2JsonWriter) parseFileHeader(matches [][]byte) ([]byte, error) { Size: size, ModificationTime: timestamp, Permissions: permissions, - Owner: owner, - Group: group, + Owner: string(matches[headerLineGroupOwner]), + Group: string(matches[headerLineGroupGroup]), }) if err != nil { return nil, fmt.Errorf("could not marshal file header: %w", err) diff --git a/tests/e2e/runners_test.go b/tests/e2e/runners_test.go index 9c6cfcf..36e59a0 100644 --- a/tests/e2e/runners_test.go +++ b/tests/e2e/runners_test.go @@ -13,6 +13,7 @@ import ( "io" "net/http" "net/url" + "strconv" "strings" "time" ) @@ -156,8 +157,9 @@ func (s *E2ETestSuite) TestListFileSystem_Nomad() { fileHeader := listFilesResponse.Files[0] s.Equal(dto.FilePath("./"+tests.DefaultFileName), fileHeader.Name) s.Equal(dto.EntryTypeRegularFile, fileHeader.EntryType) - s.Equal("user", fileHeader.Owner) - s.Equal("user", fileHeader.Group) + // ToDo: Reconsider if those files should be owned by root. + s.Equal("root", fileHeader.Owner) + s.Equal("root", fileHeader.Group) s.Equal("rwxr--r--", fileHeader.Permissions) }) } @@ -352,6 +354,8 @@ func (s *E2ETestSuite) TestGetFileContent_Nomad() { response, err := http.Get(getFileURL.String()) s.Require().NoError(err) s.Equal(http.StatusOK, response.StatusCode) + s.Equal(strconv.Itoa(len(newFileContent)), response.Header.Get("Content-Length")) + s.Equal("attachment; filename=\""+tests.DefaultFileName+"\"", response.Header.Get("Content-Disposition")) content, err := io.ReadAll(response.Body) s.Require().NoError(err) s.Equal(newFileContent, content)