diff --git a/api/pkg/di/container.go b/api/pkg/di/container.go index ae5e15ec..de180a0a 100644 --- a/api/pkg/di/container.go +++ b/api/pkg/di/container.go @@ -86,6 +86,7 @@ type Container struct { eventDispatcher *services.EventDispatcher logger telemetry.Logger attachmentRepository repositories.AttachmentRepository + userRistrettoCache *ristretto.Cache[string, entities.AuthContext] } // NewLiteContainer creates a Container without any routes or listeners @@ -1730,8 +1731,11 @@ func (container *Container) PhoneRistrettoCache() (cache *ristretto.Cache[string } // UserRistrettoCache creates an in-memory *ristretto.Cache[string, entities.AuthContext] -func (container *Container) UserRistrettoCache() (cache *ristretto.Cache[string, entities.AuthContext]) { - container.logger.Debug(fmt.Sprintf("creating %T", cache)) +func (container *Container) UserRistrettoCache() *ristretto.Cache[string, entities.AuthContext] { + if container.userRistrettoCache != nil { + return container.userRistrettoCache + } + container.logger.Debug(fmt.Sprintf("creating %T", container.userRistrettoCache)) ristrettoCache, err := ristretto.NewCache[string, entities.AuthContext](&ristretto.Config[string, entities.AuthContext]{ MaxCost: 5000, NumCounters: 5000 * 10, @@ -1740,6 +1744,7 @@ func (container *Container) UserRistrettoCache() (cache *ristretto.Cache[string, if err != nil { container.logger.Fatal(stacktrace.Propagate(err, "cannot create user ristretto cache")) } + container.userRistrettoCache = ristrettoCache return ristrettoCache } diff --git a/api/pkg/repositories/gorm_user_repository.go b/api/pkg/repositories/gorm_user_repository.go index e31b2848..a64e8ae0 100644 --- a/api/pkg/repositories/gorm_user_repository.go +++ b/api/pkg/repositories/gorm_user_repository.go @@ -65,8 +65,13 @@ func (repository *gormUserRepository) RotateAPIKey(ctx context.Context, userID e } user := new(entities.User) + var oldAPIKey string err = crdbgorm.ExecuteTx(ctx, repository.db, nil, func(tx *gorm.DB) error { + if err := tx.WithContext(ctx).Where("id = ?", userID).First(user).Error; err != nil { + return err + } + oldAPIKey = user.APIKey return tx.WithContext(ctx).Model(user). Clauses(clause.Returning{}). Where("id = ?", userID). @@ -78,6 +83,13 @@ func (repository *gormUserRepository) RotateAPIKey(ctx context.Context, userID e return nil, repository.tracer.WrapErrorSpan(span, stacktrace.PropagateWithCode(err, ErrCodeNotFound, msg)) } + if err == nil && oldAPIKey != "" { + // Flush pending ristretto Set operations before Del to avoid a + // buffered Set re-adding the entry after removal. + repository.cache.Wait() + repository.cache.Del(oldAPIKey) + } + return user, nil } diff --git a/tests/integration_test.go b/tests/integration_test.go index dc8c933c..5aae58ff 100644 --- a/tests/integration_test.go +++ b/tests/integration_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "net/http" "strings" @@ -213,6 +214,90 @@ func TestSendSMS_RateLimit(t *testing.T) { } } +func TestRotateAPIKey_InvalidatesCache(t *testing.T) { + ctx := context.Background() + + // Use a dedicated test user so we don't mutate the shared userAPIKey + rotateUserAPIKey := "rotate-test-api-key" + rotateUserID := "rotate-test-user-id" + + // 1) Confirm the dedicated user's API key works and warm the cache + meURL := apiBaseURL + "/v1/users/me" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, meURL, nil) + require.NoError(t, err) + req.Header.Set("x-api-key", rotateUserAPIKey) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode, "initial auth failed: %s", string(body)) + + // Parse the current API key from the response + var meResp struct { + Data struct { + ID string `json:"id"` + APIKey string `json:"api_key"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(body, &meResp)) + require.Equal(t, rotateUserID, meResp.Data.ID) + oldAPIKey := meResp.Data.APIKey + require.NotEmpty(t, oldAPIKey) + t.Logf("user ID: %s, old API key prefix: %s...", rotateUserID, oldAPIKey[:10]) + + // 2) Rotate the API key + rotateURL := fmt.Sprintf("%s/v1/users/%s/api-keys", apiBaseURL, rotateUserID) + req, err = http.NewRequestWithContext(ctx, http.MethodDelete, rotateURL, nil) + require.NoError(t, err) + req.Header.Set("x-api-key", rotateUserAPIKey) + + resp, err = http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode, "rotate failed: %s", string(body)) + + // Parse new API key from rotate response + var rotateResp struct { + Data struct { + APIKey string `json:"api_key"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(body, &rotateResp)) + newAPIKey := rotateResp.Data.APIKey + require.NotEmpty(t, newAPIKey) + require.NotEqual(t, oldAPIKey, newAPIKey, "API key should have changed after rotation") + t.Logf("new API key prefix: %s...", newAPIKey[:10]) + + // 3) Old API key should immediately fail (401) — this is the bug regression check + req, err = http.NewRequestWithContext(ctx, http.MethodGet, meURL, nil) + require.NoError(t, err) + req.Header.Set("x-api-key", oldAPIKey) + + resp, err = http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "old API key should return 401 after rotation") + + // 4) New API key should work + req, err = http.NewRequestWithContext(ctx, http.MethodGet, meURL, nil) + require.NoError(t, err) + req.Header.Set("x-api-key", newAPIKey) + + resp, err = http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode, "new API key should work: %s", string(body)) +} + func TestSendSMS_OutstandingFlow(t *testing.T) { ctx := context.Background() phone := setupPhone(ctx, t, 60) diff --git a/tests/seed.sql b/tests/seed.sql index 4ae41006..36d714d9 100644 --- a/tests/seed.sql +++ b/tests/seed.sql @@ -13,6 +13,18 @@ VALUES ( NOW() ) ON CONFLICT (id) DO NOTHING; +-- Test user for API key rotation tests (isolated to avoid mutating the shared test user) +INSERT INTO users (id, email, api_key, timezone, subscription_name, created_at, updated_at) +VALUES ( + 'rotate-test-user-id', + 'rotate-test@httpsms.com', + 'rotate-test-api-key', + 'UTC', + 'pro-monthly', + NOW(), + NOW() +) ON CONFLICT (id) DO NOTHING; + -- System user (for event queue auth) INSERT INTO users (id, email, api_key, timezone, subscription_name, created_at, updated_at) VALUES (