|
1 | 1 | package data |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "fmt" |
| 5 | + "time" |
| 6 | + |
4 | 7 | "gorm.io/gorm" |
5 | 8 |
|
| 9 | + "github.com/infrahq/infra/internal/server/data/querybuilder" |
6 | 10 | "github.com/infrahq/infra/internal/server/models" |
7 | 11 | "github.com/infrahq/infra/uid" |
8 | 12 | ) |
9 | 13 |
|
10 | | -func CreateGroup(db GormTxn, group *models.Group) error { |
11 | | - return add(db, group) |
| 14 | +type groupsTable models.Group |
| 15 | + |
| 16 | +func (g groupsTable) Table() string { |
| 17 | + return "groups" |
| 18 | +} |
| 19 | + |
| 20 | +func (g groupsTable) Columns() []string { |
| 21 | + return []string{"created_at", "created_by", "created_by_provider", "deleted_at", "id", "name", "organization_id", "updated_at"} |
| 22 | +} |
| 23 | + |
| 24 | +func (g groupsTable) Values() []any { |
| 25 | + return []any{g.CreatedAt, g.CreatedBy, g.CreatedByProvider, g.DeletedAt, g.ID, g.Name, g.OrganizationID, g.UpdatedAt} |
| 26 | +} |
| 27 | + |
| 28 | +func (g *groupsTable) ScanFields() []any { |
| 29 | + return []any{&g.CreatedAt, &g.CreatedBy, &g.CreatedByProvider, &g.DeletedAt, &g.ID, &g.Name, &g.OrganizationID, &g.UpdatedAt} |
| 30 | +} |
| 31 | + |
| 32 | +func CreateGroup(tx WriteTxn, group *models.Group) error { |
| 33 | + return insert(tx, (*groupsTable)(group)) |
12 | 34 | } |
13 | 35 |
|
14 | 36 | func GetGroup(db GormTxn, selectors ...SelectorFunc) (*models.Group, error) { |
@@ -71,59 +93,49 @@ func groupIDsForUser(tx ReadTxn, userID uid.ID) ([]uid.ID, error) { |
71 | 93 | return result, rows.Err() |
72 | 94 | } |
73 | 95 |
|
74 | | -func DeleteGroups(db GormTxn, selectors ...SelectorFunc) error { |
75 | | - toDelete, err := ListGroups(db, nil, selectors...) |
| 96 | +func DeleteGroup(tx WriteTxn, id uid.ID) error { |
| 97 | + err := DeleteGrants(tx, DeleteGrantsOptions{BySubject: uid.NewGroupPolymorphicID(id)}) |
76 | 98 | if err != nil { |
77 | | - return err |
| 99 | + return fmt.Errorf("remove grants: %w", err) |
78 | 100 | } |
79 | 101 |
|
80 | | - ids := make([]uid.ID, 0) |
81 | | - for _, g := range toDelete { |
82 | | - ids = append(ids, g.ID) |
83 | | - |
84 | | - err := DeleteGrants(db, DeleteGrantsOptions{BySubject: g.PolyID()}) |
85 | | - if err != nil { |
86 | | - return err |
87 | | - } |
88 | | - |
89 | | - identities, err := ListIdentities(db, nil, []SelectorFunc{ByOptionalIdentityGroupID(g.ID)}...) |
90 | | - if err != nil { |
91 | | - return err |
92 | | - } |
93 | | - |
94 | | - var uidsToRemove []uid.ID |
95 | | - for _, id := range identities { |
96 | | - uidsToRemove = append(uidsToRemove, id.ID) |
97 | | - } |
98 | | - err = RemoveUsersFromGroup(db, g.ID, uidsToRemove) |
99 | | - if err != nil { |
100 | | - return err |
101 | | - } |
| 102 | + _, err = tx.Exec(`DELETE from identities_groups WHERE group_id = ?`, id) |
| 103 | + if err != nil { |
| 104 | + return fmt.Errorf("remove useres from group: %w", err) |
102 | 105 | } |
103 | 106 |
|
104 | | - return deleteAll[models.Group](db, ByIDs(ids)) |
| 107 | + stmt := ` |
| 108 | + UPDATE groups |
| 109 | + SET deleted_at = ? |
| 110 | + WHERE id = ? |
| 111 | + AND deleted_at is null |
| 112 | + AND organization_id = ?` |
| 113 | + _, err = tx.Exec(stmt, time.Now(), id, tx.OrganizationID()) |
| 114 | + return handleError(err) |
105 | 115 | } |
106 | 116 |
|
107 | | -func AddUsersToGroup(db GormTxn, groupID uid.ID, idsToAdd []uid.ID) error { |
108 | | - for _, id := range idsToAdd { |
109 | | - // This is effectively an "INSERT OR IGNORE" or "INSERT ... ON CONFLICT ... DO NOTHING" statement which |
110 | | - // works across both sqlite and postgres |
111 | | - _, err := db.Exec("INSERT INTO identities_groups (group_id, identity_id) SELECT ?, ? WHERE NOT EXISTS (SELECT 1 FROM identities_groups WHERE group_id = ? AND identity_id = ?)", groupID, id, groupID, id) |
112 | | - if err != nil { |
113 | | - return err |
| 117 | +func AddUsersToGroup(tx WriteTxn, groupID uid.ID, idsToAdd []uid.ID) error { |
| 118 | + query := querybuilder.New("INSERT INTO identities_groups(group_id, identity_id)") |
| 119 | + query.B("VALUES") |
| 120 | + for i, id := range idsToAdd { |
| 121 | + query.B("(?, ?)", groupID, id) |
| 122 | + if i+1 != len(idsToAdd) { |
| 123 | + query.B(",") |
114 | 124 | } |
115 | 125 | } |
116 | | - return nil |
| 126 | + query.B("ON CONFLICT DO NOTHING") |
| 127 | + |
| 128 | + _, err := tx.Exec(query.String(), query.Args...) |
| 129 | + return handleError(err) |
117 | 130 | } |
118 | 131 |
|
119 | | -func RemoveUsersFromGroup(db GormTxn, groupID uid.ID, idsToRemove []uid.ID) error { |
120 | | - for _, id := range idsToRemove { |
121 | | - _, err := db.Exec("DELETE FROM identities_groups WHERE identity_id = ? AND group_id = ?", id, groupID) |
122 | | - if err != nil { |
123 | | - return err |
124 | | - } |
125 | | - } |
126 | | - return nil |
| 132 | +// RemoveUsersFromGroup removes any user ID listed in idsToRemove from the group |
| 133 | +// with ID groupID. |
| 134 | +// Note that DeleteGroup also removes users from the group. |
| 135 | +func RemoveUsersFromGroup(tx WriteTxn, groupID uid.ID, idsToRemove []uid.ID) error { |
| 136 | + stmt := `DELETE FROM identities_groups WHERE group_id = ? AND identity_id IN (?)` |
| 137 | + _, err := tx.Exec(stmt, groupID, idsToRemove) |
| 138 | + return handleError(err) |
127 | 139 | } |
128 | 140 |
|
129 | 141 | // TODO: do this with a join in ListGroups and GetGroup |
|
0 commit comments