Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 182 additions & 40 deletions base/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package base

import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"net/http"
Expand All @@ -13,103 +15,243 @@ import (

"connectrpc.com/connect"
"github.com/go-logr/logr"
"github.com/zalando/go-keyring"

apikeyv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/apikey/v1"
"github.com/cerbos/cloud-api/genpb/cerbos/cloud/apikey/v1/apikeyv1connect"
authv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/auth/v1"
)

const (
AuthTokenHeader = "x-cerbos-auth" //nolint:gosec
earlyExpiry = 5 * time.Minute
)

var ErrAuthenticationFailed = errors.New("failed to authenticate: invalid credentials")
var (
ErrAuthenticationFailed = errors.New("failed to authenticate: invalid credentials")
ErrNoSavedCredentials = errors.New("no saved credentials")
)

type authClient struct {
type tokenSetter struct {
expiresAt time.Time
apiKeyClient apikeyv1connect.ApiKeyServiceClient
savedCredentials *authv1.SavedCredentials
logger logr.Logger
accessToken string
clientID string
clientSecret string
invalidCredentials bool
mutex sync.RWMutex
invalidCredentials bool
}

func newAuthClient(conf ClientConf, httpClient *http.Client, clientOptions ...connect.ClientOption) *authClient {
return &authClient{
apiKeyClient: apikeyv1connect.NewApiKeyServiceClient(httpClient, conf.APIEndpoint, clientOptions...),
clientID: conf.Credentials.ClientID,
clientSecret: conf.Credentials.ClientSecret,
logger: conf.Logger.WithName("auth"),
func newTokenSetter(conf ClientConf, httpClient *http.Client, clientOptions ...connect.ClientOption) *tokenSetter {
return &tokenSetter{
apiKeyClient: apikeyv1connect.NewApiKeyServiceClient(httpClient, conf.APIEndpoint, clientOptions...),
clientID: conf.Credentials.ClientID,
clientSecret: conf.Credentials.ClientSecret,
savedCredentials: conf.Credentials.SavedCredentials,
logger: conf.Logger.WithName("auth"),
}
}

func (a *authClient) SetAuthTokenHeader(ctx context.Context, headers http.Header) error {
accessToken, err := a.authenticate(ctx)
func (ts *tokenSetter) SetHeader(ctx context.Context, headers http.Header) error {
accessToken, err := ts.authenticate(ctx)
if err != nil {
a.logger.V(1).Error(err, "Failed to authenticate")
ts.logger.V(1).Error(err, "Failed to authenticate")
return err
}

headers.Set(AuthTokenHeader, accessToken)
return nil
}

func (a *authClient) authenticate(ctx context.Context) (string, error) {
a.mutex.RLock()
if a.invalidCredentials {
a.mutex.RUnlock()
a.logger.V(4).Info("Short-circuiting auth because credentials are invalid")
func (ts *tokenSetter) authenticate(ctx context.Context) (string, error) {
ts.mutex.RLock()
if ts.invalidCredentials {
ts.mutex.RUnlock()
ts.logger.V(4).Info("Short-circuiting auth because credentials are invalid")
return "", ErrAuthenticationFailed
}
accessToken, ok := a.currentAccessToken()
a.mutex.RUnlock()
accessToken, ok := ts.currentAccessToken()
ts.mutex.RUnlock()
if ok {
a.logger.V(4).Info("Using existing token")
ts.logger.V(4).Info("Using existing token")
return accessToken, nil
}

a.mutex.Lock()
defer a.mutex.Unlock()
ts.mutex.Lock()
defer ts.mutex.Unlock()

if a.invalidCredentials {
a.logger.V(4).Info("Short-circuiting auth because credentials are invalid")
if ts.invalidCredentials {
ts.logger.V(4).Info("Short-circuiting auth because credentials are invalid")
return "", ErrAuthenticationFailed
}

accessToken, ok = a.currentAccessToken()
accessToken, ok = ts.currentAccessToken()
if ok {
a.logger.V(4).Info("Using existing token")
ts.logger.V(4).Info("Using existing token")
return accessToken, nil
}

a.logger.V(4).Info("Obtaining new access token")
response, err := a.apiKeyClient.IssueAccessToken(ctx, connect.NewRequest(&apikeyv1.IssueAccessTokenRequest{
ClientId: a.clientID,
ClientSecret: a.clientSecret,
}))
ts.logger.V(4).Info("Obtaining new access token")
var expiresIn time.Duration
var err error
//nolint:nestif
if ts.savedCredentials != nil {
// Saved credentials can only be device tokens. See credentials.NewFromSavedCredentials.
deviceToken := ts.savedCredentials.GetDeviceToken()
expiresIn = deviceToken.GetExpiresAt().AsTime().Sub(time.Now().UTC())
if expiresIn > earlyExpiry {
ts.accessToken = deviceToken.GetAccessToken()
} else {
var response *connect.Response[apikeyv1.RefreshDeviceTokenResponse]
response, err = ts.apiKeyClient.RefreshDeviceToken(ctx, connect.NewRequest(&apikeyv1.RefreshDeviceTokenRequest{
DeviceToken: deviceToken,
}))
if err == nil {
ts.accessToken = response.Msg.GetDeviceToken().GetAccessToken()
ts.savedCredentials = &authv1.SavedCredentials{
ApiEndpoint: ts.savedCredentials.GetApiEndpoint(),
Credentials: &authv1.SavedCredentials_DeviceToken{
DeviceToken: response.Msg.GetDeviceToken(),
},
}
expiresIn = ts.savedCredentials.GetDeviceToken().GetExpiresAt().AsTime().Sub(time.Now().UTC())
// Refresh token rotates so we need to save it.
_ = SaveCredentials(ts.savedCredentials)
}
}
} else {
var response *connect.Response[apikeyv1.IssueAccessTokenResponse]
response, err = ts.apiKeyClient.IssueAccessToken(ctx, connect.NewRequest(&apikeyv1.IssueAccessTokenRequest{
ClientId: ts.clientID,
ClientSecret: ts.clientSecret,
}))
if err == nil {
ts.accessToken = response.Msg.GetAccessToken()
expiresIn = response.Msg.ExpiresIn.AsDuration()
}
}

if err != nil {
a.logger.V(1).Error(err, "Failed to authenticate")
ts.logger.V(1).Error(err, "Failed to authenticate")
if connect.CodeOf(err) == connect.CodeUnauthenticated {
a.invalidCredentials = true
ts.invalidCredentials = true
return "", ErrAuthenticationFailed
}
return "", fmt.Errorf("failed to authenticate: %w", err)
}

expiresIn := response.Msg.ExpiresIn.AsDuration()
if expiresIn > earlyExpiry {
expiresIn -= earlyExpiry
}

a.accessToken = response.Msg.AccessToken
a.expiresAt = time.Now().Add(expiresIn)
a.logger.V(4).Info("Obtained new access token")
ts.expiresAt = time.Now().Add(expiresIn)
ts.logger.V(4).Info("Obtained new access token")

return ts.accessToken, nil
}

func (ts *tokenSetter) currentAccessToken() (string, bool) {
return ts.accessToken, ts.accessToken != "" && ts.expiresAt.After(time.Now())
}

func DeviceLogin(ctx context.Context, apiEndpoint string, tlsConf *tls.Config) error {
credentials, err := startDeviceRegistrationFlow(ctx, apiEndpoint, tlsConf)
if err != nil {
return err
}

return SaveCredentials(credentials)
}

func startDeviceRegistrationFlow(ctx context.Context, apiEndpoint string, tlsConf *tls.Config) (*authv1.SavedCredentials, error) {
httpClient := mkHTTPClient(ClientConf{TLS: tlsConf})
apiClient := apikeyv1connect.NewApiKeyServiceClient(httpClient, apiEndpoint)
stream, err := apiClient.RegisterDevice(ctx, connect.NewRequest(&apikeyv1.RegisterDeviceRequest{}))
if err != nil {
return nil, fmt.Errorf("failed to start device registration: %w", err)
}

defer stream.Close()

for stream.Receive() {
msg := stream.Msg()
switch m := msg.GetMessage().(type) {
case *apikeyv1.RegisterDeviceResponse_VerificationUrl:
fmt.Printf("Log in and connect this machine to your account by visiting %s\n", m.VerificationUrl) //nolint:forbidigo
case *apikeyv1.RegisterDeviceResponse_DeviceToken:
return &authv1.SavedCredentials{
ApiEndpoint: apiEndpoint,
Credentials: &authv1.SavedCredentials_DeviceToken{
DeviceToken: &authv1.DeviceToken{
AccessToken: m.DeviceToken.GetAccessToken(),
RefreshToken: m.DeviceToken.GetRefreshToken(),
ExpiresAt: m.DeviceToken.GetExpiresAt(),
TokenType: m.DeviceToken.GetTokenType(),
},
},
}, nil
}
}

if err := stream.Err(); err != nil {
return nil, fmt.Errorf("device registration failed: %w", err)
}

return nil, nil
}

func ClientLogin(ctx context.Context, apiEndpoint string, tlsConf *tls.Config, clientID, clientSecret string) error {
httpClient := mkHTTPClient(ClientConf{TLS: tlsConf})
apiClient := apikeyv1connect.NewApiKeyServiceClient(httpClient, apiEndpoint)
if _, err := apiClient.IssueAccessToken(ctx, connect.NewRequest(&apikeyv1.IssueAccessTokenRequest{
ClientId: clientID,
ClientSecret: clientSecret,
})); err != nil {
return fmt.Errorf("failed to authenticate: %w", err)
}

return a.accessToken, nil
return SaveCredentials(&authv1.SavedCredentials{
ApiEndpoint: apiEndpoint,
Credentials: &authv1.SavedCredentials_ClientCredentials{
ClientCredentials: &authv1.ClientCredentials{
ClientId: clientID,
ClientSecret: clientSecret,
},
},
})
}

func (a *authClient) currentAccessToken() (string, bool) {
return a.accessToken, a.accessToken != "" && a.expiresAt.After(time.Now())
func SaveCredentials(creds *authv1.SavedCredentials) error {
credBytes, err := creds.MarshalVT()
if err != nil {
return fmt.Errorf("failed to marshal credentials: %w", err)
}

credEncoded := base64.StdEncoding.EncodeToString(credBytes)
if err := keyring.Set(creds.GetApiEndpoint(), "cerbos", credEncoded); err != nil {
return fmt.Errorf("failed to save credentials to key ring: %w", err)
}

return nil
}

func GetSavedCredentials(apiEndpoint string) (*authv1.SavedCredentials, error) {
credEncoded, err := keyring.Get(apiEndpoint, "cerbos")
if err != nil {
return nil, fmt.Errorf("failed to get credentials from key ring: %w", err)
}

credBytes, err := base64.StdEncoding.DecodeString(credEncoded)
if err != nil {
return nil, fmt.Errorf("failed to decode credentials: %w", err)
}

creds := &authv1.SavedCredentials{}
if err := creds.UnmarshalVT(credBytes); err != nil {
return nil, fmt.Errorf("failed to unmarshal credentials: %w", err)
}

return creds, nil
}
58 changes: 58 additions & 0 deletions base/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2021-2026 Zenauth Ltd.
// SPDX-License-Identifier: Apache-2.0

package base_test

import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"github.com/zalando/go-keyring"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/cerbos/cloud-api/base"
authv1 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/auth/v1"
)

func TestSaveAndLoadCredentials(t *testing.T) {
keyring.MockInit()

t.Run("ClientCredentials", func(t *testing.T) {
want := &authv1.SavedCredentials{
ApiEndpoint: "https://api.example.com",
Credentials: &authv1.SavedCredentials_ClientCredentials{
ClientCredentials: &authv1.ClientCredentials{
ClientId: "client",
ClientSecret: "secret",
},
},
}

require.NoError(t, base.SaveCredentials(want))
have, err := base.GetSavedCredentials("https://api.example.com")
require.NoError(t, err)
require.Empty(t, cmp.Diff(want, have, protocmp.Transform()))
})

t.Run("DeviceToken", func(t *testing.T) {
want := &authv1.SavedCredentials{
ApiEndpoint: "https://device.example.com",
Credentials: &authv1.SavedCredentials_DeviceToken{
DeviceToken: &authv1.DeviceToken{
AccessToken: "access",
RefreshToken: "refresh",
ExpiresAt: timestamppb.New(time.Now().UTC().Add(30 * time.Minute)),
TokenType: "Bearer",
},
},
}

require.NoError(t, base.SaveCredentials(want))
have, err := base.GetSavedCredentials("https://device.example.com")
require.NoError(t, err)
require.Empty(t, cmp.Diff(want, have, protocmp.Transform()))
})
}
4 changes: 2 additions & 2 deletions base/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ func NewClient(conf ClientConf) (c Client, opts []connect.ClientOption, _ error)
}

retryableHTTPClient := mkRetryableHTTPClient(conf)
authClient := newAuthClient(conf, retryableHTTPClient, opts...)
tokenSetter := newTokenSetter(conf, retryableHTTPClient, opts...)

circuitBreaker := newCircuitBreakerInterceptor()
opts = append(opts, connect.WithInterceptors(circuitBreaker, newAuthInterceptor(authClient)))
opts = append(opts, connect.WithInterceptors(circuitBreaker, newAuthInterceptor(tokenSetter)))

return Client{
ClientConf: conf,
Expand Down
10 changes: 5 additions & 5 deletions base/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@ func (uas uaStreamingClientConn) RequestHeader() http.Header {
}

type authInterceptor struct {
authClient *authClient
tokenSetter *tokenSetter
}

func newAuthInterceptor(authClient *authClient) authInterceptor {
return authInterceptor{authClient: authClient}
func newAuthInterceptor(tokenSetter *tokenSetter) authInterceptor {
return authInterceptor{tokenSetter: tokenSetter}
}

func (ai authInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
err := ai.authClient.SetAuthTokenHeader(ctx, req.Header())
err := ai.tokenSetter.SetHeader(ctx, req.Header())
if err != nil {
return nil, err
}
Expand All @@ -94,7 +94,7 @@ func (ai authInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
func (ai authInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
conn := next(ctx, spec)
err := ai.authClient.SetAuthTokenHeader(ctx, conn.RequestHeader())
err := ai.tokenSetter.SetHeader(ctx, conn.RequestHeader())

return authStreamingClientConn{
StreamingClientConn: conn,
Expand Down
Loading
Loading