@@ -14,33 +14,44 @@ import (
1414 "github.com/zeromicro/go-zero/core/stores/sqlx"
1515)
1616
17- type MigrateOpts struct {
18- PreProcessSqlFunc func (content string ) string
19- Source string
17+ type MigrateUpOpts struct {
18+ PreProcessSqlFunc func (version uint , content string ) string
19+ BeforeMigrateUpFuncs map [uint ]func (version uint ) error
20+ AfterMigrateUpFuncs map [uint ]func (version uint ) error
21+ Source string
2022}
2123
22- func (opts MigrateOpts ) DefaultOptions () MigrateOpts {
23- return MigrateOpts {
24- PreProcessSqlFunc : func (content string ) string {
25- return content
26- },
24+ func (opts MigrateUpOpts ) DefaultOptions () MigrateUpOpts {
25+ return MigrateUpOpts {
2726 Source : "file://desc/sql_migration" ,
2827 }
2928}
3029
31- func WithPreProcessSqlFunc (f func (string ) string ) opts.Opt [MigrateOpts ] {
32- return func (opts * MigrateOpts ) {
30+ func WithPreProcessSqlFunc (f func (uint , string ) string ) opts.Opt [MigrateUpOpts ] {
31+ return func (opts * MigrateUpOpts ) {
3332 opts .PreProcessSqlFunc = f
3433 }
3534}
3635
37- func WithSource (source string ) opts.Opt [MigrateOpts ] {
38- return func (opts * MigrateOpts ) {
36+ func WithSource (source string ) opts.Opt [MigrateUpOpts ] {
37+ return func (opts * MigrateUpOpts ) {
3938 opts .Source = source
4039 }
4140}
4241
43- func Migrate (ctx context.Context , c sqlx.SqlConf , op ... opts.Opt [MigrateOpts ]) error {
42+ func WithBeforeMigrateUpFunc (mapFuncs map [uint ]func (uint ) error ) opts.Opt [MigrateUpOpts ] {
43+ return func (opts * MigrateUpOpts ) {
44+ opts .BeforeMigrateUpFuncs = mapFuncs
45+ }
46+ }
47+
48+ func WithAfterMigrateUpFunc (mapFunc map [uint ]func (uint ) error ) opts.Opt [MigrateUpOpts ] {
49+ return func (opts * MigrateUpOpts ) {
50+ opts .AfterMigrateUpFuncs = mapFunc
51+ }
52+ }
53+
54+ func MigrateUp (ctx context.Context , c sqlx.SqlConf , op ... opts.Opt [MigrateUpOpts ]) error {
4455 ops := opts .DefaultApply (op ... )
4556 var databaseUrl string
4657 switch c .DriverName {
@@ -49,16 +60,17 @@ func Migrate(ctx context.Context, c sqlx.SqlConf, op ...opts.Opt[MigrateOpts]) e
4960 case "pgx" :
5061 databaseUrl = "pgx5://" + strings .TrimPrefix (c .DataSource , "postgres://" )
5162 }
52- if err := sqlMigrate ( ops .Source , databaseUrl , c , ops ); err != nil {
63+ if err := migrateUp ( ctx , ops .Source , databaseUrl , c , ops ); err != nil {
5364 return err
5465 }
5566 return nil
5667}
5768
5869type customFileSource struct {
5970 * file.File
60- driverName string
61- preProcessSqlFunc func (content string ) string
71+ preProcessSqlFunc func (version uint , content string ) string
72+ sqlConf sqlx.SqlConf
73+ ctx context.Context
6274}
6375
6476func (c * customFileSource ) ReadUp (version uint ) (r io.ReadCloser , identifier string , err error ) {
@@ -75,7 +87,8 @@ func (c *customFileSource) ReadUp(version uint) (r io.ReadCloser, identifier str
7587 if err = rc .Close (); err != nil {
7688 return nil , "" , err
7789 }
78- return io .NopCloser (strings .NewReader (c .preProcessSqlFunc (string (content )))), id , nil
90+
91+ return io .NopCloser (strings .NewReader (c .preProcessSqlFunc (version , string (content )))), id , nil
7992}
8093
8194func (c * customFileSource ) ReadDown (version uint ) (r io.ReadCloser , identifier string , err error ) {
@@ -93,11 +106,10 @@ func (c *customFileSource) ReadDown(version uint) (r io.ReadCloser, identifier s
93106 return nil , "" , err
94107 }
95108
96- modifiedContent := c .preProcessSqlFunc (string (content ))
97- return io .NopCloser (strings .NewReader (modifiedContent )), id , nil
109+ return io .NopCloser (strings .NewReader (c .preProcessSqlFunc (version , string (content )))), id , nil
98110}
99111
100- func sqlMigrate ( sourceUrl , databaseUrl string , c sqlx.SqlConf , ops MigrateOpts ) error {
112+ func migrateUp ( ctx context. Context , sourceUrl , databaseUrl string , c sqlx.SqlConf , ops MigrateUpOpts ) error {
101113 fileDriver := & file.File {}
102114 fileSource , err := fileDriver .Open (sourceUrl )
103115 if err != nil {
@@ -106,27 +118,65 @@ func sqlMigrate(sourceUrl, databaseUrl string, c sqlx.SqlConf, ops MigrateOpts)
106118
107119 customSource := & customFileSource {
108120 File : fileSource .(* file.File ),
109- driverName : c .DriverName ,
110121 preProcessSqlFunc : ops .PreProcessSqlFunc ,
122+ sqlConf : c ,
123+ ctx : ctx ,
111124 }
112125
113- m , err := migrate .NewWithSourceInstance ("file" , customSource , databaseUrl )
114- if err != nil {
115- return err
126+ var m * migrate.Migrate
127+ if ops .PreProcessSqlFunc == nil {
128+ m , err = migrate .New (sourceUrl , databaseUrl )
129+ if err != nil {
130+ return err
131+ }
132+ } else {
133+ m , err = migrate .NewWithSourceInstance ("file" , customSource , databaseUrl )
134+ if err != nil {
135+ return err
136+ }
116137 }
138+ defer m .Close ()
117139
118- if err = m .Up (); err != nil {
119- if errors .Is (err , migrate .ErrNoChange ) {
120- return nil
140+ if ops .BeforeMigrateUpFuncs == nil && ops .AfterMigrateUpFuncs == nil {
141+ err = m .Up ()
142+ if err != nil && ! errors .Is (err , migrate .ErrNoChange ) {
143+ return err
121144 }
145+ return nil
122146 }
123-
124- sourceErr , databaseErr := m .Close ()
125- if sourceErr != nil {
126- return sourceErr
147+ // 获取当前版本
148+ currentVersion , _ , err := m .Version ()
149+ if err != nil {
150+ if errors .Is (err , migrate .ErrNilVersion ) {
151+ // 不存在的话, 直接返回 Up
152+ return m .Up ()
153+ }
127154 }
128- if databaseErr != nil {
129- return databaseErr
155+
156+ for {
157+ nextVersion , err := customSource .Next (currentVersion )
158+ if err == nil && nextVersion > currentVersion {
159+ if f , ok := ops .BeforeMigrateUpFuncs [nextVersion ]; ok {
160+ if err = f (nextVersion ); err != nil {
161+ return err
162+ }
163+ }
164+ if err = m .Steps (1 ); err != nil {
165+ return err
166+ }
167+ if f , ok := ops .AfterMigrateUpFuncs [nextVersion ]; ok {
168+ if err = f (nextVersion ); err != nil {
169+ if stepDownErr := m .Steps (- 1 ); stepDownErr != nil {
170+ return stepDownErr
171+ }
172+ return err
173+ }
174+ }
175+ currentVersion = nextVersion
176+ } else {
177+ break
178+ }
130179 }
180+
131181 return nil
132182}
0 commit comments