Skip to content

Instantly share code, notes, and snippets.

@ja7ad
Created October 21, 2024 16:12
Show Gist options
  • Save ja7ad/4b1424575770725fbe555ac1547dfeb1 to your computer and use it in GitHub Desktop.
Save ja7ad/4b1424575770725fbe555ac1547dfeb1 to your computer and use it in GitHub Desktop.
grpc gateway transport
package transport
import (
"context"
"github.com/Ja7ad/swaggerui"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"net/http"
"time"
)
type HTTPBootstrapper interface {
BaseTransporter
AddHandler(routerPath string, handler http.Handler)
AddHandlerFunc(routerPath string, handlerFunc http.HandlerFunc)
// RegisterServiceEndpoint register grpc gateway endpoint
RegisterServiceEndpoint(endpoint func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error) error
SetReadTimeout(t time.Duration)
SetWriteTimeout(t time.Duration)
}
type HTTPServer struct {
ctx context.Context
server *http.Server
notify chan error
shutdownTimeout time.Duration
rMux *runtime.ServeMux
mux *http.ServeMux
grpcAddress string
}
// NewHTTPServer create http server transport
func NewHTTPServer(
ctx context.Context,
httpAddress, grpcAddress string,
development bool,
swagger []byte,
customHeaders []string,
origins []string,
middleware func(handler http.Handler) http.Handler,
muxOpts ...runtime.ServeMuxOption,
) HTTPBootstrapper {
httpServer := new(HTTPServer)
if len(muxOpts) == 0 {
muxOpts = make([]runtime.ServeMuxOption, 0)
muxOpts = append(muxOpts, runtime.WithErrorHandler(middlewares.ErrorHandler))
}
rMux := runtime.NewServeMux(muxOpts...)
muxHandlers := http.NewServeMux()
muxHandlers = middlewares.SetRuntimeAsRootHandler(muxHandlers, rMux)
if development {
muxHandlers = middlewares.SwaggerHandler(muxHandlers, "swagger.json", swagger)
muxHandlers.Handle("/api-docs/", http.StripPrefix("/api-docs", swaggerui.Handler(swagger)))
}
srv := &http.Server{
Handler: middleware(AllowCORS(muxHandlers, origins, customHeaders...)),
Addr: httpAddress,
ReadTimeout: _defaultReadTimeout,
WriteTimeout: _defaultWriteTimeout,
}
httpServer.server = srv
httpServer.notify = make(chan error)
httpServer.ctx = ctx
httpServer.shutdownTimeout = _defaultShutdownTimeout
httpServer.grpcAddress = grpcAddress
httpServer.rMux = rMux
httpServer.mux = muxHandlers
return httpServer
}
func (s *HTTPServer) SetReadTimeout(t time.Duration) {
s.server.ReadTimeout = t
}
func (s *HTTPServer) SetWriteTimeout(t time.Duration) {
s.server.WriteTimeout = t
}
func (s *HTTPServer) Start() {
go func() {
s.notify <- s.server.ListenAndServe()
close(s.notify)
}()
}
func (s *HTTPServer) Notify() <-chan error {
return s.notify
}
func (s *HTTPServer) Shutdown(ctx context.Context) error {
return s.server.Shutdown(ctx)
}
func (s *HTTPServer) AddHandler(routerPath string, handler http.Handler) {
s.mux.Handle(routerPath, handler)
}
func (s *HTTPServer) AddHandlerFunc(routerPath string, handlerFunc http.HandlerFunc) {
s.mux.HandleFunc(routerPath, handlerFunc)
}
func (s *HTTPServer) RegisterServiceEndpoint(endpoint func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error) error {
return endpoint(s.ctx, s.rMux, s.grpcAddress, []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())})
}
// ErrorHandler convert grpc status error to http error
func ErrorHandler(ctx context.Context, mux *runtime.ServeMux, m runtime.Marshaler, w http.ResponseWriter, r *http.Request, err error) {
w.Header().Set("Content-Type", "application/json")
runtime.DefaultHTTPErrorHandler(ctx, mux, m, w, r, err)
}
// AllowCORS add cors to http handler
func AllowCORS(h http.Handler, origins []string, customHeaders ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if len(origins) != 0 {
if origin == "" {
w.WriteHeader(http.StatusForbidden)
return
}
if !checkOrigin(origin, origins) {
w.WriteHeader(http.StatusForbidden)
return
}
} else {
origin = "*"
}
headers := []string{
"Accept",
"Content-Type",
"Content-Length",
"Accept-Encoding",
"Authorization",
"ResponseType",
}
headers = append(headers, customHeaders...)
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", strings.Join(headers, ", "))
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
h.ServeHTTP(w, r)
})
}
// SetRuntimeAsRootHandler set runtime mux as root handler http server mux
func SetRuntimeAsRootHandler(mux *http.ServeMux, rMux *runtime.ServeMux) *http.ServeMux {
mux.Handle("/", rMux)
return mux
}
// SwaggerHandler add swagger file embedded to http handler path, swaggerFileName (swagger.json or swagger.yaml and etc)
func SwaggerHandler(mux *http.ServeMux, swaggerFileName string, swagger []byte) *http.ServeMux {
mux.HandleFunc("/"+swaggerFileName, func(w http.ResponseWriter, _ *http.Request) {
w.Write(swagger)
})
return mux
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment