protoc-gen-slc/main.go

253 lines
7.0 KiB
Go

package main
import (
"errors"
"fmt"
"go/format"
"io"
"os"
"path/filepath"
"protoc-gen-slc/tpl"
"regexp"
"strings"
"git.apinb.com/bsm-sdk/core/utils"
"golang.org/x/mod/modfile"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/pluginpb"
)
var ServicesName []string
func main() {
protogen.Options{}.Run(func(gen *protogen.Plugin) error {
gen.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
if !utils.PathExists("./internal") {
os.MkdirAll("./internal", 777)
}
if !utils.PathExists("./internal/server") {
os.MkdirAll("./internal/server", 777)
}
if !utils.PathExists("./internal/logic") {
os.MkdirAll("./internal/logic", 777)
}
for _, f := range gen.Files {
if len(f.Services) == 0 {
continue
}
if err := generateFiles(gen, f); err != nil {
return err
}
}
err := generateNewServerFile(ServicesName)
if err != nil {
return err
}
return nil
})
}
func generateFiles(gen *protogen.Plugin, file *protogen.File) error {
for _, service := range file.Services {
ServicesName = append(ServicesName, service.GoName)
// Generate server file
if err := generateServerFile(gen, file, service); err != nil {
return err
}
// Generate logic file
if err := generateLogicFile(gen, file, service); err != nil {
return err
}
}
return nil
}
func generateNewServerFile(services []string) error {
moduleName := getModuleName()
//create new.go
code := tpl.NewFile
newImports := []string{
"pb \"" + moduleName + "/pb\"",
}
code = strings.ReplaceAll(code, "{import}", strings.Join(newImports, "\n"))
// register grpc
var register []string
for _, service := range services {
register = append(register, "pb.Register"+service+"Server(srv.Grpc, New"+service+"Server())")
}
code = strings.ReplaceAll(code, "{register}", strings.Join(register, "\n"))
// register grpc gw
var gw []string
for _, service := range services {
gw = append(gw, "pb.Register"+service+"HandlerFromEndpoint(srv.Ctx, srv.Mux, addr, opts)")
}
code = strings.ReplaceAll(code, "{gw}", strings.Join(gw, "\n"))
// 格式化代码
formattedCode, err := format.Source([]byte(code))
if err != nil {
return fmt.Errorf("failed to format generated code: %w", err)
}
StringToFile("./internal/server/new.go", string(formattedCode))
return nil
}
func generateServerFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error {
filename := fmt.Sprintf("./internal/server/%s_server.go", strings.ToLower(service.GoName))
moduleName := getModuleName()
//create servers.
code := tpl.Server
imports := []string{
"\"" + moduleName + "/internal/logic/" + strings.ToLower(service.GoName) + "\"",
"pb \"" + moduleName + "/pb\"",
}
code = strings.ReplaceAll(code, "{import}", strings.Join(imports, "\n"))
code = strings.ReplaceAll(code, "{service}", service.GoName)
var codeMethods []string
for _, method := range service.Methods {
commit := strings.TrimSpace(method.Comments.Leading.String())
methodCode := tpl.Method
methodCode = strings.ReplaceAll(methodCode, "{service}", service.GoName)
methodCode = strings.ReplaceAll(methodCode, "{serviceLower}", strings.ToLower(service.GoName))
methodCode = strings.ReplaceAll(methodCode, "{func}", method.GoName)
methodCode = strings.ReplaceAll(methodCode, "{comment}", commit)
methodCode = strings.ReplaceAll(methodCode, "{input}", method.Input.GoIdent.GoName)
methodCode = strings.ReplaceAll(methodCode, "{output}", method.Output.GoIdent.GoName)
codeMethods = append(codeMethods, methodCode)
}
code = strings.ReplaceAll(code, "{method}", strings.Join(codeMethods, "\n"))
// 格式化代码
formattedCode, err := format.Source([]byte(code))
if err != nil {
return fmt.Errorf("failed to format generated code: %w", err)
}
StringToFile(filename, string(formattedCode))
return nil
}
func generateClientFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error {
filename := fmt.Sprintf("%s_client.pb.go", strings.ToLower(service.GoName))
fmt.Println(filename, file.GoImportPath)
return nil
}
func generateLogicFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error {
logicPath := "./internal/logic/" + strings.ToLower(service.GoName)
if !utils.PathExists(logicPath) {
os.MkdirAll(logicPath, os.ModePerm)
}
moduleName := getModuleName()
for _, method := range service.Methods {
filename := fmt.Sprintf("%s/%s.go", logicPath, toSnakeCase(method.GoName))
if utils.PathExists(filename) {
continue
}
code := tpl.LogicFile
code = strings.ReplaceAll(code, "{methodName}", strings.ToLower(service.GoName))
imports := []string{
"pb \"" + moduleName + "/pb\"",
}
code = strings.ReplaceAll(code, "{import}", strings.Join(imports, "\n"))
commit := strings.TrimSpace(method.Comments.Leading.String())
code = strings.ReplaceAll(code, "{func}", method.GoName)
code = strings.ReplaceAll(code, "{comment}", commit)
code = strings.ReplaceAll(code, "{input}", method.Input.GoIdent.GoName)
code = strings.ReplaceAll(code, "{output}", method.Output.GoIdent.GoName)
// formattedCode, err := format.Source([]byte(code))
// if err != nil {
// return fmt.Errorf("failed to format generated code: %w", err)
// }
// StringToFile(filename, string(formattedCode))
StringToFile(filename, code)
}
return nil
}
func toSnakeCase(str string) string {
// Use a regular expression to find uppercase letters and insert an underscore before them
re := regexp.MustCompile("([a-z0-9])([A-Z])")
snake := re.ReplaceAllString(str, "${1}_${2}")
// Convert the entire string to lowercase
return strings.ToLower(snake)
}
func methodSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
return fmt.Sprintf("%s(ctx context.Context, req pb%s) (*%s, error)",
method.GoName,
method.Input.GoIdent,
method.Output.GoIdent)
}
func fullMethodName(file *protogen.File, service *protogen.Service, method *protogen.Method) string {
return fmt.Sprintf("/%s.%s/%s",
file.Proto.GetPackage(),
service.GoName,
method.GoName)
}
func getModuleName() (modulePath string) {
// 获取当前工作目录
cwd, err := os.Getwd()
if err != nil {
fmt.Errorf("failed to get current working directory: %w", err)
return
}
// 读取 go.mod 文件
modFilePath := filepath.Join(cwd, "go.mod")
modFileBytes, err := os.ReadFile(modFilePath)
if err != nil {
fmt.Errorf("failed to read go.mod file: %w", err)
return
}
// 解析 go.mod 文件
modFile, err := modfile.Parse(modFilePath, modFileBytes, nil)
if err != nil {
fmt.Errorf("failed to parse go.mod file: %w", err)
return
}
// 获取模块路径
return modFile.Module.Mod.Path
}
// 将字符串写入文件
func StringToFile(path, content string) error {
startF, err := os.Create(path)
if err != nil {
return errors.New("os.Create create file " + path + " error:" + err.Error())
}
defer startF.Close()
_, err = io.WriteString(startF, content)
if err != nil {
return errors.New("io.WriteString to " + path + " error:" + err.Error())
}
return nil
}