package main import ( "bufio" "errors" "fmt" "io" "os" "path/filepath" "strings" ) // merge ./*.proto auto_gen.proto func main() { if len(os.Args) < 2 { fmt.Println("Args error, eg: proto-merge ./*.proto auto_gen.proto") os.Exit(0) } else { in := os.Args[1] out := os.Args[2] fmt.Println("Merge In:", in, "Out:", out) files, err := FindFilesByPattern(in) if err != nil { fmt.Println("Error:", err) os.Exit(1) } files = IsPorotFile(files) // 合并 var mergeLins []string genHeader := make(map[string]string) var protoHeader []string protoHeader = append(protoHeader, "syntax = \"proto3\";") for _, file := range files { mergeLins = append(mergeLins, "") mergeLins = append(mergeLins, "// "+file+" START") header, lines := scanProtoFile(file) mergeLins = append(mergeLins, lines...) mergeLins = append(mergeLins, "// END") mergeLins = append(mergeLins, "") for key, val := range header { genHeader[key] = val } } for k, v := range genHeader { if k == "package" { protoHeader = append(protoHeader, v) } } for k, v := range genHeader { if k == "option" { protoHeader = append(protoHeader, v) } } var body []string body = append(body, protoHeader...) body = append(body, mergeLins...) // 输出 err = StringToFile(out, strings.Join(body, "\n")) if err != nil { fmt.Println("Error:", err) os.Exit(1) } else { fmt.Println("Merge OK!") } } } // FindFilesByPattern takes a combined string of directory and pattern and returns a slice of matching file names. func FindFilesByPattern(pathWithPattern string) ([]string, error) { // 使用Glob函数查找匹配的文件 files, err := filepath.Glob(pathWithPattern) if err != nil { return nil, err // 如果发生错误,返回错误信息 } // 返回文件名切片 baseFileName := filepath.Base(pathWithPattern) absPath, err := filepath.Abs(pathWithPattern) if err != nil { fmt.Println("Error:", err) os.Exit(1) } dirPath := strings.ReplaceAll(absPath, baseFileName, "") for idx, v := range files { bfn := filepath.Base(v) files[idx] = dirPath + bfn } return files, nil } func IsPorotFile(in []string) []string { var ok []string for _, v := range in { body, err := os.ReadFile(v) if err != nil { fmt.Println(err.Error()) } if strings.Contains(string(body), "proto3") { ok = append(ok, v) } } return ok } func scanProtoFile(filePath string) (map[string]string, []string) { file, err := os.Open(filePath) if err != nil { panic("无法打开文件") } defer file.Close() scanner := bufio.NewScanner(file) var bodys []string var header map[string]string = make(map[string]string) for scanner.Scan() { line := scanner.Text() /* line = strings.Trim(line, "") line = strings.ReplaceAll(line, "\r\n", "") line = strings.ReplaceAll(line, "\n", "") */ //if line != "" { key, isSysTag := containsSysTag(line) if isSysTag { header[key] = line } else { bodys = append(bodys, line) } //} } return header, bodys } func containsSysTag(line string) (string, bool) { var tags []string = []string{"syntax", "package", "option", "import"} for _, v := range tags { if len(line) >= len(v) { if line[0:len(v)] == v { return v, true } } } return "", false } func StringToFile(path, content string) error { startF, err := os.Create(path) if err != nil { return errors.New("os.Create create file " + path + " error:" + err.Error()) } defer startF.Close() _, err = io.WriteString(startF, content) if err != nil { return errors.New("io.WriteString to " + path + " error:" + err.Error()) } return nil }