Example Usage:
$ go build -i -o protoc-gen-grpc-gateway-ts .
$ protoc -I ./lib/proto ./lib/proto/*.proto
--plugin=protoc-gen-grpc_gateway_ts=which protoc-gen-grpc-gateway-ts
--grpc_gateway_ts_out=lib/ts/src
| package main | |
| import ( | |
| "fmt" | |
| "io" | |
| "io/ioutil" | |
| "log" | |
| "os" | |
| "path" | |
| "sort" | |
| "strings" | |
| "github.com/golang/protobuf/proto" | |
| "github.com/golang/protobuf/protoc-gen-go/descriptor" | |
| gen "github.com/golang/protobuf/protoc-gen-go/generator" | |
| plugin "github.com/golang/protobuf/protoc-gen-go/plugin" | |
| ) | |
| type generator struct { | |
| *gen.Generator | |
| reader io.Reader | |
| writer io.Writer | |
| } | |
| type messageField struct { | |
| name string | |
| t string | |
| } | |
| func namespace(name string) string { | |
| return fmt.Sprintf(".%s", strings.Replace(path.Dir(name), "/", ".", -1)) | |
| } | |
| func qualifiedName(ns, name string) string { | |
| return fmt.Sprintf("%s.%s", ns, name) | |
| } | |
| type serviceDefinition struct { | |
| desc *descriptor.ServiceDescriptorProto | |
| dependencies map[string]bool | |
| } | |
| type serviceMap map[string]*serviceDefinition | |
| type messageDefinition struct { | |
| desc *descriptor.DescriptorProto | |
| dependencies map[string]bool | |
| name string | |
| fields []messageField | |
| location string | |
| } | |
| type messageMap map[string]*messageDefinition | |
| type enumDefinition struct { | |
| desc *descriptor.EnumDescriptorProto | |
| location string | |
| } | |
| type enumMap map[string]*enumDefinition | |
| func fatal(err error, msg string) { | |
| log.Printf("protoc-gen-grpc-gateway-ts: error: %s: %s", msg, err.Error()) | |
| os.Exit(1) | |
| } | |
| func New() *generator { | |
| return &generator{ | |
| Generator: gen.New(), | |
| reader: os.Stdin, | |
| writer: os.Stdout, | |
| } | |
| } | |
| // P prints the arguments to the generated output. It handles strings and int32s, plus | |
| // handling indirections because they may be *string, etc. | |
| func (g *generator) print(indent int, str ...interface{}) { | |
| g.WriteString(strings.Repeat(" ", 2*indent)) | |
| for _, v := range str { | |
| switch s := v.(type) { | |
| case string: | |
| g.WriteString(s) | |
| case *string: | |
| g.WriteString(*s) | |
| case bool: | |
| fmt.Fprintf(g, "%t", s) | |
| case *bool: | |
| fmt.Fprintf(g, "%t", *s) | |
| case int: | |
| fmt.Fprintf(g, "%d", s) | |
| case int32: | |
| fmt.Fprintf(g, "%d", s) | |
| case *int32: | |
| fmt.Fprintf(g, "%d", *s) | |
| case *int64: | |
| fmt.Fprintf(g, "%d", *s) | |
| case float64: | |
| fmt.Fprintf(g, "%g", s) | |
| case *float64: | |
| fmt.Fprintf(g, "%g", *s) | |
| default: | |
| g.Fail(fmt.Sprintf("unknown type in printer: %T", v)) | |
| } | |
| } | |
| g.WriteString("\n") | |
| } | |
| func (g *generator) processFile(fd *descriptor.FileDescriptorProto, definitions map[string]bool, mm messageMap, em enumMap) (*plugin.CodeGeneratorResponse_File, error) { | |
| fn := fd.GetName() | |
| ns := namespace(fn) | |
| needsWrite := false | |
| // Iterate message and enum dependencies to build a map for an import statement. | |
| importMap := make(map[string]map[string]bool) | |
| for _, message := range fd.MessageType { | |
| qn := qualifiedName(ns, message.GetName()) | |
| if _, ok := definitions[qn]; ok { | |
| needsWrite = true | |
| md := mm[qn] | |
| for dep, _ := range md.dependencies { | |
| messageImportDef := mm[dep] | |
| if messageImportDef != nil && messageImportDef.location != fn { | |
| // This is a message dependency | |
| s, ok := importMap[messageImportDef.location] | |
| if !ok { | |
| s = make(map[string]bool) | |
| } | |
| s[dep] = true | |
| importMap[messageImportDef.location] = s | |
| continue | |
| } | |
| enumImportDef := em[dep] | |
| if enumImportDef != nil && enumImportDef.location != fn { | |
| // This is an enum dependency | |
| s, ok := importMap[enumImportDef.location] | |
| if !ok { | |
| s = make(map[string]bool) | |
| } | |
| s[dep] = true | |
| importMap[enumImportDef.location] = s | |
| } | |
| } | |
| } | |
| } | |
| if !needsWrite { | |
| // We didn't have | |
| for _, enum := range fd.EnumType { | |
| qn := qualifiedName(ns, enum.GetName()) | |
| if _, ok := definitions[qn]; ok { | |
| needsWrite = true | |
| break | |
| } | |
| } | |
| } | |
| if !needsWrite { | |
| // There's nothing for us to do for this file. | |
| return nil, nil | |
| } | |
| i := 0 | |
| // Write the import statements, if necessary. | |
| multiImports := make([]string, 0) | |
| singleImports := make([]string, 0) | |
| for f, imports := range importMap { | |
| loc := fmt.Sprintf("../%s", filename(f)) | |
| loc = loc[:strings.LastIndex(loc, ".")] | |
| ims := make([]string, 0) | |
| for imp, _ := range imports { | |
| ims = append(ims, strings.Trim(path.Ext(imp), ".")) | |
| } | |
| sort.Strings(ims) | |
| statement := fmt.Sprintf("import { %s } from '%s'", strings.Join(ims, ", "), loc) | |
| if len(ims) > 1 { | |
| multiImports = append(multiImports, statement) | |
| } else { | |
| singleImports = append(singleImports, statement) | |
| } | |
| i++ | |
| } | |
| sort.Strings(multiImports) | |
| sort.Strings(singleImports) | |
| for _, is := range multiImports { | |
| g.print(0, is) | |
| } | |
| for _, is := range singleImports { | |
| g.print(0, is) | |
| } | |
| for _, enum := range fd.EnumType { | |
| qn := qualifiedName(ns, enum.GetName()) | |
| if _, ok := definitions[qn]; !ok { | |
| continue | |
| } | |
| // This enum is used by at least one other message or service, so we need to write it out. | |
| if i != 0 { | |
| g.print(0, "") | |
| } | |
| g.print(0, "export enum ", enum.GetName(), " {") | |
| for _, v := range enum.GetValue() { | |
| g.print(1, v.GetName(), " = ", v.GetNumber(), ",") | |
| } | |
| g.print(0, "}") | |
| i++ | |
| } | |
| for _, message := range fd.MessageType { | |
| qn := qualifiedName(ns, message.GetName()) | |
| if _, ok := definitions[qn]; !ok { | |
| continue | |
| } | |
| // This message is used by at least one other message or service, so we need to write it out. | |
| if i != 0 { | |
| g.print(0, "") | |
| } | |
| md := mm[qn] | |
| fieldCount := len(md.fields) | |
| brackets := "{" | |
| if fieldCount == 0 { | |
| brackets = "{}" | |
| } | |
| g.print(0, "export type ", md.name, " = ", brackets) | |
| for _, field := range md.fields { | |
| g.print(1, field.name, ": ", field.t) | |
| } | |
| if fieldCount > 0 { | |
| g.print(0, "}") | |
| } | |
| i++ | |
| } | |
| file := &plugin.CodeGeneratorResponse_File{ | |
| Name: proto.String(filename(fn)), | |
| Content: proto.String(g.String()), | |
| } | |
| g.Reset() | |
| return file, nil | |
| } | |
| // parseField parses the supplied field to extract its type. If it is a map, enum, or | |
| // message type, adds dependent messages and enums to the supplied dependencies map. | |
| func parseField(message *descriptor.DescriptorProto, field *descriptor.FieldDescriptorProto, dependencies map[string]bool) string { | |
| isMap := false | |
| t := "" | |
| switch field.GetType() { | |
| case descriptor.FieldDescriptorProto_TYPE_INT32: | |
| t = "number" | |
| case descriptor.FieldDescriptorProto_TYPE_INT64: | |
| t = "string" | |
| case descriptor.FieldDescriptorProto_TYPE_UINT64: | |
| t = "string" | |
| case descriptor.FieldDescriptorProto_TYPE_DOUBLE: | |
| t = "number" | |
| case descriptor.FieldDescriptorProto_TYPE_FLOAT: | |
| t = "number" | |
| case descriptor.FieldDescriptorProto_TYPE_FIXED64: | |
| t = "string" | |
| case descriptor.FieldDescriptorProto_TYPE_FIXED32: | |
| t = "number" | |
| case descriptor.FieldDescriptorProto_TYPE_BOOL: | |
| t = "boolean" | |
| case descriptor.FieldDescriptorProto_TYPE_STRING: | |
| t = "string" | |
| case descriptor.FieldDescriptorProto_TYPE_BYTES: | |
| t = "string" | |
| case descriptor.FieldDescriptorProto_TYPE_MESSAGE: | |
| messageType := strings.Trim(path.Ext(field.GetTypeName()), ".") | |
| // Handle Maps | |
| for _, nested := range message.GetNestedType() { | |
| if nested.GetName() == messageType && nested.GetOptions().GetMapEntry() { | |
| // This is a map, not a message. | |
| isMap = true | |
| keyType := "" | |
| valueType := "" | |
| for _, nestedField := range nested.GetField() { | |
| if nestedField.GetName() == "value" { | |
| valueType = parseField(message, nestedField, dependencies) | |
| } else if nestedField.GetName() == "key" { | |
| keyType = parseField(message, nestedField, dependencies) | |
| } | |
| } | |
| messageType = fmt.Sprintf("{ [key: %s]: %s }", keyType, valueType) | |
| break | |
| } | |
| } | |
| if !isMap { | |
| dependencies[field.GetTypeName()] = true | |
| } | |
| t = messageType | |
| case descriptor.FieldDescriptorProto_TYPE_ENUM: | |
| t = strings.Trim(path.Ext(field.GetTypeName()), ".") | |
| dependencies[field.GetTypeName()] = true | |
| default: | |
| fatal(fmt.Errorf("Unknown field type %d for %+v in %s", field.GetType(), field, message.GetName()), "Error") | |
| } | |
| if !isMap && field.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED { | |
| t = fmt.Sprintf("%s[]", t) | |
| } | |
| return t | |
| } | |
| func (g *generator) collectDefinitions(request *plugin.CodeGeneratorRequest) (map[string]bool, messageMap, enumMap) { | |
| mm := make(messageMap) | |
| sm := make(serviceMap) | |
| em := make(enumMap) | |
| for _, fd := range request.ProtoFile { | |
| ns := namespace(fd.GetName()) | |
| for _, enum := range fd.EnumType { | |
| em[qualifiedName(ns, enum.GetName())] = &enumDefinition{ | |
| desc: enum, | |
| location: fd.GetName(), | |
| } | |
| } | |
| for _, message := range fd.MessageType { | |
| dependencies := make(map[string]bool) | |
| fields := []messageField{} | |
| for _, field := range message.GetField() { | |
| fields = append(fields, messageField{ | |
| name: field.GetJsonName(), | |
| t: parseField(message, field, dependencies), | |
| }) | |
| } | |
| mm[qualifiedName(ns, message.GetName())] = &messageDefinition{ | |
| desc: message, | |
| dependencies: dependencies, | |
| name: message.GetName(), | |
| fields: fields, | |
| location: fd.GetName(), | |
| } | |
| } | |
| for _, service := range fd.Service { | |
| serviceName := gen.CamelCase(service.GetName()) | |
| method := service.GetMethod() | |
| dependencies := make(map[string]bool) | |
| for _, m := range method { | |
| dependencies[m.GetInputType()] = true | |
| dependencies[m.GetOutputType()] = true | |
| } | |
| sm[serviceName] = &serviceDefinition{ | |
| desc: service, | |
| dependencies: dependencies, | |
| } | |
| } | |
| } | |
| // Iterate through everything we've found, starting from our services, so we | |
| // only write types for things we've defined or imported. | |
| collected := make(map[string]bool) | |
| seen := make(map[string]bool) | |
| for _, definition := range sm { | |
| for sd, _ := range definition.dependencies { | |
| // These definitions are Requests and Responses | |
| collected[sd] = true | |
| collectDependencies(sd, mm, collected, seen) | |
| } | |
| } | |
| return collected, mm, em | |
| } | |
| func collectDependencies(key string, mm messageMap, collected map[string]bool, seen map[string]bool) { | |
| if _, repeated := seen[key]; repeated { | |
| return | |
| } | |
| seen[key] = true | |
| if next, ok := mm[key]; ok { | |
| for dependency, _ := range next.dependencies { | |
| collected[dependency] = true | |
| collectDependencies(dependency, mm, collected, seen) | |
| } | |
| } | |
| } | |
| func filename(name string) string { | |
| if ext := path.Ext(name); ext == ".proto" { | |
| name = name[:len(name)-len(ext)] | |
| } | |
| return fmt.Sprintf("%s_gw.ts", name) | |
| } | |
| func (g *generator) Generate() { | |
| input, err := ioutil.ReadAll(g.reader) | |
| if err != nil { | |
| fatal(err, "Could not read input.") | |
| } | |
| request := g.Request | |
| if err := proto.Unmarshal(input, request); err != nil { | |
| fatal(err, "Could not parse input proto.") | |
| } | |
| if len(request.FileToGenerate) == 0 { | |
| fatal(err, "No input files.") | |
| } | |
| g.CommandLineParameters(g.Request.GetParameter()) | |
| g.WrapTypes() | |
| g.SetPackageNames() | |
| g.BuildTypeNameMap() | |
| g.GenerateAllFiles() | |
| g.Reset() | |
| response := new(plugin.CodeGeneratorResponse) | |
| messages, mm, em := g.collectDefinitions(request) | |
| for _, fd := range request.ProtoFile { | |
| file, err := g.processFile(fd, messages, mm, em) | |
| if err != nil { | |
| fatal(err, fmt.Sprintf("Couldn't write file for %s", fd.GetName())) | |
| } | |
| if file != nil { | |
| response.File = append(response.File, file) | |
| } | |
| } | |
| output, err := proto.Marshal(response) | |
| if err != nil { | |
| fatal(err, "Couldn't marshal output proto") | |
| } | |
| _, err = g.writer.Write(output) | |
| if err != nil { | |
| fatal(err, "Couldn't write files") | |
| } | |
| } | |
| func main() { | |
| g := New() | |
| g.Generate() | |
| } |