Last active
November 22, 2024 09:56
-
-
Save jfut/cb18ecee147436298f5705ab161389fc to your computer and use it in GitHub Desktop.
echo + go-ozzo/ozzo-validation: CSRF Token Validation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// echo: CSRF Token Management with Session | |
// | |
// Copyright (c) 2024 Jun Futagawa (jfut) | |
// | |
// This software is released under the MIT License. | |
// http://opensource.org/licenses/mit-license.php | |
package web | |
import ( | |
"encoding/gob" | |
"fmt" | |
"time" | |
"github.com/google/uuid" | |
"github.com/labstack/echo-contrib/session" | |
"github.com/labstack/echo/v4" | |
) | |
// Maximum number of CSRF tokens per session | |
const maxCSRFtokensPerSession = 256 | |
// CSRFTokenItem represents a single CSRF token with its associated path and expiration time | |
type CSRFTokenItem struct { | |
Token string | |
Path string | |
ExpiresAt time.Time | |
} | |
// CSRFTokens stores a collection of CSRF tokens and their expiration times | |
type CSRFTokens struct { | |
// Note: The mutex is currently unused due to concurrent map access issues when retrieving from the session | |
// Error: assignment copies lock value to csrfTokens: ... contains sync.RWMutex copy locks | |
// mutex sync.RWMutex | |
Tokens []CSRFTokenItem | |
} | |
func init() { | |
// Register CSRFTokens struct for session encoding to ensure proper serialization | |
gob.Register(CSRFTokens{}) | |
} | |
// GenerateCSRFToken creates a new CSRF token, stores it in the session, and returns it. | |
// It also removes expired tokens from the CSRF tokens collection. | |
func GenerateCSRFToken(c echo.Context, sessionKey string, paths ...string) (string, error) { | |
sess, err := session.Get(DEFAULT_SESSIONID_KEY, c) | |
if err != nil { | |
return "", fmt.Errorf("failed_to_get_session: %w", err) | |
} | |
// Get CSRF tokens from the session | |
var csrfTokens CSRFTokens | |
if existingTokens, ok := sess.Values[sessionKey]; !ok { | |
csrfTokens = CSRFTokens{ | |
Tokens: []CSRFTokenItem{}, | |
} | |
} else { | |
csrfTokens = existingTokens.(CSRFTokens) | |
} | |
// csrfTokens.mutex.RLock() | |
// defer csrfTokens.mutex.RUnlock() | |
// Generate a new UUID v7 token | |
token := uuid.Must(uuid.NewV7()).String() | |
for _, path := range paths { | |
csrfTokens.Tokens = append(csrfTokens.Tokens, CSRFTokenItem{ | |
Token: token, | |
Path: path, | |
ExpiresAt: time.Now().Add(time.Duration(sess.Options.MaxAge) * time.Second), | |
}) | |
} | |
// Remove expired tokens | |
removeExpiredTokens(&csrfTokens) | |
// Limit the number of tokens per session | |
if maxCSRFtokensPerSession < len(csrfTokens.Tokens) { | |
oversize := len(csrfTokens.Tokens) - maxCSRFtokensPerSession | |
csrfTokens.Tokens = csrfTokens.Tokens[oversize:] | |
} | |
// Save updated csrfTokens to the session | |
sess.Values[sessionKey] = csrfTokens | |
if err := SaveSession(c, sess); err != nil { | |
return "", fmt.Errorf("failed_to_save_session: %w", err) | |
} | |
return token, nil | |
} | |
// ValidateCSRFToken checks if the provided token is valid for the current request path. | |
// It removes expired tokens from the CSRF tokens collection. | |
// If the token is valid, it removes that specific token from the collection. | |
func ValidateCSRFToken(c echo.Context, sessionKey string, inputToken string) bool { | |
sess, err := session.Get(DEFAULT_SESSIONID_KEY, c) | |
if err != nil { | |
return false | |
} | |
// Get CSRF tokens from the session | |
csrfTokens, ok := sess.Values[sessionKey].(CSRFTokens) | |
if !ok { | |
return false | |
} | |
// csrfTokens.mutex.RLock() | |
// defer csrfTokens.mutex.RUnlock() | |
// Remove expired tokens | |
removeExpiredTokens(&csrfTokens) | |
// Validate token | |
valid := false | |
for i, item := range csrfTokens.Tokens { | |
if item.Token == inputToken { | |
// NOTICE: c.Request().RequestURI and c.Request().RemoteAddr is empty in syumai/workers. | |
if item.Path == "*" || (c.Request() != nil && item.Path == c.Request().URL.Path) { | |
valid = true | |
} | |
csrfTokens.Tokens = append(csrfTokens.Tokens[:i], csrfTokens.Tokens[i+1:]...) | |
break | |
} | |
} | |
return valid | |
} | |
// removeExpiredTokens removes all expired tokens | |
func removeExpiredTokens(csrfTokens *CSRFTokens) { | |
// Find the index of the last expired token | |
expiredIndex := -1 | |
for i, item := range csrfTokens.Tokens { | |
if time.Now().After(item.ExpiresAt) { | |
expiredIndex = i | |
} else { | |
break | |
} | |
} | |
// Remove all expired tokens | |
if expiredIndex != -1 { | |
csrfTokens.Tokens = csrfTokens.Tokens[expiredIndex+1:] | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// echo + go-ozzo/ozzo-validation: CSRF Token Validation | |
// | |
// Copyright (c) 2024 Jun Futagawa (jfut) | |
// | |
// This software is released under the MIT License. | |
// http://opensource.org/licenses/mit-license.php | |
package web | |
import ( | |
validation "github.com/go-ozzo/ozzo-validation/v4" | |
"github.com/labstack/echo/v4" | |
) | |
// CSRFTokenRule defines a validation rule for CSRF tokens | |
type CSRFTokenRule struct { | |
c echo.Context | |
name string | |
err validation.Error | |
} | |
// CSRFToken creates a new CSRFTokenRule | |
func CSRFToken(c echo.Context, name string) CSRFTokenRule { | |
return CSRFTokenRule{ | |
c: c, | |
name: name, | |
err: validation.NewError("validation_csrf_token_invalid", "Invalid CSRF token"), | |
} | |
} | |
// Validate checks if the provided value is a valid CSRF token | |
func (r CSRFTokenRule) Validate(value interface{}) error { | |
token, ok := value.(string) | |
if !ok { | |
return nil | |
} | |
if !ValidateCSRFToken(r.c, r.name, token) { | |
return r.err | |
} | |
return nil | |
} | |
// Error sets a custom error message for the CSRFTokenRule | |
func (r CSRFTokenRule) Error(message string) CSRFTokenRule { | |
r.err = r.err.SetMessage(message) | |
return r | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment