diff --git a/internal/nomad/nomad_test.go b/internal/nomad/nomad_test.go index 9321a87..943c314 100644 --- a/internal/nomad/nomad_test.go +++ b/internal/nomad/nomad_test.go @@ -830,31 +830,32 @@ func (s *ExecuteCommandTestSuite) TestWithSeparateStderr() { var stdout, stderr bytes.Buffer var calledStdoutCommand, calledStderrCommand string - - // mock regular call - call := s.mockExecute(mock.AnythingOfType("string"), 0, nil, func(_ mock.Arguments) {}) - call.Run(func(args mock.Arguments) { + runFn := func(args mock.Arguments) { var ok bool calledCommand, ok := args.Get(2).(string) s.Require().True(ok) - writer, ok := args.Get(5).(io.Writer) - s.Require().True(ok) - + var out string if isStderrCommand := strings.Contains(calledCommand, "mkfifo"); isStderrCommand { calledStderrCommand = calledCommand - _, err := writer.Write([]byte(s.expectedStderr)) - s.Require().NoError(err) - call.ReturnArguments = mock.Arguments{stderrExitCode, nil} + out = s.expectedStderr } else { calledStdoutCommand = calledCommand - _, err := writer.Write([]byte(s.expectedStdout)) - s.Require().NoError(err) - call.ReturnArguments = mock.Arguments{commandExitCode, nil} + out = s.expectedStdout } - }) + + writer, ok := args.Get(5).(io.Writer) + s.Require().True(ok) + _, err := writer.Write([]byte(out)) + s.Require().NoError(err) + } + + s.apiMock.On("Execute", s.allocationID, mock.Anything, mock.Anything, withTTY, + mock.AnythingOfType("nullio.Reader"), mock.Anything, mock.Anything).Run(runFn).Return(stderrExitCode, nil) + s.apiMock.On("Execute", s.allocationID, mock.Anything, mock.Anything, withTTY, + mock.AnythingOfType("*bytes.Buffer"), mock.Anything, mock.Anything).Run(runFn).Return(commandExitCode, nil) exitCode, err := s.nomadAPIClient.ExecuteCommand(s.allocationID, s.ctx, s.testCommand, withTTY, - UnprivilegedExecution, nullio.Reader{}, &stdout, &stderr) + UnprivilegedExecution, &bytes.Buffer{}, &stdout, &stderr) s.Require().NoError(err) s.apiMock.AssertNumberOfCalls(s.T(), "Execute", 2)