Skip to content

Instantly share code, notes, and snippets.

@melkishengue
Last active April 3, 2025 06:34
Show Gist options
  • Save melkishengue/354772f5c04d57128f9918ff99e51236 to your computer and use it in GitHub Desktop.
Save melkishengue/354772f5c04d57128f9918ff99e51236 to your computer and use it in GitHub Desktop.
A smarter gorm redis cacher with support for table/model invalidaition
package postgres
import (
"context"
"fmt"
"strings"
"time"
"github.com/go-gorm/caches/v4"
"github.com/gofiber/fiber/v2/log"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
type SmartRedisCacher struct {
rdb *redis.Client
}
func (c *SmartRedisCacher) Get(ctx context.Context, key string, q *caches.Query[any]) (*caches.Query[any], error) {
res, err := c.rdb.Get(ctx, key).Result()
if err == redis.Nil {
log.Infof("Key %s not found in redis.\n", key)
return nil, nil
}
if err != nil {
log.Errorf("failed to get data from redis: %s", err.Error())
return nil, err
}
if err := q.Unmarshal([]byte(res)); err != nil {
log.Errorf("failed to unmarshal data from redis: %s", err.Error())
return nil, err
}
return q, nil
}
func (c *SmartRedisCacher) Store(ctx context.Context, key string, val *caches.Query[any]) error {
res, err := val.Marshal()
if err != nil {
return err
}
if err := c.rdb.Set(ctx, key, res, 300*time.Second).Err(); err != nil {
log.Errorf("failed to store data in redis: %s", err.Error())
return err
}
return nil
}
func (c *SmartRedisCacher) InvalidateTables(ctx context.Context, tableNames ...string) error {
log.Infof("Invalidating cache for tables: %v", tableNames)
if len(tableNames) == 0 {
return nil
}
var allKeys []string
// For each table, find and collect keys that contain the table name
for _, tableName := range tableNames {
var (
cursor uint64
keys []string
)
// Create patterns to match - look for SQL references to this table
patterns := []string{
// Standard table reference pattern - "FROM table_name"
fmt.Sprintf("*FROM \"%s\"*", tableName),
// For joined tables - "JOIN table_name"
fmt.Sprintf("*JOIN %s*", tableName),
// For references in WHERE clauses - "table_name."
fmt.Sprintf("*%s.*", tableName),
}
// Scan for keys matching each pattern
for _, pattern := range patterns {
fullPattern := fmt.Sprintf("gorm-caches::%s", pattern)
for {
k, nextCursor, err := c.rdb.Scan(ctx, cursor, fullPattern, 100).Result()
if err != nil {
log.Errorf("Error scanning keys for table %s: %v", tableName, err)
return err
}
keys = append(keys, k...)
cursor = nextCursor
if cursor == 0 {
break
}
}
}
if len(keys) > 0 {
log.Infof("Found %d cache keys for table %s", len(keys), tableName)
allKeys = append(allKeys, keys...)
}
}
// Remove duplicate keys
uniqueKeys := make(map[string]struct{})
for _, key := range allKeys {
uniqueKeys[key] = struct{}{}
}
// Convert back to slice
distinctKeys := make([]string, 0, len(uniqueKeys))
for key := range uniqueKeys {
distinctKeys = append(distinctKeys, key)
}
// Delete all the collected keys
if len(distinctKeys) > 0 {
log.Infof("Invalidating %d distinct cache keys", len(distinctKeys))
if _, err := c.rdb.Del(ctx, distinctKeys...).Result(); err != nil {
log.Errorf("Failed to delete keys from redis: %s", err.Error())
return err
}
log.Info("Cache invalidation completed successfully")
} else {
log.Info("No matching keys found for invalidation")
}
return nil
}
// Model-specific invalidation
func (c *SmartRedisCacher) InvalidateModel(ctx context.Context, modelName string) error {
// Convert model name to table name (usually pluralization and snake_case)
tableName := ToSnakeCase(modelName) + "s" // Simple conversion, might need more sophistication
return c.InvalidateTables(ctx, tableName)
}
func (c *SmartRedisCacher) InvalidateModels(ctx context.Context, modelNames ...string) error {
tableNames := []string{}
// Convert model name to table name (usually pluralization and snake_case)
for _, modelName := range modelNames {
tableNames = append(tableNames, ToSnakeCase(modelName)+"s")
}
return c.InvalidateTables(ctx, tableNames...)
}
func (c *SmartRedisCacher) Invalidate(ctx context.Context) error {
// completely disable full cache invalidation for now as I dont see any value in that
return nil
}
func (c *SmartRedisCacher) FullCacheInvalidate(ctx context.Context) error {
log.Info("Full cache invalidation requested")
var (
cursor uint64
keys []string
)
for attempts := 0; ; attempts++ {
if attempts > 1000 {
return fmt.Errorf("infinite loop detected in Invalidate")
}
k, nextCursor, err := c.rdb.Scan(ctx, cursor, "gorm-caches::*", 100).Result()
if err != nil {
return err
}
keys = append(keys, k...)
cursor = nextCursor
if cursor == 0 {
break
}
}
if len(keys) > 0 {
log.Infof("Full invalidation: removing %d keys", len(keys))
if _, err := c.rdb.Del(ctx, keys...).Result(); err != nil {
log.Errorf("Failed to delete keys from redis: %s", err.Error())
return err
}
}
return nil
}
func (c *SmartRedisCacher) SetupCallbacks(db *gorm.DB) {
callback := db.Callback()
callback.Create().After("gorm:create").Register("smart_cache:invalidate_after_create", func(db *gorm.DB) {
c.InvalidateForScope(db)
})
callback.Update().After("gorm:update").Register("smart_cache:invalidate_after_update", func(db *gorm.DB) {
c.InvalidateForScope(db)
})
callback.Delete().After("gorm:delete").Register("smart_cache:invalidate_after_delete", func(db *gorm.DB) {
c.InvalidateForScope(db)
})
}
// Helper method to extract model information and invalidate cache
func (c *SmartRedisCacher) InvalidateForScope(db *gorm.DB) {
if db.Statement.Schema == nil {
return
}
// Get the model name and table name
modelName := db.Statement.Schema.Name
tableName := db.Statement.Table
// Log the operation
log.Infof("DB operation on model %s (table: %s) - invalidating related caches", modelName, tableName)
// Use our smart invalidation
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}
// Invalidate by both model name and table name to be thorough
_ = c.InvalidateModel(ctx, modelName)
_ = c.InvalidateTables(ctx, tableName)
}
func (c *SmartRedisCacher) InvalidateModelsManually(ctx context.Context, modelNames ...string) error {
// this can happen if redis caching is not enabled but the code tries to invalidate the cache
if c.rdb == nil {
log.Info("Cache is disabled.")
return nil
}
if len(modelNames) == 0 {
log.Info("No models to invalidate.")
return nil
}
log.Infof("Manual invalidation requested for models: %v", modelNames)
return c.InvalidateModels(ctx, modelNames...)
}
// Helper for converting CamelCase to snake_case
func ToSnakeCase(camel string) string {
var result strings.Builder
for i, r := range camel {
if i > 0 && 'A' <= r && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment