Last active
February 12, 2020 04:01
-
-
Save paulbdavis/e684956357bde4a18ab1b8d1957e21fe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package auth | |
type tokenRefresher struct { | |
oauth *oauth2.Config | |
tokens map[string]*oauth2.Token | |
locks map[string]*sync.Mutex | |
accessed map[string]time.Time | |
} | |
func newRefresher(oauthConfig *oauth2.Config) *tokenRefresher { | |
return &tokenRefresher{ | |
oauth: oauthConfig, | |
tokens: map[string]*oauth2.Token{}, | |
locks: map[string]*sync.Mutex{}, | |
accessed: map[string]time.Time{}, | |
} | |
} | |
func (tr *tokenRefresher) maybeRefreshToken(ctx context.Context, t *oauth2.Token) (*oauth2.Token, error) { | |
key := t.AccessToken | |
log := logger.Ctx(ctx).With().Str("access token", key).Logger() | |
mx := tr.locks[key] | |
if mx == nil { | |
mx = &sync.Mutex{} | |
tr.locks[key] = mx | |
} | |
log.Debug(). | |
Msg("setting mutex lock") | |
mx.Lock() | |
defer mx.Unlock() | |
// if we already refreshed this access token, try to refresh the new one | |
refreshedToken := tr.tokens[key] | |
if refreshedToken != nil { | |
log.Debug(). | |
Interface("refreshed token", refreshedToken). | |
Msg("found a new token from this one") | |
return tr.maybeRefreshToken(ctx, refreshedToken) | |
} | |
log.Debug(). | |
Msg("no existing refresh, checking this token") | |
source := tr.oauth.TokenSource(ctx, t) | |
newToken, err := source.Token() | |
if err != nil { | |
return nil, fmt.Errorf("checking for refresh: %w", err) | |
} | |
// if this is a new token, save it to the map so that next time this token is used | |
if newToken.AccessToken != t.AccessToken { | |
log.Debug(). | |
Msg("token refreshed, adding to refresher cache") | |
tr.tokens[key] = newToken | |
} | |
tr.accessed[key] = time.Now() | |
log.Debug(). | |
Msg("finished refresh check") | |
return newToken, nil | |
} | |
func (tr *tokenRefresher) cleanup(ctx context.Context) { | |
log := logger.Ctx(ctx) | |
cutoff := time.Now().Add(-15 * time.Minute) | |
for key, accessed := range tr.accessed { | |
if accessed.Before(cutoff) { | |
log.Debug(). | |
Interface("old token", tr.tokens[key]). | |
Msg("removing old token record from refresher") | |
delete(tr.tokens, key) | |
delete(tr.locks, key) | |
delete(tr.accessed, key) | |
} | |
} | |
} | |
func CheckAuthHandler(oauthConf *oauth2.Config, next http.Handler) http.Handler { | |
refresher := newRefresher(oauthConf) | |
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
log := logger.Ctx(r.Context()) | |
tokenStr := r.Header.Get("x-auth-token") | |
incomingToken, err := parseTokenString(tokenStr) | |
if err != nil { | |
util.SendError(w, errors.New("unparsable token", http.StatusUnauthorized)) | |
return | |
} | |
token, err := refresher.maybeRefreshToken(r.Context(), incomingToken) | |
if err != nil { | |
util.SendError(w, err) | |
return | |
} | |
refresher.cleanup(r.Context()) | |
userInfo, err := getUserInfo(r.Context(), oauthConf, token) | |
if err != nil { | |
util.SendError(w, err) | |
return | |
} | |
ctx = context.WithValue(ctx, ContextKeyUserClient, oauthConf.Client(ctx, token)) | |
}) | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment