253 lines
7.0 KiB
Go
253 lines
7.0 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"go/format"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"protoc-gen-slc/tpl"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"git.apinb.com/bsm-sdk/core/utils"
|
|
"golang.org/x/mod/modfile"
|
|
"google.golang.org/protobuf/compiler/protogen"
|
|
"google.golang.org/protobuf/types/pluginpb"
|
|
)
|
|
|
|
var ServicesName []string
|
|
|
|
func main() {
|
|
protogen.Options{}.Run(func(gen *protogen.Plugin) error {
|
|
gen.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
|
|
|
|
if !utils.PathExists("./internal") {
|
|
os.MkdirAll("./internal", 777)
|
|
}
|
|
|
|
if !utils.PathExists("./internal/server") {
|
|
os.MkdirAll("./internal/server", 777)
|
|
}
|
|
|
|
if !utils.PathExists("./internal/logic") {
|
|
os.MkdirAll("./internal/logic", 777)
|
|
}
|
|
|
|
for _, f := range gen.Files {
|
|
if len(f.Services) == 0 {
|
|
continue
|
|
}
|
|
if err := generateFiles(gen, f); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
err := generateNewServerFile(ServicesName)
|
|
if err != nil {
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func generateFiles(gen *protogen.Plugin, file *protogen.File) error {
|
|
for _, service := range file.Services {
|
|
ServicesName = append(ServicesName, service.GoName)
|
|
// Generate server file
|
|
if err := generateServerFile(gen, file, service); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Generate logic file
|
|
if err := generateLogicFile(gen, file, service); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func generateNewServerFile(services []string) error {
|
|
moduleName := getModuleName()
|
|
|
|
//create new.go
|
|
code := tpl.NewFile
|
|
newImports := []string{
|
|
"pb \"" + moduleName + "/pb\"",
|
|
}
|
|
code = strings.ReplaceAll(code, "{import}", strings.Join(newImports, "\n"))
|
|
|
|
// register grpc
|
|
var register []string
|
|
for _, service := range services {
|
|
register = append(register, "pb.Register"+service+"Server(srv.Grpc, New"+service+"Server())")
|
|
}
|
|
code = strings.ReplaceAll(code, "{register}", strings.Join(register, "\n"))
|
|
|
|
// register grpc gw
|
|
var gw []string
|
|
for _, service := range services {
|
|
gw = append(gw, "pb.Register"+service+"HandlerFromEndpoint(srv.Ctx, srv.Mux, addr, opts)")
|
|
}
|
|
code = strings.ReplaceAll(code, "{gw}", strings.Join(gw, "\n"))
|
|
|
|
// 格式化代码
|
|
formattedCode, err := format.Source([]byte(code))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to format generated code: %w", err)
|
|
}
|
|
|
|
StringToFile("./internal/server/new.go", string(formattedCode))
|
|
|
|
return nil
|
|
}
|
|
|
|
func generateServerFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error {
|
|
filename := fmt.Sprintf("./internal/server/%s_server.go", strings.ToLower(service.GoName))
|
|
moduleName := getModuleName()
|
|
|
|
//create servers.
|
|
code := tpl.Server
|
|
imports := []string{
|
|
"\"" + moduleName + "/internal/logic/" + strings.ToLower(service.GoName) + "\"",
|
|
"pb \"" + moduleName + "/pb\"",
|
|
}
|
|
|
|
code = strings.ReplaceAll(code, "{import}", strings.Join(imports, "\n"))
|
|
code = strings.ReplaceAll(code, "{service}", service.GoName)
|
|
|
|
var codeMethods []string
|
|
for _, method := range service.Methods {
|
|
commit := strings.TrimSpace(method.Comments.Leading.String())
|
|
methodCode := tpl.Method
|
|
methodCode = strings.ReplaceAll(methodCode, "{service}", service.GoName)
|
|
methodCode = strings.ReplaceAll(methodCode, "{serviceLower}", strings.ToLower(service.GoName))
|
|
methodCode = strings.ReplaceAll(methodCode, "{func}", method.GoName)
|
|
methodCode = strings.ReplaceAll(methodCode, "{comment}", commit)
|
|
methodCode = strings.ReplaceAll(methodCode, "{input}", method.Input.GoIdent.GoName)
|
|
methodCode = strings.ReplaceAll(methodCode, "{output}", method.Output.GoIdent.GoName)
|
|
codeMethods = append(codeMethods, methodCode)
|
|
}
|
|
code = strings.ReplaceAll(code, "{method}", strings.Join(codeMethods, "\n"))
|
|
|
|
// 格式化代码
|
|
formattedCode, err := format.Source([]byte(code))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to format generated code: %w", err)
|
|
}
|
|
|
|
StringToFile(filename, string(formattedCode))
|
|
|
|
return nil
|
|
}
|
|
|
|
func generateClientFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error {
|
|
filename := fmt.Sprintf("%s_client.pb.go", strings.ToLower(service.GoName))
|
|
fmt.Println(filename, file.GoImportPath)
|
|
return nil
|
|
}
|
|
|
|
func generateLogicFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error {
|
|
logicPath := "./internal/logic/" + strings.ToLower(service.GoName)
|
|
if !utils.PathExists(logicPath) {
|
|
os.MkdirAll(logicPath, os.ModePerm)
|
|
}
|
|
moduleName := getModuleName()
|
|
for _, method := range service.Methods {
|
|
filename := fmt.Sprintf("%s/%s.go", logicPath, toSnakeCase(method.GoName))
|
|
if utils.PathExists(filename) {
|
|
continue
|
|
}
|
|
code := tpl.LogicFile
|
|
|
|
code = strings.ReplaceAll(code, "{methodName}", strings.ToLower(service.GoName))
|
|
imports := []string{
|
|
"pb \"" + moduleName + "/pb\"",
|
|
}
|
|
|
|
code = strings.ReplaceAll(code, "{import}", strings.Join(imports, "\n"))
|
|
commit := strings.TrimSpace(method.Comments.Leading.String())
|
|
code = strings.ReplaceAll(code, "{func}", method.GoName)
|
|
code = strings.ReplaceAll(code, "{comment}", commit)
|
|
code = strings.ReplaceAll(code, "{input}", method.Input.GoIdent.GoName)
|
|
code = strings.ReplaceAll(code, "{output}", method.Output.GoIdent.GoName)
|
|
|
|
// formattedCode, err := format.Source([]byte(code))
|
|
// if err != nil {
|
|
// return fmt.Errorf("failed to format generated code: %w", err)
|
|
// }
|
|
|
|
// StringToFile(filename, string(formattedCode))
|
|
StringToFile(filename, code)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func toSnakeCase(str string) string {
|
|
// Use a regular expression to find uppercase letters and insert an underscore before them
|
|
re := regexp.MustCompile("([a-z0-9])([A-Z])")
|
|
snake := re.ReplaceAllString(str, "${1}_${2}")
|
|
// Convert the entire string to lowercase
|
|
return strings.ToLower(snake)
|
|
}
|
|
|
|
func methodSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
|
|
return fmt.Sprintf("%s(ctx context.Context, req pb%s) (*%s, error)",
|
|
method.GoName,
|
|
method.Input.GoIdent,
|
|
method.Output.GoIdent)
|
|
}
|
|
|
|
func fullMethodName(file *protogen.File, service *protogen.Service, method *protogen.Method) string {
|
|
return fmt.Sprintf("/%s.%s/%s",
|
|
file.Proto.GetPackage(),
|
|
service.GoName,
|
|
method.GoName)
|
|
}
|
|
|
|
func getModuleName() (modulePath string) {
|
|
// 获取当前工作目录
|
|
cwd, err := os.Getwd()
|
|
if err != nil {
|
|
fmt.Errorf("failed to get current working directory: %w", err)
|
|
return
|
|
}
|
|
|
|
// 读取 go.mod 文件
|
|
modFilePath := filepath.Join(cwd, "go.mod")
|
|
modFileBytes, err := os.ReadFile(modFilePath)
|
|
if err != nil {
|
|
fmt.Errorf("failed to read go.mod file: %w", err)
|
|
return
|
|
}
|
|
|
|
// 解析 go.mod 文件
|
|
modFile, err := modfile.Parse(modFilePath, modFileBytes, nil)
|
|
if err != nil {
|
|
fmt.Errorf("failed to parse go.mod file: %w", err)
|
|
return
|
|
}
|
|
|
|
// 获取模块路径
|
|
return modFile.Module.Mod.Path
|
|
}
|
|
|
|
// 将字符串写入文件
|
|
func StringToFile(path, content string) error {
|
|
startF, err := os.Create(path)
|
|
if err != nil {
|
|
return errors.New("os.Create create file " + path + " error:" + err.Error())
|
|
}
|
|
defer startF.Close()
|
|
_, err = io.WriteString(startF, content)
|
|
if err != nil {
|
|
return errors.New("io.WriteString to " + path + " error:" + err.Error())
|
|
}
|
|
return nil
|
|
}
|