Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package oidc | |
import ( | |
"fmt" | |
"testing" | |
"time" | |
errors "github.com/go-openapi/errors" | |
"github.com/golang-jwt/jwt/v4" | |
"github.com/stretchr/testify/assert" | |
"github.com/stretchr/testify/require" | |
"github.com/weaviate/weaviate/usecases/config" | |
) | |
func Test_Middleware_NotConfigured(t *testing.T) { | |
cfg := config.Config{ | |
Authentication: config.Authentication{ | |
OIDC: config.OIDC{ | |
Enabled: false, | |
}, | |
}, | |
} | |
expectedErr := errors.New(401, "oidc auth is not configured, please try another auth scheme or set up weaviate with OIDC configured") | |
client, err := New(cfg) | |
require.Nil(t, err) | |
principal, err := client.ValidateAndExtract("token-doesnt-matter", []string{}) | |
assert.Nil(t, principal) | |
assert.Equal(t, expectedErr, err) | |
} | |
func Test_Middleware_IncompleteConfiguration(t *testing.T) { | |
cfg := config.Config{ | |
Authentication: config.Authentication{ | |
OIDC: config.OIDC{ | |
Enabled: true, | |
}, | |
}, | |
} | |
expectedErr := fmt.Errorf("oidc init: invalid config: missing required field 'issuer', " + | |
"missing required field 'username_claim', missing required field 'client_id': either set a client_id or explicitly disable the check with 'skip_client_id_check: true'") | |
_, err := New(cfg) | |
assert.Equal(t, expectedErr, err) | |
} | |
type claims struct { | |
jwt.StandardClaims | |
Email string `json:"email"` | |
Groups []string `json:"groups"` | |
} | |
func Test_Middleware_WithValidToken(t *testing.T) { | |
t.Run("without groups set", func(t *testing.T) { | |
server := newOIDCServer(t) | |
defer server.Close() | |
cfg := config.Config{ | |
Authentication: config.Authentication{ | |
OIDC: config.OIDC{ | |
Enabled: true, | |
Issuer: server.URL, | |
ClientID: "best_client", | |
SkipClientIDCheck: false, | |
UsernameClaim: "sub", | |
}, | |
}, | |
} | |
token := token(t, "best-user", server.URL, "best_client") | |
client, err := New(cfg) | |
require.Nil(t, err) | |
principal, err := client.ValidateAndExtract(token, []string{}) | |
require.Nil(t, err) | |
assert.Equal(t, "best-user", principal.Username) | |
}) | |
t.Run("with a non-standard username claim", func(t *testing.T) { | |
server := newOIDCServer(t) | |
defer server.Close() | |
cfg := config.Config{ | |
Authentication: config.Authentication{ | |
OIDC: config.OIDC{ | |
Enabled: true, | |
Issuer: server.URL, | |
ClientID: "best_client", | |
SkipClientIDCheck: false, | |
UsernameClaim: "email", | |
GroupsClaim: "groups", | |
}, | |
}, | |
} | |
token := tokenWithEmail(t, "best-user", server.URL, "best_client", "[email protected]") | |
client, err := New(cfg) | |
require.Nil(t, err) | |
principal, err := client.ValidateAndExtract(token, []string{}) | |
require.Nil(t, err) | |
assert.Equal(t, "[email protected]", principal.Username) | |
}) | |
t.Run("with groups claim", func(t *testing.T) { | |
server := newOIDCServer(t) | |
defer server.Close() | |
cfg := config.Config{ | |
Authentication: config.Authentication{ | |
OIDC: config.OIDC{ | |
Enabled: true, | |
Issuer: server.URL, | |
ClientID: "best_client", | |
SkipClientIDCheck: false, | |
UsernameClaim: "sub", | |
GroupsClaim: "groups", | |
}, | |
}, | |
} | |
token := tokenWithGroups(t, "best-user", server.URL, "best_client", []string{"group1", "group2"}) | |
client, err := New(cfg) | |
require.Nil(t, err) | |
principal, err := client.ValidateAndExtract(token, []string{}) | |
require.Nil(t, err) | |
assert.Equal(t, "best-user", principal.Username) | |
assert.Equal(t, []string{"group1", "group2"}, principal.Groups) | |
}) | |
} | |
func token(t *testing.T, subject string, issuer string, aud string) string { | |
return tokenWithEmail(t, subject, issuer, aud, "") | |
} | |
func tokenWithEmail(t *testing.T, subject string, issuer string, aud string, email string) string { | |
claims := claims{ | |
Email: email, | |
} | |
return tokenWithClaims(t, subject, issuer, aud, claims) | |
} | |
func tokenWithGroups(t *testing.T, subject string, issuer string, aud string, groups []string) string { | |
claims := claims{ | |
Groups: groups, | |
} | |
return tokenWithClaims(t, subject, issuer, aud, claims) | |
} | |
func tokenWithClaims(t *testing.T, subject string, issuer string, aud string, claims claims) string { | |
//nolint:staticcheck // is deprecated, but for the purpose of this test, this doesn't matter | |
claims.StandardClaims = jwt.StandardClaims{ | |
Subject: subject, | |
Issuer: issuer, | |
Audience: aud, | |
ExpiresAt: time.Now().Add(10 * time.Second).Unix(), | |
} | |
token, err := signToken(claims) | |
require.Nil(t, err, "signing token should not error") | |
return token | |
} | |