diff --git a/pkg/api/resource.go b/pkg/api/resource.go new file mode 100644 index 00000000..96c18f6e --- /dev/null +++ b/pkg/api/resource.go @@ -0,0 +1,103 @@ +package api + +import ( + "fmt" + "time" + + "gorm.io/datatypes" + "gorm.io/gorm" + + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/registry" +) + +// Resource is the generic GORM model for entity types managed by the entity +// registry (Channel, Version, WIF Config, etc.). Entity kinds are +// differentiated by the Kind field. Existing Cluster and NodePool types +// are NOT migrated to this model. +type Resource struct { + Meta + Kind string `json:"kind" gorm:"size:100;not null"` + Name string `json:"name" gorm:"size:100;not null"` + Href string `json:"href,omitempty" gorm:"size:500"` + CreatedBy string `json:"created_by" gorm:"size:255;not null"` + UpdatedBy string `json:"updated_by" gorm:"size:255;not null"` + DeletedBy *string `json:"deleted_by,omitempty" gorm:"size:255"` + DeletedTime *time.Time `json:"deleted_time,omitempty"` + OwnerID *string `json:"owner_id,omitempty" gorm:"size:255"` + OwnerKind *string `json:"owner_kind,omitempty" gorm:"size:100"` + OwnerHref *string `json:"owner_href,omitempty" gorm:"size:500"` + Spec datatypes.JSON `json:"spec" gorm:"type:jsonb;not null"` + Labels datatypes.JSON `json:"labels,omitempty" gorm:"type:jsonb"` + Generation int32 `json:"generation" gorm:"default:1;not null"` +} + +type ( + ResourceList []*Resource + ResourceIndex map[string]*Resource +) + +func (l ResourceList) Index() ResourceIndex { + index := ResourceIndex{} + for _, o := range l { + index[o.ID] = o + } + return index +} + +func (r Resource) TableName() string { + return "resources" +} + +// BeforeCreate TODO: Validate the necessity for this as part of https://redhat.atlassian.net/browse/HYPERFLEET-1085 +func (r *Resource) BeforeCreate(tx *gorm.DB) error { + if r.ID == "" { + id, err := NewID() + if err != nil { + return fmt.Errorf("failed to generate resource ID: %w", err) + } + r.ID = id + } + + now := time.Now() + if r.CreatedTime.IsZero() { + r.CreatedTime = now + } + r.UpdatedTime = now + if r.Generation == 0 { + r.Generation = 1 + } + + if r.Href == "" { + desc := registry.MustGet(r.Kind) + if r.OwnerID != nil && *r.OwnerID != "" { + if r.OwnerKind == nil || *r.OwnerKind == "" { + return fmt.Errorf("owner_kind is required when owner_id is set") + } + if r.OwnerHref == nil { + parentDesc := registry.MustGet(*r.OwnerKind) + ownerHref := fmt.Sprintf("/api/hyperfleet/v1/%s/%s", + parentDesc.Plural, *r.OwnerID) + r.OwnerHref = &ownerHref + } + r.Href = fmt.Sprintf("%s/%s/%s", *r.OwnerHref, desc.Plural, r.ID) + } else { + r.Href = fmt.Sprintf("/api/hyperfleet/v1/%s/%s", desc.Plural, r.ID) + } + } + + return nil +} + +func (r *Resource) BeforeUpdate(tx *gorm.DB) error { + r.UpdatedTime = time.Now() + return nil +} + +func (r *Resource) MarkDeleted(by string, t time.Time) { + r.DeletedTime = &t + r.DeletedBy = &by +} + +func (r *Resource) IncrementGeneration() { + r.Generation++ +} diff --git a/pkg/api/resource_test.go b/pkg/api/resource_test.go new file mode 100644 index 00000000..6a428763 --- /dev/null +++ b/pkg/api/resource_test.go @@ -0,0 +1,258 @@ +package api + +import ( + "testing" + "time" + + . "github.com/onsi/gomega" + + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/registry" +) + +func setupTestRegistry() { + registry.Reset() + registry.Register(registry.EntityDescriptor{ + Kind: "Channel", + Plural: "channels", + }) + registry.Register(registry.EntityDescriptor{ + Kind: "Version", + Plural: "versions", + ParentKind: "Channel", + }) +} + +func strPtr(s string) *string { + return &s +} + +func TestResourceList_Index(t *testing.T) { + RegisterTestingT(t) + + emptyList := ResourceList{} + emptyIndex := emptyList.Index() + Expect(len(emptyIndex)).To(Equal(0)) + + r1 := &Resource{} + r1.ID = "res-1" + r1.Name = "test-resource-1" + + r2 := &Resource{} + r2.ID = "res-2" + r2.Name = "test-resource-2" + + multiList := ResourceList{r1, r2} + multiIndex := multiList.Index() + Expect(len(multiIndex)).To(Equal(2)) + Expect(multiIndex["res-1"]).To(Equal(r1)) + Expect(multiIndex["res-2"]).To(Equal(r2)) + + r1Dup := &Resource{} + r1Dup.ID = "res-1" + r1Dup.Name = "duplicate" + + dupList := ResourceList{r1, r1Dup} + dupIndex := dupList.Index() + Expect(len(dupIndex)).To(Equal(1)) + Expect(dupIndex["res-1"].Name).To(Equal("duplicate")) +} + +func TestResource_BeforeCreate_IDGeneration(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + r := &Resource{Name: "test", Kind: "Channel"} + + err := r.BeforeCreate(nil) + Expect(err).To(BeNil()) + Expect(r.ID).ToNot(BeEmpty()) +} + +func TestResource_BeforeCreate_IDPreservation(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + r := &Resource{Name: "test", Kind: "Channel"} + r.ID = "pre-set-id" + + err := r.BeforeCreate(nil) + Expect(err).To(BeNil()) + Expect(r.ID).To(Equal("pre-set-id")) +} + +func TestResource_BeforeCreate_GenerationDefault(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + r := &Resource{Name: "test", Kind: "Channel"} + + err := r.BeforeCreate(nil) + Expect(err).To(BeNil()) + Expect(r.Generation).To(Equal(int32(1))) +} + +func TestResource_BeforeCreate_GenerationPreserved(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + r := &Resource{Name: "test", Kind: "Channel", Generation: 5} + + err := r.BeforeCreate(nil) + Expect(err).To(BeNil()) + Expect(r.Generation).To(Equal(int32(5))) +} + +func TestResource_BeforeCreate_Timestamps(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + before := time.Now() + r := &Resource{Name: "test", Kind: "Channel"} + + err := r.BeforeCreate(nil) + Expect(err).To(BeNil()) + + Expect(r.CreatedTime).ToNot(BeZero()) + Expect(r.UpdatedTime).ToNot(BeZero()) + Expect(r.CreatedTime.After(before) || r.CreatedTime.Equal(before)).To(BeTrue()) + Expect(r.UpdatedTime.After(before) || r.UpdatedTime.Equal(before)).To(BeTrue()) +} + +func TestResource_BeforeCreate_CreatedTimePreserved(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + fixedTime := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + r := &Resource{Name: "test", Kind: "Channel"} + r.CreatedTime = fixedTime + + err := r.BeforeCreate(nil) + Expect(err).To(BeNil()) + Expect(r.CreatedTime).To(Equal(fixedTime)) +} + +func TestResource_BeforeCreate_HrefTopLevel(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + r := &Resource{Name: "stable", Kind: "Channel"} + err := r.BeforeCreate(nil) + Expect(err).To(BeNil()) + Expect(r.Href).To(Equal("/api/hyperfleet/v1/channels/" + r.ID)) +} + +func TestResource_BeforeCreate_HrefChild(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + r := &Resource{ + Name: "4-17-12", + Kind: "Version", + OwnerID: strPtr("ch-1"), + OwnerKind: strPtr("Channel"), + } + err := r.BeforeCreate(nil) + Expect(err).To(BeNil()) + Expect(r.Href).To(Equal("/api/hyperfleet/v1/channels/ch-1/versions/" + r.ID)) + Expect(*r.OwnerHref).To(Equal("/api/hyperfleet/v1/channels/ch-1")) +} + +func TestResource_BeforeCreate_OwnerKindMissing(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + r := &Resource{ + Name: "4-17-12", + Kind: "Version", + OwnerID: strPtr("ch-1"), + } + err := r.BeforeCreate(nil) + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("owner_kind is required")) +} + +func TestResource_BeforeCreate_OwnerKindEmpty(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + r := &Resource{ + Name: "4-17-12", + Kind: "Version", + OwnerID: strPtr("ch-1"), + OwnerKind: strPtr(""), + } + err := r.BeforeCreate(nil) + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("owner_kind is required")) +} + +func TestResource_BeforeCreate_HrefChildWithPresetOwnerHref(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + r := &Resource{ + Name: "some-label", + Kind: "Version", + OwnerID: strPtr("v-1"), + OwnerKind: strPtr("Version"), + OwnerHref: strPtr("/api/hyperfleet/v1/channels/ch-1/versions/v-1"), + } + err := r.BeforeCreate(nil) + Expect(err).To(BeNil()) + Expect(r.Href).To(Equal("/api/hyperfleet/v1/channels/ch-1/versions/v-1/versions/" + r.ID)) + Expect(*r.OwnerHref).To(Equal("/api/hyperfleet/v1/channels/ch-1/versions/v-1")) +} + +func TestResource_BeforeCreate_HrefPreserved(t *testing.T) { + RegisterTestingT(t) + setupTestRegistry() + + r := &Resource{Name: "test", Kind: "Channel", Href: "/custom/href"} + err := r.BeforeCreate(nil) + Expect(err).To(BeNil()) + Expect(r.Href).To(Equal("/custom/href")) +} + +func TestResource_BeforeUpdate_UpdatesTimestamp(t *testing.T) { + RegisterTestingT(t) + + r := &Resource{Name: "test", Kind: "Channel"} + r.UpdatedTime = time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + before := time.Now() + err := r.BeforeUpdate(nil) + Expect(err).To(BeNil()) + Expect(r.UpdatedTime.After(before) || r.UpdatedTime.Equal(before)).To(BeTrue()) +} + +func TestResource_MarkDeleted(t *testing.T) { + RegisterTestingT(t) + + r := &Resource{Name: "test", Kind: "Channel"} + now := time.Now() + + r.MarkDeleted("admin", now) + + Expect(r.DeletedTime).ToNot(BeNil()) + Expect(*r.DeletedTime).To(Equal(now)) + Expect(r.DeletedBy).ToNot(BeNil()) + Expect(*r.DeletedBy).To(Equal("admin")) +} + +func TestResource_IncrementGeneration(t *testing.T) { + RegisterTestingT(t) + + r := &Resource{Name: "test", Kind: "Channel", Generation: 1} + r.IncrementGeneration() + Expect(r.Generation).To(Equal(int32(2))) + + r.IncrementGeneration() + Expect(r.Generation).To(Equal(int32(3))) +} + +func TestResource_TableName(t *testing.T) { + RegisterTestingT(t) + + r := Resource{} + Expect(r.TableName()).To(Equal("resources")) +} diff --git a/pkg/dao/mocks/resource.go b/pkg/dao/mocks/resource.go new file mode 100644 index 00000000..20d4e839 --- /dev/null +++ b/pkg/dao/mocks/resource.go @@ -0,0 +1,108 @@ +package mocks + +import ( + "context" + + "gorm.io/gorm" + + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/api" + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/dao" +) + +var _ dao.ResourceDao = &resourceDaoMock{} + +type resourceDaoMock struct { + resources api.ResourceList +} + +func (d *resourceDaoMock) Get(_ context.Context, kind, id string) (*api.Resource, error) { + for _, r := range d.resources { + if r.ID == id && r.Kind == kind { + return r, nil + } + } + return nil, gorm.ErrRecordNotFound +} + +func (d *resourceDaoMock) GetForUpdate(ctx context.Context, kind, id string) (*api.Resource, error) { + return d.Get(ctx, kind, id) +} + +func (d *resourceDaoMock) GetByOwner(_ context.Context, kind, id, ownerID string) (*api.Resource, error) { + for _, r := range d.resources { + if r.ID == id && r.Kind == kind && r.OwnerID != nil && *r.OwnerID == ownerID { + return r, nil + } + } + return nil, gorm.ErrRecordNotFound +} + +func (d *resourceDaoMock) Create(_ context.Context, resource *api.Resource) (*api.Resource, error) { + d.resources = append(d.resources, resource) + return resource, nil +} + +func (d *resourceDaoMock) Save(_ context.Context, resource *api.Resource) error { + for i, r := range d.resources { + if r.ID == resource.ID { + d.resources[i] = resource + return nil + } + } + d.resources = append(d.resources, resource) + return nil +} + +func (d *resourceDaoMock) Delete(_ context.Context, kind, id string) error { + for i, r := range d.resources { + if r.ID == id && r.Kind == kind { + d.resources = append(d.resources[:i], d.resources[i+1:]...) + return nil + } + } + return nil +} + +func (d *resourceDaoMock) CountByOwner(_ context.Context, kind, ownerID string) (int64, error) { + var count int64 + for _, r := range d.resources { + if r.Kind == kind && r.OwnerID != nil && *r.OwnerID == ownerID { + count++ + } + } + return count, nil +} + +func (d *resourceDaoMock) FindByType(_ context.Context, kind string) (api.ResourceList, error) { + var result api.ResourceList + for _, r := range d.resources { + if r.Kind == kind { + result = append(result, r) + } + } + return result, nil +} + +func (d *resourceDaoMock) FindByTypeAndOwner(_ context.Context, kind, ownerID string) (api.ResourceList, error) { + var result api.ResourceList + for _, r := range d.resources { + if r.Kind == kind && r.OwnerID != nil && *r.OwnerID == ownerID { + result = append(result, r) + } + } + return result, nil +} + +func (d *resourceDaoMock) FindByIDs(_ context.Context, kind string, ids []string) (api.ResourceList, error) { + idSet := make(map[string]bool, len(ids)) + for _, id := range ids { + idSet[id] = true + } + var result api.ResourceList + for _, r := range d.resources { + if r.Kind == kind && idSet[r.ID] { + result = append(result, r) + } + } + return result, nil +} diff --git a/pkg/dao/resource.go b/pkg/dao/resource.go new file mode 100644 index 00000000..7ffea32b --- /dev/null +++ b/pkg/dao/resource.go @@ -0,0 +1,137 @@ +package dao + +import ( + "context" + "fmt" + + "gorm.io/gorm/clause" + + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/api" + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/db" +) + +type ResourceDao interface { + Get(ctx context.Context, kind, id string) (*api.Resource, error) + GetForUpdate(ctx context.Context, kind, id string) (*api.Resource, error) + GetByOwner(ctx context.Context, kind, id, ownerID string) (*api.Resource, error) + Create(ctx context.Context, resource *api.Resource) (*api.Resource, error) + Save(ctx context.Context, resource *api.Resource) error + Delete(ctx context.Context, kind, id string) error + CountByOwner(ctx context.Context, kind, ownerID string) (int64, error) + FindByType(ctx context.Context, kind string) (api.ResourceList, error) + FindByTypeAndOwner(ctx context.Context, kind, ownerID string) (api.ResourceList, error) + FindByIDs(ctx context.Context, kind string, ids []string) (api.ResourceList, error) +} + +var _ ResourceDao = &sqlResourceDao{} + +type sqlResourceDao struct { + sessionFactory db.SessionFactory +} + +func NewResourceDao(sessionFactory db.SessionFactory) ResourceDao { + return &sqlResourceDao{sessionFactory: sessionFactory} +} + +func (d *sqlResourceDao) Get(ctx context.Context, kind, id string) (*api.Resource, error) { + g2 := d.sessionFactory.New(ctx) + var resource api.Resource + if err := g2.Take(&resource, "kind = ? AND id = ?", kind, id).Error; err != nil { + return nil, err + } + return &resource, nil +} + +func (d *sqlResourceDao) GetForUpdate(ctx context.Context, kind, id string) (*api.Resource, error) { + g2 := d.sessionFactory.New(ctx) + var resource api.Resource + if err := g2.Clauses(clause.Locking{Strength: "UPDATE"}).Take( + &resource, "kind = ? AND id = ?", kind, id).Error; err != nil { + return nil, err + } + return &resource, nil +} + +func (d *sqlResourceDao) GetByOwner(ctx context.Context, kind, id, ownerID string) (*api.Resource, error) { + g2 := d.sessionFactory.New(ctx) + var resource api.Resource + if err := g2.Take(&resource, "kind = ? AND id = ? AND owner_id = ?", kind, id, ownerID).Error; err != nil { + return nil, err + } + return &resource, nil +} + +func (d *sqlResourceDao) Create(ctx context.Context, resource *api.Resource) (*api.Resource, error) { + if resource.OwnerID != nil { + // If OwnerID is empty, convert to nil + if *resource.OwnerID == "" { + resource.OwnerID = nil + resource.OwnerKind = nil + resource.OwnerHref = nil + } else if resource.OwnerKind == nil || *resource.OwnerKind == "" { + return nil, fmt.Errorf("owner_kind is required when owner_id is set") + } + } + g2 := d.sessionFactory.New(ctx) + if err := g2.Omit(clause.Associations).Create(resource).Error; err != nil { + db.MarkForRollback(ctx, err) + return nil, err + } + return resource, nil +} + +func (d *sqlResourceDao) Save(ctx context.Context, resource *api.Resource) error { + g2 := d.sessionFactory.New(ctx) + if err := g2.Omit(clause.Associations).Save(resource).Error; err != nil { + db.MarkForRollback(ctx, err) + return err + } + return nil +} + +func (d *sqlResourceDao) Delete(ctx context.Context, kind, id string) error { + g2 := d.sessionFactory.New(ctx) + if err := g2.Omit(clause.Associations).Where("kind = ?", kind).Delete( + &api.Resource{Meta: api.Meta{ID: id}}).Error; err != nil { + db.MarkForRollback(ctx, err) + return err + } + return nil +} + +func (d *sqlResourceDao) CountByOwner(ctx context.Context, kind, ownerID string) (int64, error) { + g2 := d.sessionFactory.New(ctx) + var count int64 + if err := g2.Model(&api.Resource{}).Where( + "kind = ? AND owner_id = ?", kind, ownerID).Count(&count).Error; err != nil { + return 0, err + } + return count, nil +} + +func (d *sqlResourceDao) FindByType(ctx context.Context, kind string) (api.ResourceList, error) { + g2 := d.sessionFactory.New(ctx) + var resources api.ResourceList + if err := g2.Where("kind = ?", kind).Find(&resources).Error; err != nil { + return nil, err + } + return resources, nil +} + +func (d *sqlResourceDao) FindByTypeAndOwner(ctx context.Context, kind, ownerID string) (api.ResourceList, error) { + g2 := d.sessionFactory.New(ctx) + var resources api.ResourceList + if err := g2.Where("kind = ? AND owner_id = ?", kind, ownerID).Find(&resources).Error; err != nil { + return nil, err + } + return resources, nil +} + +func (d *sqlResourceDao) FindByIDs(ctx context.Context, kind string, ids []string) (api.ResourceList, error) { + g2 := d.sessionFactory.New(ctx) + var resources api.ResourceList + if err := g2.Where("kind = ? AND id in (?)", kind, ids).Find(&resources).Error; err != nil { + return nil, err + } + return resources, nil +} diff --git a/pkg/db/migrations/202605202128_add_resources.go b/pkg/db/migrations/202605202128_add_resources.go new file mode 100644 index 00000000..9719735e --- /dev/null +++ b/pkg/db/migrations/202605202128_add_resources.go @@ -0,0 +1,72 @@ +package migrations + +import ( + "github.com/go-gormigrate/gormigrate/v2" + "gorm.io/gorm" +) + +func addResources() *gormigrate.Migration { + return &gormigrate.Migration{ + ID: "202605202128", + Migrate: func(tx *gorm.DB) error { + if err := tx.Exec(`CREATE TABLE IF NOT EXISTS resources ( + id VARCHAR(255) PRIMARY KEY, + created_time TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_time TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_time TIMESTAMPTZ NULL, + kind VARCHAR(100) NOT NULL, + name VARCHAR(100) NOT NULL, + href VARCHAR(500), + created_by VARCHAR(255) NOT NULL, + updated_by VARCHAR(255) NOT NULL, + deleted_by VARCHAR(255) NULL, + owner_id VARCHAR(255) NULL, + owner_kind VARCHAR(100) NULL, + owner_href VARCHAR(500) NULL, + spec JSONB NOT NULL, + labels JSONB NULL, + generation INTEGER NOT NULL DEFAULT 1 + );`).Error; err != nil { + return err + } + + if err := tx.Exec( + "CREATE INDEX IF NOT EXISTS idx_resources_kind ON resources (kind);", + ).Error; err != nil { + return err + } + + if err := tx.Exec( + "CREATE UNIQUE INDEX IF NOT EXISTS idx_resources_kind_name " + + "ON resources (kind, name) " + + "WHERE owner_id IS NULL AND deleted_time IS NULL;", + ).Error; err != nil { + return err + } + + if err := tx.Exec( + "CREATE UNIQUE INDEX IF NOT EXISTS idx_resources_kind_owner_name " + + "ON resources (kind, owner_id, name) " + + "WHERE owner_id IS NOT NULL AND deleted_time IS NULL;", + ).Error; err != nil { + return err + } + + if err := tx.Exec( + "CREATE INDEX IF NOT EXISTS idx_resources_owner_id " + + "ON resources (owner_id) WHERE owner_id IS NOT NULL;", + ).Error; err != nil { + return err + } + + if err := tx.Exec( + "CREATE INDEX IF NOT EXISTS idx_resources_deleted_time " + + "ON resources (deleted_time) WHERE deleted_time IS NOT NULL;", + ).Error; err != nil { + return err + } + + return nil + }, + } +} diff --git a/pkg/db/migrations/migration_structs.go b/pkg/db/migrations/migration_structs.go index ff503558..c6a6c3f6 100755 --- a/pkg/db/migrations/migration_structs.go +++ b/pkg/db/migrations/migration_structs.go @@ -36,6 +36,7 @@ var MigrationList = []*gormigrate.Migration{ addReconciledIndex(), addNodePoolOwnerDeletedIndex(), addDeletedTimeIndexes(), + addResources(), } // Model represents the base model struct. All entities will have this struct embedded. diff --git a/pkg/db/sql_helpers.go b/pkg/db/sql_helpers.go index 9fad5ab6..6cc492e9 100755 --- a/pkg/db/sql_helpers.go +++ b/pkg/db/sql_helpers.go @@ -73,6 +73,24 @@ func getField(name string, disallowedFields map[string]string) (field string, er return } + // Map user-friendly spec.xxx syntax to JSONB query: spec->>'xxx' + if strings.HasPrefix(trimmedName, "spec.") { + if _, disallowed := disallowedFields["spec"]; disallowed { + err = errors.BadRequest("%s is not a valid field name", name) + return + } + + key := strings.TrimPrefix(trimmedName, "spec.") + + if validationErr := validateLabelKey(key); validationErr != nil { + err = validationErr + return + } + + field = fmt.Sprintf("spec->>'%s'", key) + return + } + // Map user-friendly labels.xxx syntax to JSONB query: labels->>'xxx' if strings.HasPrefix(trimmedName, "labels.") { key := strings.TrimPrefix(trimmedName, "labels.") diff --git a/pkg/db/sql_helpers_test.go b/pkg/db/sql_helpers_test.go index 9f96970e..a760ea61 100644 --- a/pkg/db/sql_helpers_test.go +++ b/pkg/db/sql_helpers_test.go @@ -652,6 +652,75 @@ func TestConditionTypeValidation(t *testing.T) { } } +func TestGetField_SpecMapping(t *testing.T) { + tests := []struct { + name string + input string + expected string + expectError bool + }{ + { + name: "valid snake_case key", + input: "spec.is_default", + expected: "spec->>'is_default'", + }, + { + name: "valid single word key", + input: "spec.region", + expected: "spec->>'region'", + }, + { + name: "valid key with digits", + input: "spec.release_image_v2", + expected: "spec->>'release_image_v2'", + }, + { + name: "invalid key with uppercase", + input: "spec.ReleaseImage", + expectError: true, + }, + { + name: "invalid key with hyphens", + input: "spec.release-image", + expectError: true, + }, + { + name: "empty key", + input: "spec.", + expectError: true, + }, + { + name: "injection attempt", + input: "spec.'; DROP TABLE resources;--", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + RegisterTestingT(t) + + field, err := getField(tt.input, map[string]string{}) + if tt.expectError { + Expect(err).ToNot(BeNil()) + } else { + Expect(err).To(BeNil()) + Expect(field).To(Equal(tt.expected)) + } + }) + } +} + +func TestGetField_SpecDisallowed(t *testing.T) { + RegisterTestingT(t) + + disallowed := map[string]string{"spec": "spec"} + + _, err := getField("spec.is_default", disallowed) + Expect(err).ToNot(BeNil()) + Expect(err.Reason).To(ContainSubstring("not a valid field name")) +} + func TestConditionStatusValidation(t *testing.T) { tests := []struct { status string diff --git a/pkg/registry/descriptor.go b/pkg/registry/descriptor.go new file mode 100644 index 00000000..fbbaeac1 --- /dev/null +++ b/pkg/registry/descriptor.go @@ -0,0 +1,22 @@ +package registry + +// OnParentDeletePolicy determines child behavior when its parent is deleted. +type OnParentDeletePolicy string + +const ( + OnParentDeleteRestrict OnParentDeletePolicy = "restrict" + OnParentDeleteCascade OnParentDeletePolicy = "cascade" +) + +// EntityDescriptor defines everything specific to a HyperFleet entity type. +// Descriptors are registered at startup via Register() in plugin init() functions. +type EntityDescriptor struct { + Kind string // discriminator value stored in Resource.Kind + Plural string // URL path segment, e.g. "channels" + ParentKind string // "" for top-level entities + SpecSchemaName string // OpenAPI component name for spec validation + OnParentDelete OnParentDeletePolicy // only meaningful when ParentKind != "" + SearchDisallowedFields []string // fields blocked from TSL search + NameMinLen int // minimum name length + NameMaxLen int // maximum name length +} diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go new file mode 100644 index 00000000..68fe8b4b --- /dev/null +++ b/pkg/registry/registry.go @@ -0,0 +1,122 @@ +package registry + +import ( + "fmt" + "sync" +) + +var ( + mu sync.RWMutex + descriptors = make(map[string]EntityDescriptor) +) + +// Register adds a descriptor to the global registry. Panics on empty Kind or duplicate Kind. +func Register(d EntityDescriptor) { + mu.Lock() + defer mu.Unlock() + + if d.Kind == "" { + panic("entity kind cannot be empty") + } + if _, exists := descriptors[d.Kind]; exists { + panic(fmt.Sprintf("entity kind %q already registered", d.Kind)) + } + descriptors[d.Kind] = d +} + +// Get returns a descriptor by Kind, or (zero, false) if not found. +func Get(entityKind string) (EntityDescriptor, bool) { + mu.RLock() + defer mu.RUnlock() + + d, ok := descriptors[entityKind] + return d, ok +} + +// MustGet returns a descriptor by Kind. Panics if not found. +func MustGet(entityKind string) EntityDescriptor { + d, ok := Get(entityKind) + if !ok { + panic(fmt.Sprintf("entity kind %q not registered", entityKind)) + } + return d +} + +// All returns a snapshot of all registered descriptors. +func All() []EntityDescriptor { + mu.RLock() + defer mu.RUnlock() + + result := make([]EntityDescriptor, 0, len(descriptors)) + for _, d := range descriptors { + result = append(result, d) + } + return result +} + +// ChildrenOf returns descriptors whose ParentKind matches the given kind. +func ChildrenOf(parentKind string) []EntityDescriptor { + mu.RLock() + defer mu.RUnlock() + + var children []EntityDescriptor + for _, d := range descriptors { + if d.ParentKind == parentKind { + children = append(children, d) + } + } + return children +} + +// Validate checks registry integrity. Panics on: +// - empty Kind or Plural on any descriptor +// - any ParentKind that references an unregistered kind +// - NameMinLen > NameMaxLen (when NameMaxLen is set) +// - duplicate Plural values across descriptors +func Validate() { + mu.RLock() + defer mu.RUnlock() + + plurals := make(map[string]string, len(descriptors)) + + for _, d := range descriptors { + if d.Kind == "" { + panic("entity kind cannot be empty") + } + if d.Plural == "" { + panic(fmt.Sprintf("entity kind %q has empty plural", d.Kind)) + } + + if d.ParentKind != "" { + if _, ok := descriptors[d.ParentKind]; !ok { + panic(fmt.Sprintf( + "entity kind %q references unregistered parent kind %q", + d.Kind, d.ParentKind, + )) + } + } + + if d.NameMaxLen > 0 && d.NameMinLen > d.NameMaxLen { + panic(fmt.Sprintf( + "entity kind %q has NameMinLen (%d) > NameMaxLen (%d)", + d.Kind, d.NameMinLen, d.NameMaxLen, + )) + } + + if existing, ok := plurals[d.Plural]; ok { + panic(fmt.Sprintf( + "duplicate plural %q: registered by both %q and %q", + d.Plural, existing, d.Kind, + )) + } + plurals[d.Plural] = d.Kind + } +} + +// Reset clears all registrations. Only for use in tests. +func Reset() { + mu.Lock() + defer mu.Unlock() + + descriptors = make(map[string]EntityDescriptor) +} diff --git a/pkg/registry/registry_test.go b/pkg/registry/registry_test.go new file mode 100644 index 00000000..9e94c9b7 --- /dev/null +++ b/pkg/registry/registry_test.go @@ -0,0 +1,197 @@ +package registry + +import ( + "testing" + + . "github.com/onsi/gomega" +) + +func TestRegister_Success(t *testing.T) { + RegisterTestingT(t) + Reset() + + d := EntityDescriptor{ + Kind: "Channel", + Plural: "channels", + NameMinLen: 3, + NameMaxLen: 53, + } + + Register(d) + + got, ok := Get("Channel") + Expect(ok).To(BeTrue()) + Expect(got.Kind).To(Equal("Channel")) + Expect(got.Plural).To(Equal("channels")) + Expect(got.NameMinLen).To(Equal(3)) + Expect(got.NameMaxLen).To(Equal(53)) +} + +func TestRegister_DuplicateKind_Panics(t *testing.T) { + RegisterTestingT(t) + Reset() + + Register(EntityDescriptor{Kind: "Channel", Plural: "channels"}) + + Expect(func() { + Register(EntityDescriptor{Kind: "Channel", Plural: "ch"}) + }).To(PanicWith(ContainSubstring("already registered"))) +} + +func TestRegister_EmptyKind_Panics(t *testing.T) { + RegisterTestingT(t) + Reset() + + Expect(func() { + Register(EntityDescriptor{Kind: "", Plural: "things"}) + }).To(PanicWith(ContainSubstring("entity kind cannot be empty"))) +} + +func TestGet_NotFound(t *testing.T) { + RegisterTestingT(t) + Reset() + + _, ok := Get("NonExistent") + Expect(ok).To(BeFalse()) +} + +func TestMustGet_Success(t *testing.T) { + RegisterTestingT(t) + Reset() + + Register(EntityDescriptor{Kind: "Channel", Plural: "channels"}) + + d := MustGet("Channel") + Expect(d.Kind).To(Equal("Channel")) +} + +func TestMustGet_NotFound_Panics(t *testing.T) { + RegisterTestingT(t) + Reset() + + Expect(func() { + MustGet("NonExistent") + }).To(PanicWith(ContainSubstring("not registered"))) +} + +func TestAll_ReturnsSnapshot(t *testing.T) { + RegisterTestingT(t) + Reset() + + Register(EntityDescriptor{Kind: "Channel", Plural: "channels"}) + Register(EntityDescriptor{Kind: "Version", Plural: "versions", ParentKind: "Channel"}) + + all := All() + Expect(all).To(HaveLen(2)) + + types := make(map[string]bool) + for _, d := range all { + types[d.Kind] = true + } + Expect(types).To(HaveKey("Channel")) + Expect(types).To(HaveKey("Version")) +} + +func TestChildrenOf(t *testing.T) { + RegisterTestingT(t) + Reset() + + Register(EntityDescriptor{Kind: "Channel", Plural: "channels"}) + Register(EntityDescriptor{Kind: "Version", Plural: "versions", ParentKind: "Channel"}) + Register(EntityDescriptor{Kind: "WifConfig", Plural: "wifconfigs"}) + + children := ChildrenOf("Channel") + Expect(children).To(HaveLen(1)) + Expect(children[0].Kind).To(Equal("Version")) + + children = ChildrenOf("WifConfig") + Expect(children).To(BeEmpty()) +} + +func TestValidate_MissingParent_Panics(t *testing.T) { + RegisterTestingT(t) + Reset() + + Register(EntityDescriptor{Kind: "Version", Plural: "versions", ParentKind: "Ghost"}) + + Expect(func() { + Validate() + }).To(PanicWith(ContainSubstring("unregistered parent kind"))) +} + +func TestValidate_DuplicatePlural_Panics(t *testing.T) { + RegisterTestingT(t) + Reset() + + Register(EntityDescriptor{Kind: "Channel", Plural: "things"}) + Register(EntityDescriptor{Kind: "Version", Plural: "things"}) + + Expect(func() { + Validate() + }).To(PanicWith(ContainSubstring("duplicate plural"))) +} + +func TestValidate_EmptyPlural_Panics(t *testing.T) { + RegisterTestingT(t) + Reset() + + Register(EntityDescriptor{Kind: "Channel", Plural: ""}) + + Expect(func() { + Validate() + }).To(PanicWith(ContainSubstring("has empty plural"))) +} + +func TestValidate_NameMinExceedsMax_Panics(t *testing.T) { + RegisterTestingT(t) + Reset() + + Register(EntityDescriptor{Kind: "Channel", Plural: "channels", NameMinLen: 100, NameMaxLen: 3}) + + Expect(func() { + Validate() + }).To(PanicWith(ContainSubstring("NameMinLen (100) > NameMaxLen (3)"))) +} + +func TestValidate_Success(t *testing.T) { + RegisterTestingT(t) + Reset() + + Register(EntityDescriptor{Kind: "Channel", Plural: "channels"}) + Register(EntityDescriptor{Kind: "Version", Plural: "versions", ParentKind: "Channel"}) + + Expect(func() { + Validate() + }).ToNot(Panic()) +} + +func TestValidate_EmptyRegistry(t *testing.T) { + RegisterTestingT(t) + Reset() + + Expect(func() { + Validate() + }).ToNot(Panic()) +} + +func TestDescriptorFields(t *testing.T) { + RegisterTestingT(t) + Reset() + + Register(EntityDescriptor{ + Kind: "Version", + Plural: "versions", + NameMinLen: 3, + NameMaxLen: 53, + ParentKind: "Channel", + OnParentDelete: OnParentDeleteRestrict, + SpecSchemaName: "VersionSpec", + SearchDisallowedFields: []string{"spec"}, + }) + + d := MustGet("Version") + Expect(d.ParentKind).To(Equal("Channel")) + Expect(d.OnParentDelete).To(Equal(OnParentDeleteRestrict)) + Expect(d.SpecSchemaName).To(Equal("VersionSpec")) + Expect(d.SearchDisallowedFields).To(ConsistOf("spec")) +}