Skip to content

Commit dea6ee0

Browse files
committed
feat(core): migrate support custom funcs for each version migration
1 parent d22dc93 commit dea6ee0

1 file changed

Lines changed: 83 additions & 33 deletions

File tree

core/stores/migrate/migrate.go

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,44 @@ import (
1414
"github.com/zeromicro/go-zero/core/stores/sqlx"
1515
)
1616

17-
type MigrateOpts struct {
18-
PreProcessSqlFunc func(content string) string
19-
Source string
17+
type MigrateUpOpts struct {
18+
PreProcessSqlFunc func(version uint, content string) string
19+
BeforeMigrateUpFuncs map[uint]func(version uint) error
20+
AfterMigrateUpFuncs map[uint]func(version uint) error
21+
Source string
2022
}
2123

22-
func (opts MigrateOpts) DefaultOptions() MigrateOpts {
23-
return MigrateOpts{
24-
PreProcessSqlFunc: func(content string) string {
25-
return content
26-
},
24+
func (opts MigrateUpOpts) DefaultOptions() MigrateUpOpts {
25+
return MigrateUpOpts{
2726
Source: "file://desc/sql_migration",
2827
}
2928
}
3029

31-
func WithPreProcessSqlFunc(f func(string) string) opts.Opt[MigrateOpts] {
32-
return func(opts *MigrateOpts) {
30+
func WithPreProcessSqlFunc(f func(uint, string) string) opts.Opt[MigrateUpOpts] {
31+
return func(opts *MigrateUpOpts) {
3332
opts.PreProcessSqlFunc = f
3433
}
3534
}
3635

37-
func WithSource(source string) opts.Opt[MigrateOpts] {
38-
return func(opts *MigrateOpts) {
36+
func WithSource(source string) opts.Opt[MigrateUpOpts] {
37+
return func(opts *MigrateUpOpts) {
3938
opts.Source = source
4039
}
4140
}
4241

43-
func Migrate(ctx context.Context, c sqlx.SqlConf, op ...opts.Opt[MigrateOpts]) error {
42+
func WithBeforeMigrateUpFunc(mapFuncs map[uint]func(uint) error) opts.Opt[MigrateUpOpts] {
43+
return func(opts *MigrateUpOpts) {
44+
opts.BeforeMigrateUpFuncs = mapFuncs
45+
}
46+
}
47+
48+
func WithAfterMigrateUpFunc(mapFunc map[uint]func(uint) error) opts.Opt[MigrateUpOpts] {
49+
return func(opts *MigrateUpOpts) {
50+
opts.AfterMigrateUpFuncs = mapFunc
51+
}
52+
}
53+
54+
func MigrateUp(ctx context.Context, c sqlx.SqlConf, op ...opts.Opt[MigrateUpOpts]) error {
4455
ops := opts.DefaultApply(op...)
4556
var databaseUrl string
4657
switch c.DriverName {
@@ -49,16 +60,17 @@ func Migrate(ctx context.Context, c sqlx.SqlConf, op ...opts.Opt[MigrateOpts]) e
4960
case "pgx":
5061
databaseUrl = "pgx5://" + strings.TrimPrefix(c.DataSource, "postgres://")
5162
}
52-
if err := sqlMigrate(ops.Source, databaseUrl, c, ops); err != nil {
63+
if err := migrateUp(ctx, ops.Source, databaseUrl, c, ops); err != nil {
5364
return err
5465
}
5566
return nil
5667
}
5768

5869
type customFileSource struct {
5970
*file.File
60-
driverName string
61-
preProcessSqlFunc func(content string) string
71+
preProcessSqlFunc func(version uint, content string) string
72+
sqlConf sqlx.SqlConf
73+
ctx context.Context
6274
}
6375

6476
func (c *customFileSource) ReadUp(version uint) (r io.ReadCloser, identifier string, err error) {
@@ -75,7 +87,8 @@ func (c *customFileSource) ReadUp(version uint) (r io.ReadCloser, identifier str
7587
if err = rc.Close(); err != nil {
7688
return nil, "", err
7789
}
78-
return io.NopCloser(strings.NewReader(c.preProcessSqlFunc(string(content)))), id, nil
90+
91+
return io.NopCloser(strings.NewReader(c.preProcessSqlFunc(version, string(content)))), id, nil
7992
}
8093

8194
func (c *customFileSource) ReadDown(version uint) (r io.ReadCloser, identifier string, err error) {
@@ -93,11 +106,10 @@ func (c *customFileSource) ReadDown(version uint) (r io.ReadCloser, identifier s
93106
return nil, "", err
94107
}
95108

96-
modifiedContent := c.preProcessSqlFunc(string(content))
97-
return io.NopCloser(strings.NewReader(modifiedContent)), id, nil
109+
return io.NopCloser(strings.NewReader(c.preProcessSqlFunc(version, string(content)))), id, nil
98110
}
99111

100-
func sqlMigrate(sourceUrl, databaseUrl string, c sqlx.SqlConf, ops MigrateOpts) error {
112+
func migrateUp(ctx context.Context, sourceUrl, databaseUrl string, c sqlx.SqlConf, ops MigrateUpOpts) error {
101113
fileDriver := &file.File{}
102114
fileSource, err := fileDriver.Open(sourceUrl)
103115
if err != nil {
@@ -106,27 +118,65 @@ func sqlMigrate(sourceUrl, databaseUrl string, c sqlx.SqlConf, ops MigrateOpts)
106118

107119
customSource := &customFileSource{
108120
File: fileSource.(*file.File),
109-
driverName: c.DriverName,
110121
preProcessSqlFunc: ops.PreProcessSqlFunc,
122+
sqlConf: c,
123+
ctx: ctx,
111124
}
112125

113-
m, err := migrate.NewWithSourceInstance("file", customSource, databaseUrl)
114-
if err != nil {
115-
return err
126+
var m *migrate.Migrate
127+
if ops.PreProcessSqlFunc == nil {
128+
m, err = migrate.New(sourceUrl, databaseUrl)
129+
if err != nil {
130+
return err
131+
}
132+
} else {
133+
m, err = migrate.NewWithSourceInstance("file", customSource, databaseUrl)
134+
if err != nil {
135+
return err
136+
}
116137
}
138+
defer m.Close()
117139

118-
if err = m.Up(); err != nil {
119-
if errors.Is(err, migrate.ErrNoChange) {
120-
return nil
140+
if ops.BeforeMigrateUpFuncs == nil && ops.AfterMigrateUpFuncs == nil {
141+
err = m.Up()
142+
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
143+
return err
121144
}
145+
return nil
122146
}
123-
124-
sourceErr, databaseErr := m.Close()
125-
if sourceErr != nil {
126-
return sourceErr
147+
// 获取当前版本
148+
currentVersion, _, err := m.Version()
149+
if err != nil {
150+
if errors.Is(err, migrate.ErrNilVersion) {
151+
// 不存在的话, 直接返回 Up
152+
return m.Up()
153+
}
127154
}
128-
if databaseErr != nil {
129-
return databaseErr
155+
156+
for {
157+
nextVersion, err := customSource.Next(currentVersion)
158+
if err == nil && nextVersion > currentVersion {
159+
if f, ok := ops.BeforeMigrateUpFuncs[nextVersion]; ok {
160+
if err = f(nextVersion); err != nil {
161+
return err
162+
}
163+
}
164+
if err = m.Steps(1); err != nil {
165+
return err
166+
}
167+
if f, ok := ops.AfterMigrateUpFuncs[nextVersion]; ok {
168+
if err = f(nextVersion); err != nil {
169+
if stepDownErr := m.Steps(-1); stepDownErr != nil {
170+
return stepDownErr
171+
}
172+
return err
173+
}
174+
}
175+
currentVersion = nextVersion
176+
} else {
177+
break
178+
}
130179
}
180+
131181
return nil
132182
}

0 commit comments

Comments
 (0)