Skip to content

Commit 57d2d27

Browse files
muirCopilot
andauthored
[fix] support delimiter (#331)
Turns out that supporting delimiter goes beyond the change in sqltoken. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 152db66 commit 57d2d27

8 files changed

Lines changed: 132 additions & 9 deletions

File tree

api.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ type MigrationBase struct {
7474
nonTransactional bool // set automatically or by ForceNonTransactional / inference
7575
forcedTx *bool // if not nil, explicitly chosen transactional mode (true=transactional, false=non-transactional)
7676
notes map[string]any
77+
preserveComments bool
7778
}
7879

7980
func (m MigrationBase) Copy() MigrationBase {
@@ -328,6 +329,17 @@ func ForceTransactional() MigrationOption {
328329
}
329330
}
330331

332+
// PreserveComments prevents stripping of SQL comments before execution.
333+
// This is primarily useful for testing scenarios where comment-only
334+
// statements are needed to exercise specific code paths.
335+
// PreserveComments can break DELIMITER handling so do not use in conjunction
336+
// with SQL that includes DELIMITERs.
337+
func PreserveComments() MigrationOption {
338+
return func(m Migration) {
339+
m.Base().preserveComments = true
340+
}
341+
}
342+
331343
// ApplyForceOverride overrides transactionality for any prior force call (ForceTransactional
332344
// or ForceNonTransactional)
333345
func (m *MigrationBase) ApplyForceOverride() {
@@ -416,6 +428,9 @@ func (m *MigrationBase) ForcedTransactional() bool { return m.forcedTx != nil &&
416428
// ForcedNonTransactional reports if ForceNonTransactional() was explicitly called.
417429
func (m *MigrationBase) ForcedNonTransactional() bool { return m.forcedTx != nil && !*m.forcedTx }
418430

431+
// PreserveComments reports if PreserveComments() was set on this migration.
432+
func (m *MigrationBase) PreserveComments() bool { return m.preserveComments }
433+
419434
func (n MigrationName) String() string {
420435
return n.Library + ": " + n.Name
421436
}

internal/mhelp/run_sql.go

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/muir/libschema"
99
"github.com/muir/libschema/classifysql"
1010
"github.com/muir/libschema/internal"
11+
"github.com/muir/sqltoken"
1112
"github.com/pkg/errors"
1213
)
1314

@@ -16,7 +17,42 @@ type CanExecContext interface {
1617
}
1718

1819
func RunSQL(ctx context.Context, log *internal.Log, tx CanExecContext, statements classifysql.Statements, rowsAffected *int64, m libschema.Migration, d *libschema.Database) error {
19-
for _, commandSQL := range statements.TokensList().Strings() {
20+
for _, tokens := range statements.TokensList() {
21+
if !m.Base().PreserveComments() {
22+
tokens = tokens.Strip()
23+
}
24+
// Strip leading DelimiterStatement (e.g., "DELIMITER //\n")
25+
if len(tokens) > 0 && tokens[0].Type == sqltoken.DelimiterStatement {
26+
log.Debug("Stripping leading DelimiterStatement from migration", map[string]any{
27+
"name": m.Base().Name.Name,
28+
"library": m.Base().Name.Library,
29+
})
30+
tokens = tokens[1:]
31+
}
32+
// Strip trailing DelimiterStatement (e.g., "DELIMITER ;\n") and any whitespace before it
33+
for len(tokens) > 0 && (tokens[len(tokens)-1].Type == sqltoken.DelimiterStatement || tokens[len(tokens)-1].Type == sqltoken.Whitespace) {
34+
if tokens[len(tokens)-1].Type == sqltoken.DelimiterStatement {
35+
log.Debug("Stripping trailing DelimiterStatement from migration", map[string]any{
36+
"name": m.Base().Name.Name,
37+
"library": m.Base().Name.Library,
38+
})
39+
}
40+
tokens = tokens[:len(tokens)-1]
41+
}
42+
// Strip trailing Delimiter (e.g., "//") and any whitespace before it
43+
for len(tokens) > 0 && (tokens[len(tokens)-1].Type == sqltoken.Delimiter || tokens[len(tokens)-1].Type == sqltoken.Whitespace) {
44+
if tokens[len(tokens)-1].Type == sqltoken.Delimiter {
45+
log.Debug("Stripping trailing Delimiter from migration", map[string]any{
46+
"name": m.Base().Name.Name,
47+
"library": m.Base().Name.Library,
48+
})
49+
}
50+
tokens = tokens[:len(tokens)-1]
51+
}
52+
if len(tokens) == 0 {
53+
continue
54+
}
55+
commandSQL := tokens.String()
2056
result, err := tx.ExecContext(ctx, commandSQL)
2157
if d.Options.DebugLogging {
2258
log.Debug("Executed SQL", map[string]any{

lsmysql/mysql.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,9 @@ func (p *MySQL) DoOneMigration(ctx context.Context, log *internal.Log, d *libsch
255255
}
256256
sqlText = genSQL
257257
}
258-
sqlText = strings.TrimSpace(sqlText)
259-
m.Base().SetNote("sql", sqlText)
260-
if sqlText == "" {
258+
trimmedSQLText := strings.TrimSpace(sqlText)
259+
m.Base().SetNote("sql", trimmedSQLText)
260+
if trimmedSQLText == "" {
261261
return nil
262262
}
263263
statements, err := classifysql.ClassifyTokens(p.dialect, 0, sqlText)

lsmysql/mysql_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,53 @@ func testMysqlNotAllowed(t *testing.T, dsn string, createPostfix string, driverN
273273
}
274274
}
275275
}
276+
277+
func TestMysqlMigrationWithDelimiter(t *testing.T) {
278+
testMysqlOneMigration(t, `
279+
DELIMITER //
280+
CREATE PROCEDURE charge_account(IN id BIGINT, IN amount DECIMAL(18,4))
281+
BEGIN
282+
DECLARE balance DECIMAL(18,4);
283+
SELECT remaining_balance INTO balance
284+
FROM account_balance
285+
WHERE account_id = id;
286+
IF balance > amount THEN
287+
UPDATE account_balance
288+
SET remaining_balance = balance - amount
289+
WHERE account_id = id;
290+
END IF;
291+
END //
292+
DELIMITER ;
293+
`)
294+
}
295+
296+
func testMysqlOneMigration(t *testing.T, sqlText string) {
297+
t.Parallel()
298+
dsn := os.Getenv("LIBSCHEMA_MYSQL_TEST_DSN")
299+
if dsn == "" {
300+
t.Skip("Set $LIBSCHEMA_MYSQL_TEST_DSN to test libschema/lsmysql")
301+
}
302+
testOneMigration(t, dsn, sqlText, mysqlNew)
303+
}
304+
305+
func testOneMigration(t *testing.T, dsn string, sqlText string, driverNew driverNew) {
306+
options, cleanup := lstesting.FakeSchema(t, "")
307+
308+
t.Log("Doing migrations in database/schema", options.SchemaOverride)
309+
310+
options.DebugLogging = true
311+
db, err := sql.Open("mysql", dsn)
312+
require.NoError(t, err, "open database")
313+
defer func() {
314+
assert.NoError(t, db.Close())
315+
}()
316+
defer cleanup(db)
317+
318+
s := libschema.New(context.Background(), options)
319+
dbase, _, err := driverNew(t, "test", s, db)
320+
require.NoError(t, err, "libschema NewDatabase")
321+
322+
dbase.Migrations("L1", lsmysql.Script("M1", sqlText))
323+
err = s.Migrate(context.Background())
324+
assert.NoError(t, err)
325+
}

lsmysql/singlestore_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,27 @@ func TestSingleStoreNotAllowed(t *testing.T) {
6060
testMysqlNotAllowed(t, dsn, "", singleStoreNew)
6161
}
6262

63+
func TestSingleStoreMigrationWithDelimiter(t *testing.T) {
64+
testSingleStoreOneMigration(t, `
65+
DELIMITER //
66+
CREATE OR REPLACE PROCEDURE test_proc()
67+
AS
68+
BEGIN
69+
ECHO SELECT 1;
70+
END //
71+
DELIMITER ;
72+
`)
73+
}
74+
75+
func testSingleStoreOneMigration(t *testing.T, sqlText string) {
76+
t.Parallel()
77+
dsn := os.Getenv("LIBSCHEMA_SINGLESTORE_TEST_DSN")
78+
if dsn == "" {
79+
t.Skip("Set $LIBSCHEMA_SINGLESTORE_TEST_DSN to test SingleStore support in libschema/lsmysql")
80+
}
81+
testOneMigration(t, dsn, sqlText, singleStoreNew)
82+
}
83+
6384
func TestSingleStoreFailedMigration(t *testing.T) {
6485
t.Parallel()
6586
dsn := os.Getenv("LIBSCHEMA_SINGLESTORE_TEST_DSN")

lspostgres/bad_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func TestBadMigrationsPostgres(t *testing.T) {
121121
define: func(dbase *libschema.Database) {
122122
dbase.Migrations("L9",
123123
lspostgres.Script("T4", `CREATE TABLE T1 (id text)`),
124-
lspostgres.Script("T5", ` -- just a comment`, libschema.RepeatUntilNoOp()),
124+
lspostgres.Script("T5", ` -- just a comment`, libschema.RepeatUntilNoOp(), libschema.PreserveComments()),
125125
)
126126
},
127127
},

lspostgres/non_tx_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ func TestRowsAffectedErrorLogged(t *testing.T) {
217217
require.NoError(t, err)
218218
lib := fmt.Sprintf("RA_%d", time.Now().UnixNano())
219219
comment := lspostgres.Script("NOTHING",
220-
" -- just a comment")
220+
" -- just a comment",
221+
libschema.PreserveComments())
221222
dbase.Migrations(lib, comment)
222223
require.NoError(t, s.Migrate(ctx))
223224
entries := capLog.Entries()

lspostgres/postgres.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,9 @@ func (p *Postgres) DoOneMigration(ctx context.Context, log *internal.Log, d *lib
219219
}
220220
scriptSQL = sqlText
221221
}
222-
scriptSQL = strings.TrimSpace(scriptSQL)
223-
m.Base().SetNote("sql", scriptSQL)
224-
if scriptSQL == "" {
222+
trimmedScriptSQL := strings.TrimSpace(scriptSQL)
223+
m.Base().SetNote("sql", trimmedScriptSQL)
224+
if trimmedScriptSQL == "" {
225225
return nil
226226
}
227227
// Classification & downgrade via classifysql

0 commit comments

Comments
 (0)