Skip to content

Instantly share code, notes, and snippets.

@jfut
Last active November 22, 2024 09:56
Show Gist options
  • Save jfut/cb18ecee147436298f5705ab161389fc to your computer and use it in GitHub Desktop.
Save jfut/cb18ecee147436298f5705ab161389fc to your computer and use it in GitHub Desktop.
echo + go-ozzo/ozzo-validation: CSRF Token Validation
// echo: CSRF Token Management with Session
//
// Copyright (c) 2024 Jun Futagawa (jfut)
//
// This software is released under the MIT License.
// http://opensource.org/licenses/mit-license.php
package web
import (
"encoding/gob"
"fmt"
"time"
"github.com/google/uuid"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
)
// Maximum number of CSRF tokens per session
const maxCSRFtokensPerSession = 256
// CSRFTokenItem represents a single CSRF token with its associated path and expiration time
type CSRFTokenItem struct {
Token string
Path string
ExpiresAt time.Time
}
// CSRFTokens stores a collection of CSRF tokens and their expiration times
type CSRFTokens struct {
// Note: The mutex is currently unused due to concurrent map access issues when retrieving from the session
// Error: assignment copies lock value to csrfTokens: ... contains sync.RWMutex copy locks
// mutex sync.RWMutex
Tokens []CSRFTokenItem
}
func init() {
// Register CSRFTokens struct for session encoding to ensure proper serialization
gob.Register(CSRFTokens{})
}
// GenerateCSRFToken creates a new CSRF token, stores it in the session, and returns it.
// It also removes expired tokens from the CSRF tokens collection.
func GenerateCSRFToken(c echo.Context, sessionKey string, paths ...string) (string, error) {
sess, err := session.Get(DEFAULT_SESSIONID_KEY, c)
if err != nil {
return "", fmt.Errorf("failed_to_get_session: %w", err)
}
// Get CSRF tokens from the session
var csrfTokens CSRFTokens
if existingTokens, ok := sess.Values[sessionKey]; !ok {
csrfTokens = CSRFTokens{
Tokens: []CSRFTokenItem{},
}
} else {
csrfTokens = existingTokens.(CSRFTokens)
}
// csrfTokens.mutex.RLock()
// defer csrfTokens.mutex.RUnlock()
// Generate a new UUID v7 token
token := uuid.Must(uuid.NewV7()).String()
for _, path := range paths {
csrfTokens.Tokens = append(csrfTokens.Tokens, CSRFTokenItem{
Token: token,
Path: path,
ExpiresAt: time.Now().Add(time.Duration(sess.Options.MaxAge) * time.Second),
})
}
// Remove expired tokens
removeExpiredTokens(&csrfTokens)
// Limit the number of tokens per session
if maxCSRFtokensPerSession < len(csrfTokens.Tokens) {
oversize := len(csrfTokens.Tokens) - maxCSRFtokensPerSession
csrfTokens.Tokens = csrfTokens.Tokens[oversize:]
}
// Save updated csrfTokens to the session
sess.Values[sessionKey] = csrfTokens
if err := SaveSession(c, sess); err != nil {
return "", fmt.Errorf("failed_to_save_session: %w", err)
}
return token, nil
}
// ValidateCSRFToken checks if the provided token is valid for the current request path.
// It removes expired tokens from the CSRF tokens collection.
// If the token is valid, it removes that specific token from the collection.
func ValidateCSRFToken(c echo.Context, sessionKey string, inputToken string) bool {
sess, err := session.Get(DEFAULT_SESSIONID_KEY, c)
if err != nil {
return false
}
// Get CSRF tokens from the session
csrfTokens, ok := sess.Values[sessionKey].(CSRFTokens)
if !ok {
return false
}
// csrfTokens.mutex.RLock()
// defer csrfTokens.mutex.RUnlock()
// Remove expired tokens
removeExpiredTokens(&csrfTokens)
// Validate token
valid := false
for i, item := range csrfTokens.Tokens {
if item.Token == inputToken {
// NOTICE: c.Request().RequestURI and c.Request().RemoteAddr is empty in syumai/workers.
if item.Path == "*" || (c.Request() != nil && item.Path == c.Request().URL.Path) {
valid = true
}
csrfTokens.Tokens = append(csrfTokens.Tokens[:i], csrfTokens.Tokens[i+1:]...)
break
}
}
return valid
}
// removeExpiredTokens removes all expired tokens
func removeExpiredTokens(csrfTokens *CSRFTokens) {
// Find the index of the last expired token
expiredIndex := -1
for i, item := range csrfTokens.Tokens {
if time.Now().After(item.ExpiresAt) {
expiredIndex = i
} else {
break
}
}
// Remove all expired tokens
if expiredIndex != -1 {
csrfTokens.Tokens = csrfTokens.Tokens[expiredIndex+1:]
}
}
// echo + go-ozzo/ozzo-validation: CSRF Token Validation
//
// Copyright (c) 2024 Jun Futagawa (jfut)
//
// This software is released under the MIT License.
// http://opensource.org/licenses/mit-license.php
package web
import (
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/labstack/echo/v4"
)
// CSRFTokenRule defines a validation rule for CSRF tokens
type CSRFTokenRule struct {
c echo.Context
name string
err validation.Error
}
// CSRFToken creates a new CSRFTokenRule
func CSRFToken(c echo.Context, name string) CSRFTokenRule {
return CSRFTokenRule{
c: c,
name: name,
err: validation.NewError("validation_csrf_token_invalid", "Invalid CSRF token"),
}
}
// Validate checks if the provided value is a valid CSRF token
func (r CSRFTokenRule) Validate(value interface{}) error {
token, ok := value.(string)
if !ok {
return nil
}
if !ValidateCSRFToken(r.c, r.name, token) {
return r.err
}
return nil
}
// Error sets a custom error message for the CSRFTokenRule
func (r CSRFTokenRule) Error(message string) CSRFTokenRule {
r.err = r.err.SetMessage(message)
return r
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment