core/database/gorm-cache/cache.go

201 lines
3.6 KiB
Go

package cache
import (
"context"
"hash/fnv"
"os"
"strconv"
"time"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
)
type Config struct {
Store Store
Prefix string
Serializer Serializer
}
type (
Serializer interface {
Serialize(v any) ([]byte, error)
Deserialize(data []byte, v any) error
}
Store interface {
// Set 写入缓存数据
Set(ctx context.Context, key string, value any, ttl time.Duration) error
// Get 获取缓存数据
Get(ctx context.Context, key string) ([]byte, error)
// SaveTagKey 将缓存key写入tag
SaveTagKey(ctx context.Context, tag, key string) error
// RemoveFromTag 根据缓存tag删除缓存
RemoveFromTag(ctx context.Context, tag string) error
}
)
type Cache struct {
store Store
// Serializer 序列化
Serializer Serializer
// prefix 缓存前缀
prefix string
}
// New
// @param conf
// @date 2022-07-02 08:09:52
func New(conf *Config) *Cache {
if conf.Store == nil {
os.Exit(1)
}
if conf.Serializer == nil {
conf.Serializer = &DefaultJSONSerializer{}
}
return &Cache{
store: conf.Store,
prefix: conf.Prefix,
Serializer: conf.Serializer,
}
}
// Name
// @date 2022-07-02 08:09:48
func (p *Cache) Name() string {
return "gorm:cache"
}
// Initialize
// @param tx
// @date 2022-07-02 08:09:47
func (p *Cache) Initialize(tx *gorm.DB) error {
return tx.Callback().Query().Replace("gorm:query", p.Query)
}
// generateKey
// @param key
// @date 2022-07-02 08:09:46
func generateKey(key string) string {
hash := fnv.New64a()
_, _ = hash.Write([]byte(key))
return strconv.FormatUint(hash.Sum64(), 36)
}
// Query
// @param tx
// @date 2022-07-02 08:09:38
func (p *Cache) Query(tx *gorm.DB) {
ctx := tx.Statement.Context
var ttl time.Duration
var hasTTL bool
if ttl, hasTTL = FromExpiration(ctx); !hasTTL {
callbacks.Query(tx)
return
}
var (
key string
hasKey bool
)
// 调用 Gorm的方法生产SQL
callbacks.BuildQuerySQL(tx)
// 是否有自定义key
if key, hasKey = FromKey(ctx); !hasKey {
key = p.prefix + generateKey(tx.Statement.SQL.String())
}
// 查询缓存数据
if err := p.QueryCache(ctx, key, tx.Statement.Dest); err == nil {
return
}
// 查询数据库
p.QueryDB(tx)
if tx.Error != nil {
return
}
// 写入缓存
if err := p.SaveCache(ctx, key, tx.Statement.Dest, ttl); err != nil {
tx.Logger.Error(ctx, err.Error())
return
}
if tag, hasTag := FromTag(ctx); hasTag {
_ = p.store.SaveTagKey(ctx, tag, key)
}
}
// QueryDB 查询数据库数据
// 这里重写Query方法 是不想执行 callbacks.BuildQuerySQL 两遍
func (p *Cache) QueryDB(tx *gorm.DB) {
if tx.Error != nil || tx.DryRun {
return
}
rows, err := tx.Statement.ConnPool.QueryContext(tx.Statement.Context, tx.Statement.SQL.String(), tx.Statement.Vars...)
if err != nil {
_ = tx.AddError(err)
return
}
defer func() {
_ = tx.AddError(rows.Close())
}()
gorm.Scan(rows, tx, 0)
}
// QueryCache 查询缓存数据
// @param ctx
// @param key
// @param dest
func (p *Cache) QueryCache(ctx context.Context, key string, dest any) error {
values, err := p.store.Get(ctx, key)
if err != nil {
return err
}
switch dest.(type) {
case *int64:
dest = 0
}
return p.Serializer.Deserialize(values, dest)
}
// SaveCache 写入缓存数据
func (p *Cache) SaveCache(ctx context.Context, key string, dest any, ttl time.Duration) error {
values, err := p.Serializer.Serialize(dest)
if err != nil {
return err
}
return p.store.Set(ctx, key, values, ttl)
}
// RemoveFromTag 根据tag删除缓存数据
// @param ctx
// @param tag
// @date 2022-07-02 08:08:59
func (p *Cache) RemoveFromTag(ctx context.Context, tag string) error {
return p.store.RemoveFromTag(ctx, tag)
}