proto-merge/main.go

186 lines
3.6 KiB
Go

package main
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// merge ./*.proto auto_gen.proto
func main() {
if len(os.Args) < 2 {
fmt.Println("Args error, eg: proto-merge ./*.proto auto_gen.proto")
os.Exit(0)
} else {
in := os.Args[1]
out := os.Args[2]
fmt.Println("Merge In:", in, "Out:", out)
files, err := FindFilesByPattern(in)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
files = IsPorotFile(files)
// 合并
var mergeLins []string
genHeader := make(map[string]string)
var protoHeader []string
protoHeader = append(protoHeader, "syntax = \"proto3\";")
for _, file := range files {
mergeLins = append(mergeLins, "")
mergeLins = append(mergeLins, "// "+file+" START")
header, lines := scanProtoFile(file)
mergeLins = append(mergeLins, lines...)
mergeLins = append(mergeLins, "// END")
mergeLins = append(mergeLins, "")
for key, val := range header {
genHeader[key] = val
}
}
for k, v := range genHeader {
if k == "package" {
protoHeader = append(protoHeader, v)
}
}
for k, v := range genHeader {
if k == "option" {
protoHeader = append(protoHeader, v)
}
}
var body []string
body = append(body, protoHeader...)
body = append(body, mergeLins...)
// 输出
err = StringToFile(out, strings.Join(body, "\n"))
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
} else {
fmt.Println("Merge OK!")
}
}
}
// FindFilesByPattern takes a combined string of directory and pattern and returns a slice of matching file names.
func FindFilesByPattern(pathWithPattern string) ([]string, error) {
// 使用Glob函数查找匹配的文件
files, err := filepath.Glob(pathWithPattern)
if err != nil {
return nil, err // 如果发生错误,返回错误信息
}
// 返回文件名切片
baseFileName := filepath.Base(pathWithPattern)
absPath, err := filepath.Abs(pathWithPattern)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
dirPath := strings.ReplaceAll(absPath, baseFileName, "")
for idx, v := range files {
bfn := filepath.Base(v)
files[idx] = dirPath + bfn
}
return files, nil
}
func IsPorotFile(in []string) []string {
var ok []string
for _, v := range in {
body, err := os.ReadFile(v)
if err != nil {
fmt.Println(err.Error())
}
if strings.Contains(string(body), "proto3") {
ok = append(ok, v)
}
}
return ok
}
func scanProtoFile(filePath string) (map[string]string, []string) {
file, err := os.Open(filePath)
if err != nil {
panic("无法打开文件")
}
defer file.Close()
scanner := bufio.NewScanner(file)
var bodys []string
var header map[string]string = make(map[string]string)
for scanner.Scan() {
line := scanner.Text()
/*
line = strings.Trim(line, "")
line = strings.ReplaceAll(line, "\r\n", "")
line = strings.ReplaceAll(line, "\n", "")
*/
//if line != "" {
key, isSysTag := containsSysTag(line)
if isSysTag {
header[key] = line
} else {
bodys = append(bodys, line)
}
//}
}
return header, bodys
}
func containsSysTag(line string) (string, bool) {
var tags []string = []string{"syntax", "package", "option", "import"}
for _, v := range tags {
if len(line) >= len(v) {
if line[0:len(v)] == v {
return v, true
}
}
}
return "", false
}
func StringToFile(path, content string) error {
startF, err := os.Create(path)
if err != nil {
return errors.New("os.Create create file " + path + " error:" + err.Error())
}
defer startF.Close()
_, err = io.WriteString(startF, content)
if err != nil {
return errors.New("io.WriteString to " + path + " error:" + err.Error())
}
return nil
}