package main import ( "errors" "fmt" "go/format" "io" "os" "path/filepath" "protoc-gen-slc/tpl" "regexp" "strings" "git.apinb.com/bsm-sdk/core/utils" "golang.org/x/mod/modfile" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/types/pluginpb" ) var ServicesName []string func main() { protogen.Options{}.Run(func(gen *protogen.Plugin) error { gen.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) if !utils.PathExists("./internal") { os.MkdirAll("./internal", 777) } if !utils.PathExists("./internal/server") { os.MkdirAll("./internal/server", 777) } if !utils.PathExists("./internal/logic") { os.MkdirAll("./internal/logic", 777) } for _, f := range gen.Files { if len(f.Services) == 0 { continue } if err := generateFiles(gen, f); err != nil { return err } } err := generateNewServerFile(ServicesName) if err != nil { return err } return nil }) } func generateFiles(gen *protogen.Plugin, file *protogen.File) error { for _, service := range file.Services { ServicesName = append(ServicesName, service.GoName) // Generate server file if err := generateServerFile(gen, file, service); err != nil { return err } // Generate logic file if err := generateLogicFile(gen, file, service); err != nil { return err } } return nil } func generateNewServerFile(services []string) error { moduleName := getModuleName() //create new.go code := tpl.NewFile newImports := []string{ "pb \"" + moduleName + "/pb\"", } code = strings.ReplaceAll(code, "{import}", strings.Join(newImports, "\n")) // register grpc var register []string for _, service := range services { register = append(register, "pb.Register"+service+"Server(srv.Grpc, New"+service+"Server())") } code = strings.ReplaceAll(code, "{register}", strings.Join(register, "\n")) // register grpc gw var gw []string for _, service := range services { gw = append(gw, "pb.Register"+service+"HandlerFromEndpoint(srv.Ctx, srv.Mux, addr, opts)") } code = strings.ReplaceAll(code, "{gw}", strings.Join(gw, "\n")) // 格式化代码 formattedCode, err := format.Source([]byte(code)) if err != nil { return fmt.Errorf("failed to format generated code: %w", err) } StringToFile("./internal/server/new.go", string(formattedCode)) return nil } func generateServerFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error { filename := fmt.Sprintf("./internal/server/%s_server.go", strings.ToLower(service.GoName)) moduleName := getModuleName() //create servers. code := tpl.Server imports := []string{ "\"" + moduleName + "/internal/logic/" + strings.ToLower(service.GoName) + "\"", "pb \"" + moduleName + "/pb\"", } code = strings.ReplaceAll(code, "{import}", strings.Join(imports, "\n")) code = strings.ReplaceAll(code, "{service}", service.GoName) var codeMethods []string for _, method := range service.Methods { commit := strings.TrimSpace(method.Comments.Leading.String()) methodCode := tpl.Method methodCode = strings.ReplaceAll(methodCode, "{service}", service.GoName) methodCode = strings.ReplaceAll(methodCode, "{serviceLower}", strings.ToLower(service.GoName)) methodCode = strings.ReplaceAll(methodCode, "{func}", method.GoName) methodCode = strings.ReplaceAll(methodCode, "{comment}", commit) methodCode = strings.ReplaceAll(methodCode, "{input}", method.Input.GoIdent.GoName) methodCode = strings.ReplaceAll(methodCode, "{output}", method.Output.GoIdent.GoName) codeMethods = append(codeMethods, methodCode) } code = strings.ReplaceAll(code, "{method}", strings.Join(codeMethods, "\n")) // 格式化代码 formattedCode, err := format.Source([]byte(code)) if err != nil { return fmt.Errorf("failed to format generated code: %w", err) } StringToFile(filename, string(formattedCode)) return nil } func generateClientFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error { filename := fmt.Sprintf("%s_client.pb.go", strings.ToLower(service.GoName)) fmt.Println(filename, file.GoImportPath) return nil } func generateLogicFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error { logicPath := "./internal/logic/" + strings.ToLower(service.GoName) if !utils.PathExists(logicPath) { os.MkdirAll(logicPath, os.ModePerm) } moduleName := getModuleName() for _, method := range service.Methods { filename := fmt.Sprintf("%s/%s.go", logicPath, toSnakeCase(method.GoName)) if utils.PathExists(filename) { continue } code := tpl.LogicFile code = strings.ReplaceAll(code, "{methodName}", strings.ToLower(service.GoName)) imports := []string{ "pb \"" + moduleName + "/pb\"", } code = strings.ReplaceAll(code, "{import}", strings.Join(imports, "\n")) commit := strings.TrimSpace(method.Comments.Leading.String()) code = strings.ReplaceAll(code, "{func}", method.GoName) code = strings.ReplaceAll(code, "{comment}", commit) code = strings.ReplaceAll(code, "{input}", method.Input.GoIdent.GoName) code = strings.ReplaceAll(code, "{output}", method.Output.GoIdent.GoName) // formattedCode, err := format.Source([]byte(code)) // if err != nil { // return fmt.Errorf("failed to format generated code: %w", err) // } // StringToFile(filename, string(formattedCode)) StringToFile(filename, code) } return nil } func toSnakeCase(str string) string { // Use a regular expression to find uppercase letters and insert an underscore before them re := regexp.MustCompile("([a-z0-9])([A-Z])") snake := re.ReplaceAllString(str, "${1}_${2}") // Convert the entire string to lowercase return strings.ToLower(snake) } func methodSignature(g *protogen.GeneratedFile, method *protogen.Method) string { return fmt.Sprintf("%s(ctx context.Context, req pb%s) (*%s, error)", method.GoName, method.Input.GoIdent, method.Output.GoIdent) } func fullMethodName(file *protogen.File, service *protogen.Service, method *protogen.Method) string { return fmt.Sprintf("/%s.%s/%s", file.Proto.GetPackage(), service.GoName, method.GoName) } func getModuleName() (modulePath string) { // 获取当前工作目录 cwd, err := os.Getwd() if err != nil { fmt.Errorf("failed to get current working directory: %w", err) return } // 读取 go.mod 文件 modFilePath := filepath.Join(cwd, "go.mod") modFileBytes, err := os.ReadFile(modFilePath) if err != nil { fmt.Errorf("failed to read go.mod file: %w", err) return } // 解析 go.mod 文件 modFile, err := modfile.Parse(modFilePath, modFileBytes, nil) if err != nil { fmt.Errorf("failed to parse go.mod file: %w", err) return } // 获取模块路径 return modFile.Module.Mod.Path } // 将字符串写入文件 func StringToFile(path, content string) error { startF, err := os.Create(path) if err != nil { return errors.New("os.Create create file " + path + " error:" + err.Error()) } defer startF.Close() _, err = io.WriteString(startF, content) if err != nil { return errors.New("io.WriteString to " + path + " error:" + err.Error()) } return nil }