diff --git a/main.go b/main.go index 34bf8c3..290e7f9 100644 --- a/main.go +++ b/main.go @@ -26,6 +26,34 @@ func main() { } func setup() { + // 输入新的 go.mod module + var newModName string + fmt.Printf("Enter new module name: ") + _, err := fmt.Scanln(&newModName) + if err != nil { + fmt.Printf("Unable get new module name: %v\n", err) + os.Exit(1) + return + } + + fmt.Printf("Do you want to setup the project to %s? (y/n)", newModName) + var answer string + _, err = fmt.Scanln(&answer) + if err != nil { + fmt.Printf("Error reading user input: %v\n", err) + os.Exit(1) + } + if answer != "y" { + fmt.Printf("Aborting setup.\n") + } + + // 修改 go.mod 文件中的 module 名称 + err = replaceInFile("go.mod", frameworkModuleName, newModName) + if err != nil { + fmt.Printf("Error replacing module name in go.mod: %v\n", err) + os.Exit(1) + } + // 读取 go.mod 中的 module 名称 modName, err := getModuleName("go.mod") if err != nil { @@ -38,17 +66,6 @@ func setup() { os.Exit(1) } - fmt.Printf("Do you want to setup the project to %s? (y/n)", modName) - var answer string - _, err = fmt.Scanln(&answer) - if err != nil { - fmt.Printf("Error reading user input: %v\n", err) - os.Exit(1) - } - if answer != "y" { - fmt.Printf("Aborting setup.\n") - } - fmt.Printf("Module name found: %s\n", modName) // 遍历当前文件夹(排除 vendor、setup.go 和版本控制文件夹) err = filepath.Walk(".", func(path string, info fs.FileInfo, err error) error {