-
-
Save StevenACoffman/a724d4643f1878aae3ea2b6338213101 to your computer and use it in GitHub Desktop.
Example for using go's sync.errgroup together with signal detection signal.Notify to stop all running goroutines
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"encoding/json" | |
"errors" | |
"fmt" | |
"log" | |
"net/http" | |
"os" | |
"os/signal" | |
"sync" | |
"syscall" | |
"time" | |
"golang.org/x/sync/errgroup" | |
) | |
const ( | |
// exitFail is the exit code if the program | |
// fails. | |
exitFail = 1 | |
// exitSuccess is the exit code if the program succeeds. | |
exitSuccess = 0 | |
) | |
func main() { | |
addr := ":" + os.Getenv("PORT") | |
if addr == ":" { | |
addr = ":3000" | |
} | |
// Health check http server | |
logger := log.New(os.Stdout, | |
"INFO: ", | |
log.Ldate|log.Ltime|log.Lshortfile) | |
api := NewHTTPServer(logger, addr) | |
// ErrGroup for graceful shutdown | |
ctx, done := context.WithCancel(context.Background()) | |
defer done() | |
g, gctx := errgroup.WithContext(ctx) | |
eg := &graceful{gctx, done, api} | |
// goroutine to check for signals to gracefully finish all functions | |
g.Go(eg.listen) | |
// just a ticker every 2s | |
g.Go(eg.doWork2s) | |
// just a ticker every 1s | |
g.Go(eg.doWork1s) | |
// start a healthcheck server | |
g.Go(func() error { | |
// will return http.ErrServerClosed error | |
return api.ListenAndServe() | |
}) | |
var timeout time.Duration = 30 | |
// force a stop after timeout | |
time.AfterFunc(timeout*time.Second, func() { | |
fmt.Printf("force finished after %ds\n", timeout) | |
done() | |
}) | |
// wait for all errgroup goroutines | |
err := g.Wait() | |
// some "errors" are actually from normal shutdown | |
switch { | |
case err == nil: | |
fmt.Println("finished clean") | |
case errors.Is(err, context.Canceled): | |
fmt.Println("context was canceled") | |
err = nil | |
case errors.Is(err, http.ErrServerClosed): | |
fmt.Println("server was closed") | |
err = nil | |
default: | |
fmt.Printf("received error: %v\n", err) | |
} | |
if err != nil { | |
os.Exit(exitFail) | |
} | |
os.Exit(exitSuccess) | |
} | |
type graceful struct { | |
gctx context.Context | |
done context.CancelFunc | |
api *http.Server | |
} | |
func (gf *graceful) doWork2s() error { | |
duration := 2 * time.Second | |
ticker := time.NewTicker(duration) | |
for { | |
select { | |
case <-ticker.C: | |
fmt.Printf("ticker %.0fs ticked\n", duration.Seconds()) | |
// testcase what happens if an error occurred | |
// return fmt.Errorf("test error ticker %.0f", duration.Seconds()) | |
case <-gf.gctx.Done(): | |
fmt.Printf("closing ticker %.0fs goroutine\n", duration.Seconds()) | |
// will return context.Cancelled error | |
return gf.gctx.Err() | |
} | |
} | |
} | |
func (gf *graceful) doWork1s() error { | |
duration := 1 * time.Second | |
ticker := time.NewTicker(duration) | |
for { | |
select { | |
case <-ticker.C: | |
fmt.Printf("ticker %.0fs ticked\n", duration.Seconds()) | |
// testcase what happens if an error occurred | |
//return fmt.Errorf("test error ticker %.0fs", duration) | |
case <-gf.gctx.Done(): | |
fmt.Printf("closing ticker %.0fs goroutine\n", duration.Seconds()) | |
// will return context.Cancelled errors | |
return gf.gctx.Err() | |
} | |
} | |
} | |
func (gf *graceful) listen() error { | |
signalChannel := getStopSignalsChannel() | |
select { | |
case sig := <-signalChannel: | |
fmt.Printf("Received signal: %s\n", sig) | |
// Give outstanding requests a deadline for completion. | |
gf.shutdown() | |
gf.done() | |
case <-gf.gctx.Done(): | |
fmt.Printf("closing signal goroutine\n") | |
gf.shutdown() | |
return gf.gctx.Err() | |
} | |
return nil | |
} | |
func (gf *graceful) shutdown() { | |
const timeout = 5 * time.Second | |
ctx, cancel := context.WithTimeout(context.Background(), timeout) | |
defer cancel() | |
gf.api.Shutdown(ctx) | |
} | |
func getStopSignalsChannel() <-chan os.Signal { | |
signalChannel := make(chan os.Signal, 1) | |
signal.Notify(signalChannel, | |
os.Interrupt, // interrupt is syscall.SIGINT, Ctrl+C | |
syscall.SIGQUIT, // Ctrl-\ | |
syscall.SIGHUP, // "terminal is disconnected" | |
syscall.SIGTERM, // "the normal way to politely ask a program to terminate" | |
) | |
return signalChannel | |
} | |
// ServerHandler implements type http.Handler interface, with our logger | |
type ServerHandler struct { | |
logger *log.Logger | |
mux *http.ServeMux | |
once sync.Once | |
} | |
// NewHTTPServer is factory function to initialize a new server | |
func NewHTTPServer(logger *log.Logger, addr string) *http.Server { | |
s := &ServerHandler{logger: logger} | |
h := &http.Server{ | |
Addr: addr, | |
Handler: s, | |
ReadTimeout: 10 * time.Second, | |
WriteTimeout: 10 * time.Second, | |
} | |
return h | |
} | |
// ServeHTTP satisfies Handler interface, sets up the Path Routing | |
func (s *ServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |
// on the first request only, lazily initialize | |
s.once.Do(func() { | |
if s.logger == nil { | |
s.logger = log.New(os.Stdout, | |
"INFO: ", | |
log.Ldate|log.Ltime|log.Lshortfile) | |
s.logger.Printf("Default Logger used") | |
} | |
s.mux = http.NewServeMux() | |
s.mux.HandleFunc("/redirect", s.RedirectToHome) | |
s.mux.HandleFunc("/health", s.HealthCheck) | |
s.mux.HandleFunc("/", s.HelloHome) | |
}) | |
s.mux.ServeHTTP(w, r) | |
} | |
func (s *ServerHandler) HelloHome(w http.ResponseWriter, _ *http.Request) { | |
s.logger.Println("Got Home Request") | |
w.Header().Set("Content-Type", "text/plain") | |
_, err := w.Write([]byte("Hello, World!")) | |
if err != nil { | |
s.logger.Println("error writing hello world:", err) | |
} | |
} | |
// HealthCheck verifies externally that the program is still responding | |
func (s *ServerHandler) HealthCheck(w http.ResponseWriter, _ *http.Request) { | |
s.logger.Println("Got HealthCheck Request") | |
w.WriteHeader(http.StatusOK) | |
w.Header().Set("Content-Type", "application/json") | |
resp := make(map[string]string) | |
resp["message"] = "Status OK" | |
jsonResp, err := json.Marshal(resp) | |
if err != nil { | |
s.logger.Fatalf("Error happened in JSON marshal. Err: %s", err) | |
} | |
w.Write(jsonResp) | |
return | |
} | |
// RedirectToHome Will Log the Request, and respond with a HTTP 303 to redirect to / | |
func (s *ServerHandler) RedirectToHome(w http.ResponseWriter, r *http.Request) { | |
s.logger.Printf("Redirected request %v to /\n", r.RequestURI) | |
w.Header().Add("location", "/") | |
w.WriteHeader(http.StatusSeeOther) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment