Skip to content

Commit 845b884

Browse files
torsmcopybara-github
authored andcommitted
Extract transaction logic in DisableBroadcasts to a helper function
PiperOrigin-RevId: 903820050
1 parent e02ffe4 commit 845b884

3 files changed

Lines changed: 138 additions & 126 deletions

File tree

fleetspeak/src/server/mysql/broadcaststore.go

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -163,52 +163,56 @@ func (d *Datastore) DisableBroadcasts(ctx context.Context, bIDs []ids.BroadcastI
163163
return nil
164164
}
165165
return d.runInTx(ctx, false, func(tx *sql.Tx) error {
166-
for _, bID := range bIDs {
167-
var totalSent, totalLimit uint64
168-
rs, err := tx.QueryContext(ctx, "SELECT sent, message_limit FROM broadcast_allocations WHERE broadcast_id = ?", bID.Bytes())
169-
if err != nil {
170-
return err
171-
}
172-
for rs.Next() {
173-
var sent, limit uint64
174-
if err := rs.Scan(&sent, &limit); err != nil {
175-
rs.Close()
176-
return err
177-
}
178-
totalSent += sent
179-
totalLimit += limit
180-
}
181-
rs.Close()
182-
if err := rs.Err(); err != nil {
183-
return err
184-
}
166+
return d.tryDisableBroadcasts(ctx, tx, bIDs)
167+
})
168+
}
185169

186-
if _, err := tx.ExecContext(ctx, "DELETE FROM broadcast_allocations WHERE broadcast_id = ?", bID.Bytes()); err != nil {
170+
func (d *Datastore) tryDisableBroadcasts(ctx context.Context, tx *sql.Tx, bIDs []ids.BroadcastID) error {
171+
for _, bID := range bIDs {
172+
var totalSent, totalLimit uint64
173+
rs, err := tx.QueryContext(ctx, "SELECT sent, message_limit FROM broadcast_allocations WHERE broadcast_id = ?", bID.Bytes())
174+
if err != nil {
175+
return err
176+
}
177+
for rs.Next() {
178+
var sent, limit uint64
179+
if err := rs.Scan(&sent, &limit); err != nil {
180+
rs.Close()
187181
return err
188182
}
183+
totalSent += sent
184+
totalLimit += limit
185+
}
186+
rs.Close()
187+
if err := rs.Err(); err != nil {
188+
return err
189+
}
189190

190-
var bSent, bAllocated uint64
191-
row := tx.QueryRowContext(ctx, "SELECT sent, allocated FROM broadcasts WHERE broadcast_id = ?", bID.Bytes())
192-
if err := row.Scan(&bSent, &bAllocated); err != nil {
193-
if err == sql.ErrNoRows {
194-
continue
195-
}
196-
return err
197-
}
191+
if _, err := tx.ExecContext(ctx, "DELETE FROM broadcast_allocations WHERE broadcast_id = ?", bID.Bytes()); err != nil {
192+
return err
193+
}
198194

199-
newAllocated := bAllocated
200-
if newAllocated >= totalLimit {
201-
newAllocated -= totalLimit
202-
} else {
203-
newAllocated = 0
195+
var bSent, bAllocated uint64
196+
row := tx.QueryRowContext(ctx, "SELECT sent, allocated FROM broadcasts WHERE broadcast_id = ?", bID.Bytes())
197+
if err := row.Scan(&bSent, &bAllocated); err != nil {
198+
if err == sql.ErrNoRows {
199+
continue
204200
}
201+
return err
202+
}
205203

206-
if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET message_limit = 0, sent = ?, allocated = ? WHERE broadcast_id = ?", bSent+totalSent, newAllocated, bID.Bytes()); err != nil {
207-
return err
208-
}
204+
newAllocated := bAllocated
205+
if newAllocated >= totalLimit {
206+
newAllocated -= totalLimit
207+
} else {
208+
newAllocated = 0
209209
}
210-
return nil
211-
})
210+
211+
if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET message_limit = 0, sent = ?, allocated = ? WHERE broadcast_id = ?", bSent+totalSent, newAllocated, bID.Bytes()); err != nil {
212+
return err
213+
}
214+
}
215+
return nil
212216
}
213217

214218
// SaveBroadcastMessage implements db.BroadcastStore.

fleetspeak/src/server/spanner/broadcaststore.go

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -81,67 +81,71 @@ func (d *Datastore) DisableBroadcasts(ctx context.Context, bIDs []ids.BroadcastI
8181
return nil
8282
}
8383
_, err := d.dbClient.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
84-
for _, bID := range bIDs {
85-
var totalSent, totalLimit int64
86-
var allocKeys []spanner.Key
87-
88-
// Read active allocations to sum up their sent/limit stats before deleting
89-
// them. This allows us to attribute any messages sent during this window
90-
// to the parent broadcast's count and maintain the b.Allocated budget invariant.
91-
iter := txn.Read(ctx, d.broadcastAllocations, spanner.Key{bID.Bytes()}.AsPrefix(), []string{"AllocationID", "Sent", "MessageLimit"})
92-
for {
93-
row, err := iter.Next()
94-
if err == iterator.Done {
95-
break
96-
}
97-
if err != nil {
98-
iter.Stop()
99-
return err
100-
}
101-
var aid []byte
102-
var sent, limit int64
103-
if err := row.Columns(&aid, &sent, &limit); err != nil {
104-
iter.Stop()
105-
return err
106-
}
107-
totalSent += sent
108-
totalLimit += limit
109-
allocKeys = append(allocKeys, spanner.Key{bID.Bytes(), aid})
110-
}
111-
iter.Stop()
84+
return d.tryDisableBroadcasts(ctx, txn, bIDs)
85+
})
86+
return err
87+
}
11288

113-
row, err := txn.ReadRow(ctx, d.broadcasts, spanner.Key{bID.Bytes()}, []string{"Sent", "Allocated"})
89+
func (d *Datastore) tryDisableBroadcasts(ctx context.Context, txn *spanner.ReadWriteTransaction, bIDs []ids.BroadcastID) error {
90+
for _, bID := range bIDs {
91+
var totalSent, totalLimit int64
92+
var allocKeys []spanner.Key
93+
94+
// Read active allocations to sum up their sent/limit stats before deleting
95+
// them. This allows us to attribute any messages sent during this window
96+
// to the parent broadcast's count and maintain the b.Allocated budget invariant.
97+
iter := txn.Read(ctx, d.broadcastAllocations, spanner.Key{bID.Bytes()}.AsPrefix(), []string{"AllocationID", "Sent", "MessageLimit"})
98+
for {
99+
row, err := iter.Next()
100+
if err == iterator.Done {
101+
break
102+
}
114103
if err != nil {
115-
if spanner.ErrCode(err) == codes.NotFound {
116-
continue
117-
}
104+
iter.Stop()
118105
return err
119106
}
120-
var bSent, allocated int64
121-
if err := row.Columns(&bSent, &allocated); err != nil {
107+
var aid []byte
108+
var sent, limit int64
109+
if err := row.Columns(&aid, &sent, &limit); err != nil {
110+
iter.Stop()
122111
return err
123112
}
113+
totalSent += sent
114+
totalLimit += limit
115+
allocKeys = append(allocKeys, spanner.Key{bID.Bytes(), aid})
116+
}
117+
iter.Stop()
124118

125-
newAllocated := allocated
126-
if newAllocated >= totalLimit {
127-
newAllocated -= totalLimit
128-
} else {
129-
newAllocated = 0
119+
row, err := txn.ReadRow(ctx, d.broadcasts, spanner.Key{bID.Bytes()}, []string{"Sent", "Allocated"})
120+
if err != nil {
121+
if spanner.ErrCode(err) == codes.NotFound {
122+
continue
130123
}
124+
return err
125+
}
126+
var bSent, allocated int64
127+
if err := row.Columns(&bSent, &allocated); err != nil {
128+
return err
129+
}
131130

132-
var ms []*spanner.Mutation
133-
if len(allocKeys) > 0 {
134-
ms = append(ms, spanner.Delete(d.broadcastAllocations, spanner.KeySetFromKeys(allocKeys...)))
135-
}
136-
ms = append(ms, spanner.Update(d.broadcasts, []string{"BroadcastID", "MessageLimit", "Sent", "Allocated"}, []any{bID.Bytes(), int64(0), bSent + totalSent, newAllocated}))
131+
newAllocated := allocated
132+
if newAllocated >= totalLimit {
133+
newAllocated -= totalLimit
134+
} else {
135+
newAllocated = 0
136+
}
137137

138-
if err := txn.BufferWrite(ms); err != nil {
139-
return err
140-
}
138+
var ms []*spanner.Mutation
139+
if len(allocKeys) > 0 {
140+
ms = append(ms, spanner.Delete(d.broadcastAllocations, spanner.KeySetFromKeys(allocKeys...)))
141141
}
142-
return nil
143-
})
144-
return err
142+
ms = append(ms, spanner.Update(d.broadcasts, []string{"BroadcastID", "MessageLimit", "Sent", "Allocated"}, []any{bID.Bytes(), int64(0), bSent + totalSent, newAllocated}))
143+
144+
if err := txn.BufferWrite(ms); err != nil {
145+
return err
146+
}
147+
}
148+
return nil
145149
}
146150

147151
// SaveBroadcastMessage implements db.BroadcastStore.

fleetspeak/src/server/sqlite/broadcaststore.go

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -168,52 +168,56 @@ func (d *Datastore) DisableBroadcasts(ctx context.Context, bIDs []ids.BroadcastI
168168
d.l.Lock()
169169
defer d.l.Unlock()
170170
return d.runInTx(func(tx *sql.Tx) error {
171-
for _, bID := range bIDs {
172-
var totalSent, totalLimit uint64
173-
rs, err := tx.QueryContext(ctx, "SELECT sent, message_limit FROM broadcast_allocations WHERE broadcast_id = ?", bID.String())
174-
if err != nil {
175-
return err
176-
}
177-
for rs.Next() {
178-
var sent, limit uint64
179-
if err := rs.Scan(&sent, &limit); err != nil {
180-
rs.Close()
181-
return err
182-
}
183-
totalSent += sent
184-
totalLimit += limit
185-
}
186-
rs.Close()
187-
if err := rs.Err(); err != nil {
188-
return err
189-
}
171+
return d.tryDisableBroadcasts(ctx, tx, bIDs)
172+
})
173+
}
190174

191-
if _, err := tx.ExecContext(ctx, "DELETE FROM broadcast_allocations WHERE broadcast_id = ?", bID.String()); err != nil {
175+
func (d *Datastore) tryDisableBroadcasts(ctx context.Context, tx *sql.Tx, bIDs []ids.BroadcastID) error {
176+
for _, bID := range bIDs {
177+
var totalSent, totalLimit uint64
178+
rs, err := tx.QueryContext(ctx, "SELECT sent, message_limit FROM broadcast_allocations WHERE broadcast_id = ?", bID.String())
179+
if err != nil {
180+
return err
181+
}
182+
for rs.Next() {
183+
var sent, limit uint64
184+
if err := rs.Scan(&sent, &limit); err != nil {
185+
rs.Close()
192186
return err
193187
}
188+
totalSent += sent
189+
totalLimit += limit
190+
}
191+
rs.Close()
192+
if err := rs.Err(); err != nil {
193+
return err
194+
}
194195

195-
var bSent, bAllocated uint64
196-
row := tx.QueryRowContext(ctx, "SELECT sent, allocated FROM broadcasts WHERE broadcast_id = ?", bID.String())
197-
if err := row.Scan(&bSent, &bAllocated); err != nil {
198-
if err == sql.ErrNoRows {
199-
continue
200-
}
201-
return err
202-
}
196+
if _, err := tx.ExecContext(ctx, "DELETE FROM broadcast_allocations WHERE broadcast_id = ?", bID.String()); err != nil {
197+
return err
198+
}
203199

204-
newAllocated := bAllocated
205-
if newAllocated >= totalLimit {
206-
newAllocated -= totalLimit
207-
} else {
208-
newAllocated = 0
200+
var bSent, bAllocated uint64
201+
row := tx.QueryRowContext(ctx, "SELECT sent, allocated FROM broadcasts WHERE broadcast_id = ?", bID.String())
202+
if err := row.Scan(&bSent, &bAllocated); err != nil {
203+
if err == sql.ErrNoRows {
204+
continue
209205
}
206+
return err
207+
}
210208

211-
if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET message_limit = 0, sent = ?, allocated = ? WHERE broadcast_id = ?", bSent+totalSent, newAllocated, bID.String()); err != nil {
212-
return err
213-
}
209+
newAllocated := bAllocated
210+
if newAllocated >= totalLimit {
211+
newAllocated -= totalLimit
212+
} else {
213+
newAllocated = 0
214214
}
215-
return nil
216-
})
215+
216+
if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET message_limit = 0, sent = ?, allocated = ? WHERE broadcast_id = ?", bSent+totalSent, newAllocated, bID.String()); err != nil {
217+
return err
218+
}
219+
}
220+
return nil
217221
}
218222

219223
func (d *Datastore) SaveBroadcastMessage(ctx context.Context, msg *fspb.Message, bID ids.BroadcastID, cID common.ClientID, aID ids.AllocationID) error {

0 commit comments

Comments
 (0)