rag/internal/cmd/migrate.go
2024-07-14 19:48:07 +08:00

189 lines
4.5 KiB
Go

package cmd
import (
"ariga.io/atlas/sql/sqltool"
"context"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql/schema"
"fmt"
entmigrate "framework_v2/internal/ent/migrate"
"framework_v2/internal/migrations"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/pressly/goose/v3"
"github.com/spf13/cobra"
"log"
"os"
"strings"
"time"
)
var dsnCommand = &cobra.Command{
Use: "dsn",
Short: "生成 DSN",
Long: "生成 DSN",
Run: func(cmd *cobra.Command, args []string) {
fmt.Print(config.DB.Driver + "://" + config.DB.DSN)
},
}
var migrateCommand = &cobra.Command{
Use: "migrate [command]",
Short: "goose 迁移,用法 <command>",
Long: "适用于生产环境的数据库迁移",
Run: func(cmd *cobra.Command, args []string) {
if len(args) == 0 {
_ = cmd.Help()
return
}
RunMigrate(args)
},
}
var genMigrateCommand = &cobra.Command{
Use: "gen-migrate [name]",
Short: "新建 ent 迁移",
Long: "从 internal/ent 中新建迁移。在这之前,需要运行 go generate ./internal/ent",
Run: func(cmd *cobra.Command, args []string) {
generateMigration()
},
}
var createGoMigrateCommand = &cobra.Command{
Use: "create-migrate",
Short: "新建 go 迁移",
Long: "新建 goose 的 go 迁移。",
Run: func(cmd *cobra.Command, args []string) {
if len(args) == 0 {
_ = cmd.Help()
return
}
name := args[0]
createGooseMigration(name)
},
}
// RunMigrate 为数据库函数
func RunMigrate(args []string) {
db, err := goose.OpenDBWithDriver("postgres", config.DB.DSN2)
if err != nil {
log.Fatalf("goose: failed to open DB: %v\n", err)
}
defer func() {
if err := db.Close(); err != nil {
log.Fatalf("goose: failed to close DB: %v\n", err)
}
}()
command := args[0]
var arguments []string
if len(args) > 3 {
arguments = append(arguments, args[3:]...)
}
goose.SetBaseFS(migrations.MigrationFS)
if err := goose.SetDialect("postgres"); err != nil {
panic(err)
}
if err := goose.RunContext(context.Background(), command, db, ".", arguments...); err != nil {
log.Fatalf("goose %v: %v", command, err)
}
}
func generateMigration() {
ctx := context.Background()
dir, err := sqltool.NewGooseDir("internal/migrations")
if err != nil {
log.Fatalf("failed creating atlas migration directory: %v", err)
}
// Migrate diff options.
opts := []schema.MigrateOption{
schema.WithDir(dir), // provide migration directory
schema.WithMigrationMode(schema.ModeInspect), // provide migration mode
schema.WithDialect(dialect.Postgres), // Ent dialect to use
}
if len(os.Args) != 3 {
log.Fatalln("migration name is required. Use: 'go run -mod=mod internal/ent/migrate/main.go <name>'")
}
err = entmigrate.NamedDiff(ctx, config.DB.Driver+"://"+config.DB.DSN, os.Args[2], opts...)
if err != nil {
log.Fatalf("failed generating migration file: %v", err)
}
}
func createGooseMigration(name string) {
// 在 internal/migrations 目录下新建一个迁移文件
// 文件名为 yyyy-mm-dd-hh-mm-ss-<name>.go
month := int(time.Now().Month())
monthString := fmt.Sprintf("%d", month)
if month < 10 {
// 转 string
monthString = "0" + monthString
}
day := time.Now().Day()
dayString := fmt.Sprintf("%d", day)
if day < 10 {
dayString = "0" + dayString
}
hour := time.Now().Hour()
hourString := fmt.Sprintf("%d", hour)
if hour < 10 {
hourString = "0" + hourString
}
minute := time.Now().Minute()
minuteString := fmt.Sprintf("%d", minute)
if minute < 10 {
minuteString = "0" + minuteString
}
// 秒
second := time.Now().Second()
secondString := fmt.Sprintf("%d", second)
if second < 10 {
secondString = "0" + secondString
}
funcName := fmt.Sprintf("%d%s%s%s%s%s", time.Now().Year(), monthString, dayString, hourString, minuteString, secondString)
fileName := fmt.Sprintf("%s_%s.go", funcName, name)
// 模板内容
var template = `package migrations
import (
"context"
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationContext(Up<FuncName>, Down<FuncName>)
}
func Up<FuncName>(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "UPDATE users SET username='admin' WHERE username='root';")
return err
}
func Down<FuncName>(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "UPDATE users SET username='root' WHERE username='admin';")
return err
}
`
template = strings.ReplaceAll(template, "<FuncName>", funcName+name)
err := os.WriteFile("internal/migrations/"+fileName, []byte(template), 0644)
if err != nil {
log.Fatalf("failed creating migration file: %v", err)
}
}