diff --git a/deploy/aws/java11Exec/src/main/java/poseidon/App.java b/deploy/aws/java11Exec/src/main/java/poseidon/App.java index d0d4ff7..7beb110 100644 --- a/deploy/aws/java11Exec/src/main/java/poseidon/App.java +++ b/deploy/aws/java11Exec/src/main/java/poseidon/App.java @@ -67,9 +67,10 @@ public class App implements RequestHandler\\w*))?(?(?:.*?=.*?)+)?(?: && .*)?$"); + private static final Pattern isMakeCommand = Pattern.compile("^(?.* && )?make(?:\\s+(?\\w*))?(?(?:.*?=.*?)+)?(? && .*)?$"); // This pattern identifies the rules in a makefile. private static final Pattern makeRules = Pattern.compile("(?.*):\\r?\\n(?(?:\\t.+\\r?\\n?)*)"); @@ -131,6 +131,15 @@ class SimpleMakefile { String command = getCommand(ruleArgument); String assignments = makeCommandMatcher.group("assignments"); - return injectAssignments(command, (assignments != null) ? assignments : ""); + command = injectAssignments(command, (assignments != null) ? assignments : ""); + + if (makeCommandMatcher.group("before") != null) { + command = makeCommandMatcher.group("before") + command; + } + if (makeCommandMatcher.group("after") != null) { + command = command + makeCommandMatcher.group("after"); + } + + return command; } } diff --git a/deploy/aws/java11Exec/src/test/java/poseidon/AppTest.java b/deploy/aws/java11Exec/src/test/java/poseidon/AppTest.java index edccba8..12b2b2c 100644 --- a/deploy/aws/java11Exec/src/test/java/poseidon/AppTest.java +++ b/deploy/aws/java11Exec/src/test/java/poseidon/AppTest.java @@ -53,7 +53,7 @@ public class AppTest { @Test public void successfulResponse() { - APIGatewayProxyResponseEvent result = getApiGatewayProxyResponse(RecursiveMathContent); + APIGatewayProxyResponseEvent result = getApiGatewayProxyResponseRecursiveMath(RecursiveMathContent); assertEquals(200, result.getStatusCode().intValue()); } @@ -61,7 +61,7 @@ public class AppTest { @Test public void successfulMultilineResponse() { ByteArrayOutputStream out = setupStdOutLogs(); - APIGatewayProxyResponseEvent result = getApiGatewayProxyResponse(MultilineMathContent); + APIGatewayProxyResponseEvent result = getApiGatewayProxyResponseRecursiveMath(MultilineMathContent); restoreStdOutLogs(); assertEquals(200, result.getStatusCode().intValue()); @@ -75,7 +75,7 @@ public class AppTest { @Test public void outputWithoutTrailingNewline() { ByteArrayOutputStream out = setupStdOutLogs(); - APIGatewayProxyResponseEvent result = getApiGatewayProxyResponse(MathContentWithoutTrailingNewline); + APIGatewayProxyResponseEvent result = getApiGatewayProxyResponseRecursiveMath(MathContentWithoutTrailingNewline); restoreStdOutLogs(); assertEquals(200, result.getStatusCode().intValue()); @@ -85,6 +85,21 @@ public class AppTest { assertEquals(expectedOutput, out.toString()); } + @Test + public void makefileJustReplacesShellCommand() { + ByteArrayOutputStream out = setupStdOutLogs(); + APIGatewayProxyResponseEvent result = getApiGatewayProxyResponse("{\"action\":\"java11Exec\"," + + "\"cmd\":[\"env\", \"TEST_VAR=42\", \"sh\",\"-c\",\"make run\"]," + + "\"files\":{\"Makefile\":\"" + Base64.getEncoder().encodeToString(("run:\n\t@echo $TEST_VAR\n").getBytes(StandardCharsets.UTF_8)) + "\"}}"); + restoreStdOutLogs(); + + assertEquals(200, result.getStatusCode().intValue()); + String expectedOutput = + "{\"type\":\"stdout\",\"data\":\"42\\n\"}\n" + + "{\"type\":\"exit\",\"data\":0}\n"; + assertEquals(expectedOutput, out.toString()); + } + private PrintStream originalOut; private ByteArrayOutputStream setupStdOutLogs() { @@ -98,7 +113,7 @@ public class AppTest { System.setOut(originalOut); } - private APIGatewayProxyResponseEvent getApiGatewayProxyResponse(String content) { + private APIGatewayProxyResponseEvent getApiGatewayProxyResponse(String body) { App app = new App(); APIGatewayV2WebSocketEvent input = new APIGatewayV2WebSocketEvent(); APIGatewayV2WebSocketEvent.RequestContext ctx = new APIGatewayV2WebSocketEvent.RequestContext(); @@ -108,8 +123,12 @@ public class AppTest { Map headers = new HashMap<>(); headers.put(App.disableOutputHeaderKey, "True"); input.setHeaders(headers); - input.setBody("{\"action\":\"java11Exec\",\"cmd\":[\"sh\",\"-c\",\"javac org/example/RecursiveMath.java && java org/example/RecursiveMath\"]," + - "\"files\":{\"org/example/RecursiveMath.java\":\"" + content + "\"}}"); + input.setBody(body); return app.handleRequest(input, null); } + + private APIGatewayProxyResponseEvent getApiGatewayProxyResponseRecursiveMath(String content) { + return getApiGatewayProxyResponse("{\"action\":\"java11Exec\",\"cmd\":[\"sh\",\"-c\",\"javac org/example/RecursiveMath.java && java org/example/RecursiveMath\"]," + + "\"files\":{\"org/example/RecursiveMath.java\":\"" + content + "\"}}"); + } } diff --git a/deploy/aws/java11Exec/src/test/java/poseidon/SimpleMakefileTest.java b/deploy/aws/java11Exec/src/test/java/poseidon/SimpleMakefileTest.java index 4a7169f..5c97f51 100644 --- a/deploy/aws/java11Exec/src/test/java/poseidon/SimpleMakefileTest.java +++ b/deploy/aws/java11Exec/src/test/java/poseidon/SimpleMakefileTest.java @@ -148,4 +148,20 @@ public class SimpleMakefileTest { fail(); } catch (InvalidMakefileException ignored) {} } + + @Test + public void withBeforeAndAfterStatements() { + Map files = new HashMap<>(); + files.put("Makefile", Base64.getEncoder().encodeToString(("run:\n\t@echo TRAAAIIN\n").getBytes(StandardCharsets.UTF_8))); + + try { + String command = "echo \"Look it's a\" && sl && make run && echo WOW"; + SimpleMakefile makefile = new SimpleMakefile(files); + String cmd = makefile.parseCommand(command); + + assertEquals("echo \"Look it's a\" && sl && echo TRAAAIIN && echo WOW", cmd); + } catch (NoMakefileFoundException | InvalidMakefileException | NoMakeCommandException ignored) { + fail(); + } + } } diff --git a/tests/e2e/websocket_test.go b/tests/e2e/websocket_test.go index fe865fc..31e1e57 100644 --- a/tests/e2e/websocket_test.go +++ b/tests/e2e/websocket_test.go @@ -62,19 +62,8 @@ func (s *E2ETestSuite) TestExecuteCommandRoute() { func (s *E2ETestSuite) TestOutputToStdout() { for _, environmentID := range environmentIDs { s.Run(environmentID.ToString(), func() { - connection, err := ProvideWebSocketConnection(&s.Suite, environmentID, - &dto.ExecutionRequest{Command: "echo -n Hello World"}) - s.Require().NoError(err) - - messages, err := helpers.ReceiveAllWebSocketMessages(connection) - s.Require().Error(err) - s.Equal(&websocket.CloseError{Code: websocket.CloseNormalClosure}, err) - - controlMessages := helpers.WebSocketControlMessages(messages) - s.Require().Equal(&dto.WebSocketMessage{Type: dto.WebSocketMetaStart}, controlMessages[0]) - s.Require().Equal(&dto.WebSocketMessage{Type: dto.WebSocketExit}, controlMessages[1]) - - stdout, _, _ := helpers.WebSocketOutputMessages(messages) + stdout, _, _ := ExecuteNonInteractive(&s.Suite, environmentID, + &dto.ExecutionRequest{Command: "echo -n Hello World"}, nil) s.Require().Equal("Hello World", stdout) }) } @@ -83,23 +72,12 @@ func (s *E2ETestSuite) TestOutputToStdout() { func (s *E2ETestSuite) TestOutputToStderr() { for _, environmentID := range environmentIDs { s.Run(environmentID.ToString(), func() { - connection, err := ProvideWebSocketConnection(&s.Suite, environmentID, - &dto.ExecutionRequest{Command: "cat -invalid"}) - s.Require().NoError(err) + stdout, stderr, exitCode := ExecuteNonInteractive(&s.Suite, environmentID, + &dto.ExecutionRequest{Command: "cat -invalid"}, nil) - messages, err := helpers.ReceiveAllWebSocketMessages(connection) - s.Require().Error(err) - s.Equal(&websocket.CloseError{Code: websocket.CloseNormalClosure}, err) - - controlMessages := helpers.WebSocketControlMessages(messages) - s.Require().Equal(2, len(controlMessages)) - s.Require().Equal(&dto.WebSocketMessage{Type: dto.WebSocketMetaStart}, controlMessages[0]) - s.Require().Equal(&dto.WebSocketMessage{Type: dto.WebSocketExit, ExitCode: 1}, controlMessages[1]) - - stdout, stderr, errors := helpers.WebSocketOutputMessages(messages) s.NotContains(stdout, "cat: invalid option", "Stdout should not contain the error") s.Contains(stderr, "cat: invalid option", "Stderr should contain the error") - s.Empty(errors) + s.Equal(uint8(1), exitCode) }) } } @@ -108,7 +86,7 @@ func (s *E2ETestSuite) TestOutputToStderr() { func (s *E2ETestSuite) TestCommandHead() { hello := "Hello World!" connection, err := ProvideWebSocketConnection(&s.Suite, tests.DefaultEnvironmentIDAsInteger, - &dto.ExecutionRequest{Command: "head -n 1"}) + &dto.ExecutionRequest{Command: "head -n 1"}, nil) s.Require().NoError(err) startMessage, err := helpers.ReceiveNextWebSocketMessage(connection) @@ -128,42 +106,64 @@ func (s *E2ETestSuite) TestCommandHead() { func (s *E2ETestSuite) TestCommandMake() { for _, environmentID := range environmentIDs { s.Run(environmentID.ToString(), func() { - runnerID, err := ProvideRunner(&dto.RunnerRequest{ExecutionEnvironmentID: int(environmentID)}) - s.Require().NoError(err) - expectedOutput := "MeinText" - resp, err := CopyFiles(runnerID, &dto.UpdateFileSystemRequest{ + request := &dto.UpdateFileSystemRequest{ Copy: []dto.File{ {Path: "Makefile", Content: []byte( "run:\n\t@echo " + expectedOutput + "\n\n" + "test:\n\t@echo Hi\n"), }, }, - }) - s.Require().NoError(err) - s.Require().Equal(http.StatusNoContent, resp.StatusCode) - - webSocketURL, err := ProvideWebSocketURL(&s.Suite, runnerID, &dto.ExecutionRequest{Command: "make run"}) - s.Require().NoError(err) - connection, err := ConnectToWebSocket(webSocketURL) - s.Require().NoError(err) - - messages, err := helpers.ReceiveAllWebSocketMessages(connection) - s.Require().Error(err) - s.Equal(err, &websocket.CloseError{Code: websocket.CloseNormalClosure}) - stdout, _, _ := helpers.WebSocketOutputMessages(messages) + } + stdout, _, _ := ExecuteNonInteractive(&s.Suite, environmentID, &dto.ExecutionRequest{Command: "make run"}, request) stdout = regexp.MustCompile(`\r?\n$`).ReplaceAllString(stdout, "") s.Equal(expectedOutput, stdout) }) } } +func (s *E2ETestSuite) TestEnvironmentVariables() { + for _, environmentID := range environmentIDs { + s.Run(environmentID.ToString(), func() { + stdout, _, _ := ExecuteNonInteractive(&s.Suite, environmentID, &dto.ExecutionRequest{ + Command: "env", + Environment: map[string]string{"hello": "world"}, + }, nil) + + variables := s.expectEnvironmentVariables(stdout) + s.Contains(variables, "hello=world") + }) + } +} + +func (s *E2ETestSuite) TestCommandMakeEnvironmentVariables() { + for _, environmentID := range environmentIDs { + s.Run(environmentID.ToString(), func() { + request := &dto.UpdateFileSystemRequest{ + Copy: []dto.File{{Path: "Makefile", Content: []byte("run:\n\t@env\n")}}, + } + + stdout, _, _ := ExecuteNonInteractive(&s.Suite, environmentID, &dto.ExecutionRequest{Command: "make run"}, request) + s.expectEnvironmentVariables(stdout) + }) + } +} + +func (s *E2ETestSuite) expectEnvironmentVariables(stdout string) []string { + variables := strings.Split(strings.ReplaceAll(stdout, "\r\n", "\n"), "\n") + s.Contains(variables, "CODEOCEAN=true") + for _, envVar := range variables { + s.False(strings.HasPrefix(envVar, "AWS")) + } + return variables +} + func (s *E2ETestSuite) TestCommandReturnsAfterTimeout() { for _, environmentID := range environmentIDs { s.Run(environmentID.ToString(), func() { connection, err := ProvideWebSocketConnection(&s.Suite, environmentID, - &dto.ExecutionRequest{Command: "sleep 4", TimeLimit: 1}) + &dto.ExecutionRequest{Command: "sleep 4", TimeLimit: 1}, nil) s.Require().NoError(err) c := make(chan bool) @@ -189,52 +189,15 @@ func (s *E2ETestSuite) TestCommandReturnsAfterTimeout() { } } -func (s *E2ETestSuite) TestEnvironmentVariables() { - for _, environmentID := range environmentIDs { - s.Run(environmentID.ToString(), func() { - connection, err := ProvideWebSocketConnection(&s.Suite, environmentID, &dto.ExecutionRequest{ - Command: "env", - Environment: map[string]string{"hello": "world"}, - }) - s.Require().NoError(err) - - startMessage, err := helpers.ReceiveNextWebSocketMessage(connection) - s.Require().NoError(err) - s.Equal(dto.WebSocketMetaStart, startMessage.Type) - - messages, err := helpers.ReceiveAllWebSocketMessages(connection) - s.Require().Error(err) - s.Equal(err, &websocket.CloseError{Code: websocket.CloseNormalClosure}) - stdout, _, _ := helpers.WebSocketOutputMessages(messages) - - variables := strings.Split(strings.ReplaceAll(stdout, "\r\n", "\n"), "\n") - s.Contains(variables, "hello=world") - s.Contains(variables, "CODEOCEAN=true") - for _, envVar := range variables { - s.False(strings.HasPrefix(envVar, "AWS")) - } - }) - } -} - func (s *E2ETestSuite) TestMemoryMaxLimit_Nomad() { maxMemoryLimit := defaultNomadEnvironment.MemoryLimit // The operating system is in charge to kill the process and sometimes tolerates small exceeding of the limit. maxMemoryLimit = uint(1.1 * float64(maxMemoryLimit)) - connection, err := ProvideWebSocketConnection(&s.Suite, tests.DefaultEnvironmentIDAsInteger, &dto.ExecutionRequest{ + + stdout, stderr, _ := ExecuteNonInteractive(&s.Suite, tests.DefaultEnvironmentIDAsInteger, &dto.ExecutionRequest{ // This shell line tries to load maxMemoryLimit Bytes into the memory. Command: " /dev/null", - }) - s.Require().NoError(err) - - startMessage, err := helpers.ReceiveNextWebSocketMessage(connection) - s.Require().NoError(err) - s.Equal(dto.WebSocketMetaStart, startMessage.Type) - - messages, err := helpers.ReceiveAllWebSocketMessages(connection) - s.Require().Error(err) - s.Equal(err, &websocket.CloseError{Code: websocket.CloseNormalClosure}) - stdout, stderr, _ := helpers.WebSocketOutputMessages(messages) + }, nil) s.Empty(stdout) s.Contains(stderr, "Killed") } @@ -277,14 +240,43 @@ func (s *E2ETestSuite) ListTempDirectory(runnerID string) string { return stdout.String() } +// ExecuteNonInteractive Executes the passed executionRequest in the required environment without providing input. +func ExecuteNonInteractive(s *suite.Suite, environmentID dto.EnvironmentID, executionRequest *dto.ExecutionRequest, + copyRequest *dto.UpdateFileSystemRequest) (stdout, stderr string, exitCode uint8) { + connection, err := ProvideWebSocketConnection(s, environmentID, executionRequest, copyRequest) + s.Require().NoError(err) + + startMessage, err := helpers.ReceiveNextWebSocketMessage(connection) + s.Require().NoError(err) + s.Equal(dto.WebSocketMetaStart, startMessage.Type) + + messages, err := helpers.ReceiveAllWebSocketMessages(connection) + s.Require().Error(err) + s.Equal(err, &websocket.CloseError{Code: websocket.CloseNormalClosure}) + + controlMessages := helpers.WebSocketControlMessages(messages) + s.Require().Equal(1, len(controlMessages)) + exitMessage := controlMessages[0] + s.Require().Equal(dto.WebSocketExit, exitMessage.Type) + + stdout, stderr, errors := helpers.WebSocketOutputMessages(messages) + s.Empty(errors) + return stdout, stderr, exitMessage.ExitCode +} + // ProvideWebSocketConnection establishes a client WebSocket connection to run the passed ExecutionRequest. -func ProvideWebSocketConnection( - s *suite.Suite, environmentID dto.EnvironmentID, request *dto.ExecutionRequest) (*websocket.Conn, error) { +func ProvideWebSocketConnection(s *suite.Suite, environmentID dto.EnvironmentID, executionRequest *dto.ExecutionRequest, + copyRequest *dto.UpdateFileSystemRequest) (*websocket.Conn, error) { runnerID, err := ProvideRunner(&dto.RunnerRequest{ExecutionEnvironmentID: int(environmentID)}) if err != nil { return nil, fmt.Errorf("error providing runner: %w", err) } - webSocketURL, err := ProvideWebSocketURL(s, runnerID, request) + if copyRequest != nil { + resp, err := CopyFiles(runnerID, copyRequest) + s.Require().NoError(err) + s.Require().Equal(http.StatusNoContent, resp.StatusCode) + } + webSocketURL, err := ProvideWebSocketURL(s, runnerID, executionRequest) if err != nil { return nil, fmt.Errorf("error providing WebSocket URL: %w", err) }