Last active
April 3, 2025 06:34
-
-
Save melkishengue/354772f5c04d57128f9918ff99e51236 to your computer and use it in GitHub Desktop.
A smarter gorm redis cacher with support for table/model invalidaition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package postgres | |
import ( | |
"context" | |
"fmt" | |
"strings" | |
"time" | |
"github.com/go-gorm/caches/v4" | |
"github.com/gofiber/fiber/v2/log" | |
"github.com/redis/go-redis/v9" | |
"gorm.io/gorm" | |
) | |
type SmartRedisCacher struct { | |
rdb *redis.Client | |
} | |
func (c *SmartRedisCacher) Get(ctx context.Context, key string, q *caches.Query[any]) (*caches.Query[any], error) { | |
res, err := c.rdb.Get(ctx, key).Result() | |
if err == redis.Nil { | |
log.Infof("Key %s not found in redis.\n", key) | |
return nil, nil | |
} | |
if err != nil { | |
log.Errorf("failed to get data from redis: %s", err.Error()) | |
return nil, err | |
} | |
if err := q.Unmarshal([]byte(res)); err != nil { | |
log.Errorf("failed to unmarshal data from redis: %s", err.Error()) | |
return nil, err | |
} | |
return q, nil | |
} | |
func (c *SmartRedisCacher) Store(ctx context.Context, key string, val *caches.Query[any]) error { | |
res, err := val.Marshal() | |
if err != nil { | |
return err | |
} | |
if err := c.rdb.Set(ctx, key, res, 300*time.Second).Err(); err != nil { | |
log.Errorf("failed to store data in redis: %s", err.Error()) | |
return err | |
} | |
return nil | |
} | |
func (c *SmartRedisCacher) InvalidateTables(ctx context.Context, tableNames ...string) error { | |
log.Infof("Invalidating cache for tables: %v", tableNames) | |
if len(tableNames) == 0 { | |
return nil | |
} | |
var allKeys []string | |
// For each table, find and collect keys that contain the table name | |
for _, tableName := range tableNames { | |
var ( | |
cursor uint64 | |
keys []string | |
) | |
// Create patterns to match - look for SQL references to this table | |
patterns := []string{ | |
// Standard table reference pattern - "FROM table_name" | |
fmt.Sprintf("*FROM \"%s\"*", tableName), | |
// For joined tables - "JOIN table_name" | |
fmt.Sprintf("*JOIN %s*", tableName), | |
// For references in WHERE clauses - "table_name." | |
fmt.Sprintf("*%s.*", tableName), | |
} | |
// Scan for keys matching each pattern | |
for _, pattern := range patterns { | |
fullPattern := fmt.Sprintf("gorm-caches::%s", pattern) | |
for { | |
k, nextCursor, err := c.rdb.Scan(ctx, cursor, fullPattern, 100).Result() | |
if err != nil { | |
log.Errorf("Error scanning keys for table %s: %v", tableName, err) | |
return err | |
} | |
keys = append(keys, k...) | |
cursor = nextCursor | |
if cursor == 0 { | |
break | |
} | |
} | |
} | |
if len(keys) > 0 { | |
log.Infof("Found %d cache keys for table %s", len(keys), tableName) | |
allKeys = append(allKeys, keys...) | |
} | |
} | |
// Remove duplicate keys | |
uniqueKeys := make(map[string]struct{}) | |
for _, key := range allKeys { | |
uniqueKeys[key] = struct{}{} | |
} | |
// Convert back to slice | |
distinctKeys := make([]string, 0, len(uniqueKeys)) | |
for key := range uniqueKeys { | |
distinctKeys = append(distinctKeys, key) | |
} | |
// Delete all the collected keys | |
if len(distinctKeys) > 0 { | |
log.Infof("Invalidating %d distinct cache keys", len(distinctKeys)) | |
if _, err := c.rdb.Del(ctx, distinctKeys...).Result(); err != nil { | |
log.Errorf("Failed to delete keys from redis: %s", err.Error()) | |
return err | |
} | |
log.Info("Cache invalidation completed successfully") | |
} else { | |
log.Info("No matching keys found for invalidation") | |
} | |
return nil | |
} | |
// Model-specific invalidation | |
func (c *SmartRedisCacher) InvalidateModel(ctx context.Context, modelName string) error { | |
// Convert model name to table name (usually pluralization and snake_case) | |
tableName := ToSnakeCase(modelName) + "s" // Simple conversion, might need more sophistication | |
return c.InvalidateTables(ctx, tableName) | |
} | |
func (c *SmartRedisCacher) InvalidateModels(ctx context.Context, modelNames ...string) error { | |
tableNames := []string{} | |
// Convert model name to table name (usually pluralization and snake_case) | |
for _, modelName := range modelNames { | |
tableNames = append(tableNames, ToSnakeCase(modelName)+"s") | |
} | |
return c.InvalidateTables(ctx, tableNames...) | |
} | |
func (c *SmartRedisCacher) Invalidate(ctx context.Context) error { | |
// completely disable full cache invalidation for now as I dont see any value in that | |
return nil | |
} | |
func (c *SmartRedisCacher) FullCacheInvalidate(ctx context.Context) error { | |
log.Info("Full cache invalidation requested") | |
var ( | |
cursor uint64 | |
keys []string | |
) | |
for attempts := 0; ; attempts++ { | |
if attempts > 1000 { | |
return fmt.Errorf("infinite loop detected in Invalidate") | |
} | |
k, nextCursor, err := c.rdb.Scan(ctx, cursor, "gorm-caches::*", 100).Result() | |
if err != nil { | |
return err | |
} | |
keys = append(keys, k...) | |
cursor = nextCursor | |
if cursor == 0 { | |
break | |
} | |
} | |
if len(keys) > 0 { | |
log.Infof("Full invalidation: removing %d keys", len(keys)) | |
if _, err := c.rdb.Del(ctx, keys...).Result(); err != nil { | |
log.Errorf("Failed to delete keys from redis: %s", err.Error()) | |
return err | |
} | |
} | |
return nil | |
} | |
func (c *SmartRedisCacher) SetupCallbacks(db *gorm.DB) { | |
callback := db.Callback() | |
callback.Create().After("gorm:create").Register("smart_cache:invalidate_after_create", func(db *gorm.DB) { | |
c.InvalidateForScope(db) | |
}) | |
callback.Update().After("gorm:update").Register("smart_cache:invalidate_after_update", func(db *gorm.DB) { | |
c.InvalidateForScope(db) | |
}) | |
callback.Delete().After("gorm:delete").Register("smart_cache:invalidate_after_delete", func(db *gorm.DB) { | |
c.InvalidateForScope(db) | |
}) | |
} | |
// Helper method to extract model information and invalidate cache | |
func (c *SmartRedisCacher) InvalidateForScope(db *gorm.DB) { | |
if db.Statement.Schema == nil { | |
return | |
} | |
// Get the model name and table name | |
modelName := db.Statement.Schema.Name | |
tableName := db.Statement.Table | |
// Log the operation | |
log.Infof("DB operation on model %s (table: %s) - invalidating related caches", modelName, tableName) | |
// Use our smart invalidation | |
ctx := db.Statement.Context | |
if ctx == nil { | |
ctx = context.Background() | |
} | |
// Invalidate by both model name and table name to be thorough | |
_ = c.InvalidateModel(ctx, modelName) | |
_ = c.InvalidateTables(ctx, tableName) | |
} | |
func (c *SmartRedisCacher) InvalidateModelsManually(ctx context.Context, modelNames ...string) error { | |
// this can happen if redis caching is not enabled but the code tries to invalidate the cache | |
if c.rdb == nil { | |
log.Info("Cache is disabled.") | |
return nil | |
} | |
if len(modelNames) == 0 { | |
log.Info("No models to invalidate.") | |
return nil | |
} | |
log.Infof("Manual invalidation requested for models: %v", modelNames) | |
return c.InvalidateModels(ctx, modelNames...) | |
} | |
// Helper for converting CamelCase to snake_case | |
func ToSnakeCase(camel string) string { | |
var result strings.Builder | |
for i, r := range camel { | |
if i > 0 && 'A' <= r && r <= 'Z' { | |
result.WriteRune('_') | |
} | |
result.WriteRune(r) | |
} | |
return strings.ToLower(result.String()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment