diff --git a/api/auth/auth.go b/api/auth/auth.go new file mode 100644 index 0000000..81303cc --- /dev/null +++ b/api/auth/auth.go @@ -0,0 +1,36 @@ +package auth + +import ( + "crypto/subtle" + "gitlab.hpi.de/codeocean/codemoon/poseidon/config" + "gitlab.hpi.de/codeocean/codemoon/poseidon/logging" + "net/http" +) + +var log = logging.GetLogger("api/auth") + +const TokenHeader = "X-Poseidon-Token" + +var correctAuthenticationToken []byte + +// InitializeAuthentication returns true iff the authentication is initialized successfully and can be used. +func InitializeAuthentication() bool { + token := config.Config.Server.Token + if token == "" { + return false + } + correctAuthenticationToken = []byte(token) + return true +} + +func HTTPAuthenticationMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get(TokenHeader) + if subtle.ConstantTimeCompare([]byte(token), correctAuthenticationToken) == 0 { + log.WithField("token", token).Warn("Incorrect token") + w.WriteHeader(http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go new file mode 100644 index 0000000..09290c6 --- /dev/null +++ b/api/auth/auth_test.go @@ -0,0 +1,71 @@ +package auth + +import ( + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "gitlab.hpi.de/codeocean/codemoon/poseidon/config" + "net/http" + "net/http/httptest" + "testing" +) + +const testToken = "C0rr3ctT0k3n" + +type AuthenticationMiddlewareTestSuite struct { + suite.Suite + request *http.Request + recorder *httptest.ResponseRecorder + httpAuthenticationMiddleware http.Handler +} + +func (suite *AuthenticationMiddlewareTestSuite) SetupTest() { + config.Config.Server.Token = testToken + InitializeAuthentication() + suite.recorder = httptest.NewRecorder() + request, err := http.NewRequest(http.MethodGet, "/api/v1/test", nil) + if err != nil { + suite.T().Fatal(err) + } + suite.request = request + suite.httpAuthenticationMiddleware = HTTPAuthenticationMiddleware( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) +} + +func (suite *AuthenticationMiddlewareTestSuite) TestReturns401WhenHeaderUnset() { + suite.httpAuthenticationMiddleware.ServeHTTP(suite.recorder, suite.request) + assert.Equal(suite.T(), http.StatusUnauthorized, suite.recorder.Code) +} + +func (suite *AuthenticationMiddlewareTestSuite) TestReturns401WhenTokenWrong() { + suite.request.Header.Set(TokenHeader, "Wr0ngT0k3n") + suite.httpAuthenticationMiddleware.ServeHTTP(suite.recorder, suite.request) + assert.Equal(suite.T(), http.StatusUnauthorized, suite.recorder.Code) +} + +func (suite *AuthenticationMiddlewareTestSuite) TestWarnsWhenUnauthorized() { + var hook *test.Hook + logger, hook := test.NewNullLogger() + log = logger.WithField("pkg", "api/auth") + + suite.request.Header.Set(TokenHeader, "Wr0ngT0k3n") + suite.httpAuthenticationMiddleware.ServeHTTP(suite.recorder, suite.request) + + assert.Equal(suite.T(), http.StatusUnauthorized, suite.recorder.Code) + assert.Equal(suite.T(), logrus.WarnLevel, hook.LastEntry().Level) + assert.Equal(suite.T(), hook.LastEntry().Data["token"], "Wr0ngT0k3n") +} + +func (suite *AuthenticationMiddlewareTestSuite) TestPassesWhenTokenCorrect() { + suite.request.Header.Set(TokenHeader, testToken) + suite.httpAuthenticationMiddleware.ServeHTTP(suite.recorder, suite.request) + + assert.Equal(suite.T(), http.StatusOK, suite.recorder.Code) +} + +func TestHTTPAuthenticationMiddleware(t *testing.T) { + suite.Run(t, new(AuthenticationMiddlewareTestSuite)) +}