Add Content-Length and Content-Disposition Header

for GetFileContent route.
This commit is contained in:
Maximilian Paß
2022-09-28 21:49:35 +01:00
parent 0c70ad3b24
commit 195f88177e
10 changed files with 161 additions and 28 deletions

View File

@ -139,7 +139,7 @@ func (r *RunnerController) fileContent(writer http.ResponseWriter, request *http
privilegedExecution = false privilegedExecution = false
} }
writer.Header().Set("Content-Type", "application/octet-stream") writer.Header().Set("Content-Disposition", "attachment; filename=\""+path+"\"")
err = targetRunner.GetFileContent(path, writer, privilegedExecution, request.Context()) err = targetRunner.GetFileContent(path, writer, privilegedExecution, request.Context())
if errors.Is(err, runner.ErrFileNotFound) { if errors.Is(err, runner.ErrFileNotFound) {
writeClientError(writer, err, http.StatusFailedDependency) writeClientError(writer, err, http.StatusFailedDependency)

View File

@ -13,6 +13,7 @@ import (
"github.com/openHPI/poseidon/pkg/monitoring" "github.com/openHPI/poseidon/pkg/monitoring"
"github.com/openHPI/poseidon/pkg/storage" "github.com/openHPI/poseidon/pkg/storage"
"io" "io"
"net/http"
"time" "time"
) )
@ -125,7 +126,7 @@ func (w *AWSFunctionWorkload) UpdateFileSystem(request *dto.UpdateFileSystemRequ
// GetFileContent is currently not supported with this aws serverless function. // GetFileContent is currently not supported with this aws serverless function.
// This is because the function execution ends with the termination of the workload code. // This is because the function execution ends with the termination of the workload code.
// So an on-demand file streaming after the termination is not possible. Also, we do not want to copy all files. // So an on-demand file streaming after the termination is not possible. Also, we do not want to copy all files.
func (w *AWSFunctionWorkload) GetFileContent(_ string, _ io.Writer, _ bool, _ context.Context) error { func (w *AWSFunctionWorkload) GetFileContent(_ string, _ http.ResponseWriter, _ bool, _ context.Context) error {
return dto.ErrNotSupported return dto.ErrNotSupported
} }

View File

@ -15,6 +15,7 @@ import (
"github.com/openHPI/poseidon/pkg/nullio" "github.com/openHPI/poseidon/pkg/nullio"
"github.com/openHPI/poseidon/pkg/storage" "github.com/openHPI/poseidon/pkg/storage"
"io" "io"
"net/http"
"strings" "strings"
"time" "time"
) )
@ -27,6 +28,9 @@ const (
// executionTimeoutGracePeriod is the time to wait after sending a SIGQUIT signal to a timed out execution. // executionTimeoutGracePeriod is the time to wait after sending a SIGQUIT signal to a timed out execution.
// If the execution does not return after this grace period, the runner is destroyed. // If the execution does not return after this grace period, the runner is destroyed.
executionTimeoutGracePeriod = 3 * time.Second executionTimeoutGracePeriod = 3 * time.Second
// lsCommand is our format for parsing information of a file(system).
lsCommand = "ls -l --time-style=+%s -1 --literal"
lsCommandRecursive = lsCommand + " --recursive"
) )
var ( var (
@ -121,9 +125,9 @@ func (r *NomadJob) ExecuteInteractively(
func (r *NomadJob) ListFileSystem( func (r *NomadJob) ListFileSystem(
path string, recursive bool, content io.Writer, privilegedExecution bool, ctx context.Context) error { path string, recursive bool, content io.Writer, privilegedExecution bool, ctx context.Context) error {
r.ResetTimeout() r.ResetTimeout()
command := "ls -l --time-style=+%s -1 --literal" command := lsCommand
if recursive { if recursive {
command += " --recursive" command = lsCommandRecursive
} }
ls2json := &nullio.Ls2JsonWriter{Target: content} ls2json := &nullio.Ls2JsonWriter{Target: content}
@ -174,13 +178,17 @@ func (r *NomadJob) UpdateFileSystem(copyRequest *dto.UpdateFileSystemRequest) er
return nil return nil
} }
func (r *NomadJob) GetFileContent(path string, content io.Writer, privilegedExecution bool, ctx context.Context) error { func (r *NomadJob) GetFileContent(
path string, content http.ResponseWriter, privilegedExecution bool, ctx context.Context) error {
r.ResetTimeout() r.ResetTimeout()
retrieveCommand := (&dto.ExecutionRequest{Command: fmt.Sprintf("cat %q", path)}).FullCommand() contentLengthWriter := &nullio.ContentLengthWriter{Target: content}
retrieveCommand := (&dto.ExecutionRequest{
Command: fmt.Sprintf("%s %q && cat %q", lsCommand, path, path),
}).FullCommand()
// Improve: Instead of using io.Discard use a **fixed-sized** buffer. With that we could improve the error message. // Improve: Instead of using io.Discard use a **fixed-sized** buffer. With that we could improve the error message.
exitCode, err := r.api.ExecuteCommand(r.id, ctx, retrieveCommand, false, privilegedExecution, exitCode, err := r.api.ExecuteCommand(r.id, ctx, retrieveCommand, false, privilegedExecution,
&nullio.Reader{}, content, io.Discard) &nullio.Reader{}, contentLengthWriter, io.Discard)
if err != nil { if err != nil {
return fmt.Errorf("%w: nomad error during retrieve file content copy: %v", return fmt.Errorf("%w: nomad error during retrieve file content copy: %v",

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"github.com/openHPI/poseidon/internal/nomad" "github.com/openHPI/poseidon/internal/nomad"
"github.com/openHPI/poseidon/pkg/dto" "github.com/openHPI/poseidon/pkg/dto"
"github.com/openHPI/poseidon/pkg/logging"
"github.com/openHPI/poseidon/pkg/nullio" "github.com/openHPI/poseidon/pkg/nullio"
"github.com/openHPI/poseidon/pkg/storage" "github.com/openHPI/poseidon/pkg/storage"
"github.com/openHPI/poseidon/tests" "github.com/openHPI/poseidon/tests"
@ -403,6 +404,6 @@ func NewRunner(id string, manager Accessor) Runner {
func (s *UpdateFileSystemTestSuite) TestGetFileContentReturnsErrorIfExitCodeIsNotZero() { func (s *UpdateFileSystemTestSuite) TestGetFileContentReturnsErrorIfExitCodeIsNotZero() {
s.mockedExecuteCommandCall.RunFn = nil s.mockedExecuteCommandCall.RunFn = nil
s.mockedExecuteCommandCall.Return(1, nil) s.mockedExecuteCommandCall.Return(1, nil)
err := s.runner.GetFileContent("", &bytes.Buffer{}, false, context.Background()) err := s.runner.GetFileContent("", logging.NewLoggingResponseWriter(nil), false, context.Background())
s.ErrorIs(err, ErrFileNotFound) s.ErrorIs(err, ErrFileNotFound)
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/openHPI/poseidon/pkg/monitoring" "github.com/openHPI/poseidon/pkg/monitoring"
"github.com/openHPI/poseidon/pkg/storage" "github.com/openHPI/poseidon/pkg/storage"
"io" "io"
"net/http"
) )
type ExitInfo struct { type ExitInfo struct {
@ -54,7 +55,7 @@ type Runner interface {
// GetFileContent streams the file content at the requested path into the Writer provided at content. // GetFileContent streams the file content at the requested path into the Writer provided at content.
// The result is streamed via the io.Writer in order to not overload the memory with user input. // The result is streamed via the io.Writer in order to not overload the memory with user input.
GetFileContent(path string, content io.Writer, privilegedExecution bool, ctx context.Context) error GetFileContent(path string, content http.ResponseWriter, privilegedExecution bool, ctx context.Context) error
// Destroy destroys the Runner in Nomad. // Destroy destroys the Runner in Nomad.
Destroy() error Destroy() error

View File

@ -4,10 +4,12 @@ package runner
import ( import (
context "context" context "context"
io "io" http "net/http"
dto "github.com/openHPI/poseidon/pkg/dto" dto "github.com/openHPI/poseidon/pkg/dto"
io "io"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
time "time" time "time"
@ -93,11 +95,11 @@ func (_m *RunnerMock) ExecutionExists(id string) bool {
} }
// GetFileContent provides a mock function with given fields: path, content, privilegedExecution, ctx // GetFileContent provides a mock function with given fields: path, content, privilegedExecution, ctx
func (_m *RunnerMock) GetFileContent(path string, content io.Writer, privilegedExecution bool, ctx context.Context) error { func (_m *RunnerMock) GetFileContent(path string, content http.ResponseWriter, privilegedExecution bool, ctx context.Context) error {
ret := _m.Called(path, content, privilegedExecution, ctx) ret := _m.Called(path, content, privilegedExecution, ctx)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(string, io.Writer, bool, context.Context) error); ok { if rf, ok := ret.Get(0).(func(string, http.ResponseWriter, bool, context.Context) error); ok {
r0 = rf(path, content, privilegedExecution, ctx) r0 = rf(path, content, privilegedExecution, ctx)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)

View File

@ -0,0 +1,59 @@
package nullio
import (
"errors"
"fmt"
"net/http"
)
var ErrRegexMatching = errors.New("could not match content length")
// ContentLengthWriter implements io.Writer.
// It parses the size from the first line as Content Length Header and streams the following data to the Target.
// The first line is expected to follow the format headerLineRegex.
type ContentLengthWriter struct {
Target http.ResponseWriter
contentLengthSet bool
firstLine []byte
}
func (w *ContentLengthWriter) Write(p []byte) (count int, err error) {
if w.contentLengthSet {
count, err = w.Target.Write(p)
if err != nil {
err = fmt.Errorf("could not write to target: %w", err)
}
return count, err
}
for i, char := range p {
if char != '\n' {
continue
}
w.firstLine = append(w.firstLine, p[:i]...)
matches := headerLineRegex.FindSubmatch(w.firstLine)
if len(matches) < headerLineGroupName {
log.WithField("line", string(w.firstLine)).Error(ErrRegexMatching.Error())
return 0, ErrRegexMatching
}
size := string(matches[headerLineGroupSize])
w.Target.Header().Set("Content-Length", size)
w.contentLengthSet = true
if i < len(p)-1 {
count, err = w.Target.Write(p[i+1:])
if err != nil {
err = fmt.Errorf("could not write to target: %w", err)
}
}
return len(p[:i]) + 1 + count, err
}
if !w.contentLengthSet {
w.firstLine = append(w.firstLine, p...)
}
return len(p), nil
}

View File

@ -0,0 +1,48 @@
package nullio
import (
"bytes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net/http"
"testing"
)
type responseWriterStub struct {
bytes.Buffer
header http.Header
}
func (r *responseWriterStub) Header() http.Header {
return r.header
}
func (r *responseWriterStub) WriteHeader(_ int) {
}
func TestContentLengthWriter_Write(t *testing.T) {
header := http.Header(make(map[string][]string))
buf := &responseWriterStub{header: header}
writer := &ContentLengthWriter{Target: buf}
part1 := []byte("-rw-rw-r-- 1 kali ka")
contentLength := "42"
part2 := []byte("li " + contentLength + " 1660763446 flag\nFL")
part3 := []byte("AG")
count, err := writer.Write(part1)
require.NoError(t, err)
assert.Equal(t, len(part1), count)
assert.Empty(t, buf.String())
assert.Equal(t, "", header.Get("Content-Length"))
count, err = writer.Write(part2)
require.NoError(t, err)
assert.Equal(t, len(part2), count)
assert.Equal(t, "FL", buf.String())
assert.Equal(t, contentLength, header.Get("Content-Length"))
count, err = writer.Write(part3)
require.NoError(t, err)
assert.Equal(t, len(part3), count)
assert.Equal(t, "FLAG", buf.String())
assert.Equal(t, contentLength, header.Get("Content-Length"))
}

View File

@ -18,11 +18,22 @@ var (
headerLineRegex = regexp.MustCompile(`([-aAbcCdDlMnpPsw?])([-rwxXsStT]{9})([+ ])\d+ (.+?) (.+?) +(\d+) (\d+) (.*)$`) headerLineRegex = regexp.MustCompile(`([-aAbcCdDlMnpPsw?])([-rwxXsStT]{9})([+ ])\d+ (.+?) (.+?) +(\d+) (\d+) (.*)$`)
) )
const (
headerLineGroupEntryType = 1
headerLineGroupPermissions = 2
headerLineGroupACL = 3
headerLineGroupOwner = 4
headerLineGroupGroup = 5
headerLineGroupSize = 6
headerLineGroupTimestamp = 7
headerLineGroupName = 8
)
// Ls2JsonWriter implements io.Writer. // Ls2JsonWriter implements io.Writer.
// It streams the passed data to the Target and transforms the data into the json format. // It streams the passed data to the Target and transforms the data into the json format.
type Ls2JsonWriter struct { type Ls2JsonWriter struct {
Target io.Writer Target io.Writer
jsonStartSend bool jsonStartSent bool
setCommaPrefix bool setCommaPrefix bool
remaining []byte remaining []byte
latestPath []byte latestPath []byte
@ -64,20 +75,20 @@ func (w *Ls2JsonWriter) Write(p []byte) (int, error) {
} }
func (w *Ls2JsonWriter) initializeJSONObject() (count int, err error) { func (w *Ls2JsonWriter) initializeJSONObject() (count int, err error) {
if !w.jsonStartSend { if !w.jsonStartSent {
count, err = w.Target.Write([]byte("{\"files\": [")) count, err = w.Target.Write([]byte("{\"files\": ["))
if count == 0 || err != nil { if count == 0 || err != nil {
log.WithError(err).Warn("Could not write to target") log.WithError(err).Warn("Could not write to target")
err = fmt.Errorf("could not write to target: %w", err) err = fmt.Errorf("could not write to target: %w", err)
} else { } else {
w.jsonStartSend = true w.jsonStartSent = true
} }
} }
return count, err return count, err
} }
func (w *Ls2JsonWriter) Close() { func (w *Ls2JsonWriter) Close() {
if w.jsonStartSend { if w.jsonStartSent {
count, err := w.Target.Write([]byte("]}")) count, err := w.Target.Write([]byte("]}"))
if count == 0 || err != nil { if count == 0 || err != nil {
log.WithError(err).Warn("Could not Close ls2json writer") log.WithError(err).Warn("Could not Close ls2json writer")
@ -118,22 +129,20 @@ func (w *Ls2JsonWriter) writeLine(line []byte) (count int, err error) {
} }
func (w *Ls2JsonWriter) parseFileHeader(matches [][]byte) ([]byte, error) { func (w *Ls2JsonWriter) parseFileHeader(matches [][]byte) ([]byte, error) {
entryType := dto.EntryType(matches[1][0]) entryType := dto.EntryType(matches[headerLineGroupEntryType][0])
permissions := string(matches[2]) permissions := string(matches[headerLineGroupPermissions])
acl := string(matches[3]) acl := string(matches[headerLineGroupACL])
if acl == "+" { if acl == "+" {
permissions += "+" permissions += "+"
} }
owner := string(matches[4]) size, err1 := strconv.Atoi(string(matches[headerLineGroupSize]))
group := string(matches[5]) timestamp, err2 := strconv.Atoi(string(matches[headerLineGroupTimestamp]))
size, err1 := strconv.Atoi(string(matches[6]))
timestamp, err2 := strconv.Atoi(string(matches[7]))
if err1 != nil || err2 != nil { if err1 != nil || err2 != nil {
return nil, fmt.Errorf("could not parse file details: %w %+v", err1, err2) return nil, fmt.Errorf("could not parse file details: %w %+v", err1, err2)
} }
name := dto.FilePath(append(w.latestPath, matches[8]...)) name := dto.FilePath(append(w.latestPath, matches[headerLineGroupName]...))
linkTarget := dto.FilePath("") linkTarget := dto.FilePath("")
if entryType == dto.EntryTypeLink { if entryType == dto.EntryTypeLink {
parts := strings.Split(string(name), " -> ") parts := strings.Split(string(name), " -> ")
@ -153,8 +162,8 @@ func (w *Ls2JsonWriter) parseFileHeader(matches [][]byte) ([]byte, error) {
Size: size, Size: size,
ModificationTime: timestamp, ModificationTime: timestamp,
Permissions: permissions, Permissions: permissions,
Owner: owner, Owner: string(matches[headerLineGroupOwner]),
Group: group, Group: string(matches[headerLineGroupGroup]),
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("could not marshal file header: %w", err) return nil, fmt.Errorf("could not marshal file header: %w", err)

View File

@ -13,6 +13,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time" "time"
) )
@ -156,8 +157,9 @@ func (s *E2ETestSuite) TestListFileSystem_Nomad() {
fileHeader := listFilesResponse.Files[0] fileHeader := listFilesResponse.Files[0]
s.Equal(dto.FilePath("./"+tests.DefaultFileName), fileHeader.Name) s.Equal(dto.FilePath("./"+tests.DefaultFileName), fileHeader.Name)
s.Equal(dto.EntryTypeRegularFile, fileHeader.EntryType) s.Equal(dto.EntryTypeRegularFile, fileHeader.EntryType)
s.Equal("user", fileHeader.Owner) // ToDo: Reconsider if those files should be owned by root.
s.Equal("user", fileHeader.Group) s.Equal("root", fileHeader.Owner)
s.Equal("root", fileHeader.Group)
s.Equal("rwxr--r--", fileHeader.Permissions) s.Equal("rwxr--r--", fileHeader.Permissions)
}) })
} }
@ -352,6 +354,8 @@ func (s *E2ETestSuite) TestGetFileContent_Nomad() {
response, err := http.Get(getFileURL.String()) response, err := http.Get(getFileURL.String())
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(http.StatusOK, response.StatusCode) s.Equal(http.StatusOK, response.StatusCode)
s.Equal(strconv.Itoa(len(newFileContent)), response.Header.Get("Content-Length"))
s.Equal("attachment; filename=\""+tests.DefaultFileName+"\"", response.Header.Get("Content-Disposition"))
content, err := io.ReadAll(response.Body) content, err := io.ReadAll(response.Body)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(newFileContent, content) s.Equal(newFileContent, content)