KevinStephenson
Adding in weaviate code
b110593
raw
history blame
4.75 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package oidc
import (
"context"
"fmt"
"strings"
"github.com/coreos/go-oidc/v3/oidc"
errors "github.com/go-openapi/errors"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/usecases/config"
)
// Client handles the OIDC setup at startup and provides a middleware to be
// used with the goswagger API
type Client struct {
config config.OIDC
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
}
// New OIDC Client: It tries to retrieve the JWKs at startup (or fails), it
// provides a middleware which can be used at runtime with a go-swagger style
// API
func New(cfg config.Config) (*Client, error) {
client := &Client{
config: cfg.Authentication.OIDC,
}
if !client.config.Enabled {
// if oidc is not enabled, we are done, no need to setup an actual client.
// The "disabled" client is however still valuable to deny any requests
// coming in with an OAuth token set.
return client, nil
}
if err := client.init(); err != nil {
return nil, fmt.Errorf("oidc init: %v", err)
}
return client, nil
}
func (c *Client) init() error {
if err := c.validateConfig(); err != nil {
return fmt.Errorf("invalid config: %v", err)
}
provider, err := oidc.NewProvider(context.Background(), c.config.Issuer)
if err != nil {
return fmt.Errorf("could not setup provider: %v", err)
}
c.provider = provider
// oauth2
verifier := provider.Verifier(&oidc.Config{
ClientID: c.config.ClientID,
SkipClientIDCheck: c.config.SkipClientIDCheck,
})
c.verifier = verifier
return nil
}
func (c *Client) validateConfig() error {
var msgs []string
if c.config.Issuer == "" {
msgs = append(msgs, "missing required field 'issuer'")
}
if c.config.UsernameClaim == "" {
msgs = append(msgs, "missing required field 'username_claim'")
}
if !c.config.SkipClientIDCheck && c.config.ClientID == "" {
msgs = append(msgs, "missing required field 'client_id': "+
"either set a client_id or explicitly disable the check with 'skip_client_id_check: true'")
}
if len(msgs) == 0 {
return nil
}
return fmt.Errorf(strings.Join(msgs, ", "))
}
// ValidateAndExtract can be used as a middleware for go-swagger
func (c *Client) ValidateAndExtract(token string, scopes []string) (*models.Principal, error) {
if !c.config.Enabled {
return nil, errors.New(401, "oidc auth is not configured, please try another auth scheme or set up weaviate with OIDC configured")
}
parsed, err := c.verifier.Verify(context.Background(), token)
if err != nil {
return nil, errors.New(401, err.Error())
}
claims, err := c.extractClaims(parsed)
if err != nil {
return nil, errors.New(500, fmt.Sprintf("oidc: %v", err))
}
username, err := c.extractUsername(claims)
if err != nil {
return nil, errors.New(500, fmt.Sprintf("oidc: %v", err))
}
groups := c.extractGroups(claims)
return &models.Principal{
Username: username,
Groups: groups,
}, nil
}
func (c *Client) extractClaims(token *oidc.IDToken) (map[string]interface{}, error) {
var claims map[string]interface{}
if err := token.Claims(&claims); err != nil {
return nil, fmt.Errorf("could not extract claims from token: %v", err)
}
return claims, nil
}
func (c *Client) extractUsername(claims map[string]interface{}) (string, error) {
usernameUntyped, ok := claims[c.config.UsernameClaim]
if !ok {
return "", fmt.Errorf("token doesn't contain required claim '%s'", c.config.UsernameClaim)
}
username, ok := usernameUntyped.(string)
if !ok {
return "", fmt.Errorf("claim '%s' is not a string, but %T", c.config.UsernameClaim, usernameUntyped)
}
return username, nil
}
// extractGroups never errors, if groups can't be parsed an empty set of groups
// is returned. This is because groups are not a required standard in the OIDC
// spec, so we can't error if an OIDC provider does not support them.
func (c *Client) extractGroups(claims map[string]interface{}) []string {
var groups []string
groupsUntyped, ok := claims[c.config.GroupsClaim]
if !ok {
return groups
}
groupsSlice, ok := groupsUntyped.([]interface{})
if !ok {
return groups
}
for _, untyped := range groupsSlice {
if group, ok := untyped.(string); ok {
groups = append(groups, group)
}
}
return groups
}