feat: add chat cache (#152)

This commit is contained in:
Buer 2024-04-16 10:36:18 +08:00 committed by GitHub
parent bbaa4eec4b
commit 3c7c13758b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 557 additions and 49 deletions

View File

@ -37,6 +37,10 @@ var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false var TurnstileCheckEnabled = false
var RegisterEnabled = true var RegisterEnabled = true
// chat cache
var ChatCacheEnabled = false
var ChatCacheExpireMinute = 5 // 5 Minute
// mj // mj
var MjNotifyEnabled = false var MjNotifyEnabled = false

View File

@ -1,6 +1,7 @@
package common package common
import ( import (
"encoding/json"
"fmt" "fmt"
"html/template" "html/template"
"log" "log"
@ -248,3 +249,16 @@ func EscapeMarkdownText(text string) string {
} }
return text return text
} }
func UnmarshalString[T interface{}](data string) (form T, err error) {
err = json.Unmarshal([]byte(data), &form)
return form, err
}
func Marshal[T interface{}](data T) string {
res, err := json.Marshal(data)
if err != nil {
return ""
}
return string(res)
}

View File

@ -42,6 +42,7 @@ func GetStatus(c *gin.Context) {
"display_in_currency": common.DisplayInCurrencyEnabled, "display_in_currency": common.DisplayInCurrencyEnabled,
"telegram_bot": telegram_bot, "telegram_bot": telegram_bot,
"mj_notify_enabled": common.MjNotifyEnabled, "mj_notify_enabled": common.MjNotifyEnabled,
"chat_cache_enabled": common.ChatCacheEnabled,
}, },
}) })
} }

View File

@ -104,6 +104,7 @@ func AddToken(c *gin.Context) {
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota, RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota, UnlimitedQuota: token.UnlimitedQuota,
ChatCache: token.ChatCache,
} }
err = cleanToken.Insert() err = cleanToken.Insert()
if err != nil { if err != nil {
@ -187,6 +188,7 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.ChatCache = token.ChatCache
} }
err = cleanToken.Update() err = cleanToken.Update()
if err != nil { if err != nil {

37
cron/main.go Normal file
View File

@ -0,0 +1,37 @@
package cron
import (
"one-api/common"
"one-api/model"
"time"
"github.com/go-co-op/gocron/v2"
)
func InitCron() {
scheduler, err := gocron.NewScheduler()
if err != nil {
common.SysLog("Cron scheduler error: " + err.Error())
return
}
// 添加删除cache的任务
_, err = scheduler.NewJob(
gocron.DailyJob(
1,
gocron.NewAtTimes(
gocron.NewAtTime(0, 5, 0),
)),
gocron.NewTask(func() {
model.RemoveChatCache(time.Now().Unix())
common.SysLog("删除过期缓存数据")
}),
)
if err != nil {
common.SysLog("Cron job error: " + err.Error())
return
}
scheduler.Start()
}

5
go.mod
View File

@ -30,11 +30,14 @@ require (
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/chenzhuoyu/iasm v0.9.1 // indirect github.com/chenzhuoyu/iasm v0.9.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/go-co-op/gocron/v2 v2.2.9 // indirect
github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0 // indirect github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jonboulle/clockwork v0.4.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect github.com/magiconair/properties v1.8.7 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect
@ -46,7 +49,7 @@ require (
github.com/wneessen/go-mail v0.4.1 // indirect github.com/wneessen/go-mail v0.4.1 // indirect
go.uber.org/atomic v1.9.0 // indirect go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect
golang.org/x/sync v0.6.0 // indirect golang.org/x/sync v0.6.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
) )

8
go.sum
View File

@ -71,6 +71,8 @@ github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q=
github.com/go-co-op/gocron/v2 v2.2.9 h1:aoKosYWSSdXFLecjFWX1i8+R6V7XdZb8sB2ZKAY5Yis=
github.com/go-co-op/gocron/v2 v2.2.9/go.mod h1:mZx3gMSlFnb97k3hRqX3+GdlG3+DUwTh6B8fnsTScXg=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
@ -152,6 +154,8 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4=
github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
@ -219,6 +223,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg= github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@ -284,6 +290,8 @@ golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 h1:+iq7lrkxmFNBM7xx+Rae2W6uyPfhPeDWD+n+JgppptE=
golang.org/x/exp v0.0.0-20231219180239-dc181d75b848/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=

View File

@ -9,6 +9,7 @@ import (
"one-api/common/requester" "one-api/common/requester"
"one-api/common/telegram" "one-api/common/telegram"
"one-api/controller" "one-api/controller"
"one-api/cron"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/relay/util" "one-api/relay/util"
@ -48,6 +49,7 @@ func main() {
controller.InitMidjourneyTask() controller.InitMidjourneyTask()
notify.InitNotifier() notify.InitNotifier()
cron.InitCron()
initHttpServer() initHttpServer()
} }

View File

@ -105,6 +105,7 @@ func tokenAuth(c *gin.Context, key string) {
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_name", token.Name) c.Set("token_name", token.Name)
c.Set("chat_cache", token.ChatCache)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
channelId := common.String2Int(parts[1]) channelId := common.String2Int(parts[1])

41
model/chat_cache.go Normal file
View File

@ -0,0 +1,41 @@
package model
import (
"time"
"gorm.io/gorm/clause"
)
type ChatCache struct {
Hash string `json:"hash" gorm:"type:varchar(32);primaryKey"`
UserId int `json:"user_id" gorm:"type:int;not null;index"`
Data string `json:"data" gorm:"type:json;not null"`
Expiration int64 `json:"expiration" gorm:"type:bigint;not null;index"`
}
func (cache *ChatCache) Insert() error {
return DB.Clauses(clause.OnConflict{
UpdateAll: true,
}).Create(cache).Error
}
func GetChatCache(hash string, userId int) (*ChatCache, error) {
var chatCache ChatCache
// 获取当前时间戳
now := time.Now().Unix()
err := DB.Where("hash = ? and user_id = ? and expiration > ?", hash, userId, now).Find(&chatCache).Error
return &chatCache, err
}
func GetChatCacheListByUserId(userId int) ([]*ChatCache, error) {
var chatCaches []*ChatCache
// 获取当前时间戳
now := time.Now().Unix()
err := DB.Where("user_id = ? and expiration >", userId, now).Find(&chatCaches).Error
return chatCaches, err
}
func RemoveChatCache(expiration int64) error {
now := time.Now().Unix()
return DB.Where("expiration < ?", now).Delete(ChatCache{}).Error
}

View File

@ -144,6 +144,10 @@ func InitDB() (err error) {
if err != nil { if err != nil {
return err return err
} }
err = db.AutoMigrate(&ChatCache{})
if err != nil {
return err
}
common.SysLog("database migrated") common.SysLog("database migrated")
err = createRootAccountIfNeed() err = createRootAccountIfNeed()
return err return err

View File

@ -76,6 +76,9 @@ func InitOptionMap() {
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled) common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled)
common.OptionMap["ChatCacheEnabled"] = strconv.FormatBool(common.ChatCacheEnabled)
common.OptionMap["ChatCacheExpireMinute"] = strconv.Itoa(common.ChatCacheExpireMinute)
common.OptionMapRWMutex.Unlock() common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase() loadOptionsFromDatabase()
} }
@ -115,14 +118,15 @@ func UpdateOption(key string, value string) error {
} }
var optionIntMap = map[string]*int{ var optionIntMap = map[string]*int{
"SMTPPort": &common.SMTPPort, "SMTPPort": &common.SMTPPort,
"QuotaForNewUser": &common.QuotaForNewUser, "QuotaForNewUser": &common.QuotaForNewUser,
"QuotaForInviter": &common.QuotaForInviter, "QuotaForInviter": &common.QuotaForInviter,
"QuotaForInvitee": &common.QuotaForInvitee, "QuotaForInvitee": &common.QuotaForInvitee,
"QuotaRemindThreshold": &common.QuotaRemindThreshold, "QuotaRemindThreshold": &common.QuotaRemindThreshold,
"PreConsumedQuota": &common.PreConsumedQuota, "PreConsumedQuota": &common.PreConsumedQuota,
"RetryTimes": &common.RetryTimes, "RetryTimes": &common.RetryTimes,
"RetryCooldownSeconds": &common.RetryCooldownSeconds, "RetryCooldownSeconds": &common.RetryCooldownSeconds,
"ChatCacheExpireMinute": &common.ChatCacheExpireMinute,
} }
var optionBoolMap = map[string]*bool{ var optionBoolMap = map[string]*bool{
@ -141,6 +145,7 @@ var optionBoolMap = map[string]*bool{
"DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled, "DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled,
"DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled, "DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled,
"MjNotifyEnabled": &common.MjNotifyEnabled, "MjNotifyEnabled": &common.MjNotifyEnabled,
"ChatCacheEnabled": &common.ChatCacheEnabled,
} }
var optionStringMap = map[string]*string{ var optionStringMap = map[string]*string{

View File

@ -2,6 +2,7 @@ package model
import ( import (
"errors" "errors"
"fmt"
"one-api/common" "one-api/common"
"one-api/common/stmp" "one-api/common/stmp"
@ -20,6 +21,7 @@ type Token struct {
RemainQuota int `json:"remain_quota" gorm:"default:0"` RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
ChatCache bool `json:"chat_cache" gorm:"default:false"`
} }
var allowedTokenOrderFields = map[string]bool{ var allowedTokenOrderFields = map[string]bool{
@ -40,7 +42,7 @@ func GetUserTokensList(userId int, params *GenericParams) (*DataResult[Token], e
db = db.Where("name LIKE ?", params.Keyword+"%") db = db.Where("name LIKE ?", params.Keyword+"%")
} }
return PaginateAndOrder[Token](db, &params.PaginationParams, &tokens, allowedTokenOrderFields) return PaginateAndOrder(db, &params.PaginationParams, &tokens, allowedTokenOrderFields)
} }
// 获取状态为可用的令牌 // 获取状态为可用的令牌
@ -114,13 +116,26 @@ func GetTokenById(id int) (*Token, error) {
} }
func (token *Token) Insert() error { func (token *Token) Insert() error {
if token.ChatCache && !common.ChatCacheEnabled {
token.ChatCache = false
}
err := DB.Create(token).Error err := DB.Create(token).Error
return err return err
} }
// Update Make sure your token's fields is completed, because this will update non-zero values // Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error { func (token *Token) Update() error {
err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error if token.ChatCache && !common.ChatCacheEnabled {
token.ChatCache = false
}
err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "chat_cache").Updates(token).Error
// 防止Redis缓存不生效直接删除
if err == nil && common.RedisEnabled {
common.RedisDel(fmt.Sprintf("token:%s", token.Key))
}
return err return err
} }

View File

@ -1,6 +1,7 @@
package relay package relay
import ( import (
"one-api/relay/util"
"one-api/types" "one-api/types"
providersBase "one-api/providers/base" providersBase "one-api/providers/base"
@ -13,17 +14,33 @@ type relayBase struct {
provider providersBase.ProviderInterface provider providersBase.ProviderInterface
originalModel string originalModel string
modelName string modelName string
cache *util.ChatCacheProps
} }
type RelayBaseInterface interface { type RelayBaseInterface interface {
send() (err *types.OpenAIErrorWithStatusCode, done bool) send() (err *types.OpenAIErrorWithStatusCode, done bool)
getPromptTokens() (int, error) getPromptTokens() (int, error)
setRequest() error setRequest() error
getRequest() any
setProvider(modelName string) error setProvider(modelName string) error
getProvider() providersBase.ProviderInterface getProvider() providersBase.ProviderInterface
getOriginalModel() string getOriginalModel() string
getModelName() string getModelName() string
getContext() *gin.Context getContext() *gin.Context
SetChatCache(allow bool)
GetChatCache() *util.ChatCacheProps
}
func (r *relayBase) SetChatCache(allow bool) {
r.cache = util.NewChatCacheProps(r.c, allow)
}
func (r *relayBase) GetChatCache() *util.ChatCacheProps {
return r.cache
}
func (r *relayBase) getRequest() interface{} {
return nil
} }
func (r *relayBase) setProvider(modelName string) error { func (r *relayBase) setProvider(modelName string) error {

View File

@ -37,6 +37,10 @@ func (r *relayChat) setRequest() error {
return nil return nil
} }
func (r *relayChat) getRequest() interface{} {
return &r.chatRequest
}
func (r *relayChat) getPromptTokens() (int, error) { func (r *relayChat) getPromptTokens() (int, error) {
return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil
} }
@ -58,7 +62,7 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
return return
} }
err = responseStreamClient(r.c, response) err = responseStreamClient(r.c, response, r.cache)
} else { } else {
var response *types.ChatCompletionResponse var response *types.ChatCompletionResponse
response, err = chatProvider.CreateChatCompletion(&r.chatRequest) response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
@ -66,6 +70,7 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
return return
} }
err = responseJsonClient(r.c, response) err = responseJsonClient(r.c, response)
r.cache.SetResponse(response)
} }
if err != nil { if err != nil {

View File

@ -13,6 +13,7 @@ import (
"one-api/model" "one-api/model"
"one-api/providers" "one-api/providers"
providersBase "one-api/providers/base" providersBase "one-api/providers/base"
"one-api/relay/util"
"one-api/types" "one-api/types"
"strings" "strings"
@ -20,29 +21,37 @@ import (
) )
func Path2Relay(c *gin.Context, path string) RelayBaseInterface { func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
allowCache := false
var relay RelayBaseInterface
if strings.HasPrefix(path, "/v1/chat/completions") { if strings.HasPrefix(path, "/v1/chat/completions") {
return NewRelayChat(c) allowCache = true
relay = NewRelayChat(c)
} else if strings.HasPrefix(path, "/v1/completions") { } else if strings.HasPrefix(path, "/v1/completions") {
return NewRelayCompletions(c) allowCache = true
relay = NewRelayCompletions(c)
} else if strings.HasPrefix(path, "/v1/embeddings") { } else if strings.HasPrefix(path, "/v1/embeddings") {
return NewRelayEmbeddings(c) relay = NewRelayEmbeddings(c)
} else if strings.HasPrefix(path, "/v1/moderations") { } else if strings.HasPrefix(path, "/v1/moderations") {
return NewRelayModerations(c) relay = NewRelayModerations(c)
} else if strings.HasPrefix(path, "/v1/images/generations") { } else if strings.HasPrefix(path, "/v1/images/generations") {
return NewRelayImageGenerations(c) relay = NewRelayImageGenerations(c)
} else if strings.HasPrefix(path, "/v1/images/edits") { } else if strings.HasPrefix(path, "/v1/images/edits") {
return NewRelayImageEdits(c) relay = NewRelayImageEdits(c)
} else if strings.HasPrefix(path, "/v1/images/variations") { } else if strings.HasPrefix(path, "/v1/images/variations") {
return NewRelayImageVariations(c) relay = NewRelayImageVariations(c)
} else if strings.HasPrefix(path, "/v1/audio/speech") { } else if strings.HasPrefix(path, "/v1/audio/speech") {
return NewRelaySpeech(c) relay = NewRelaySpeech(c)
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") { } else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
return NewRelayTranscriptions(c) relay = NewRelayTranscriptions(c)
} else if strings.HasPrefix(path, "/v1/audio/translations") { } else if strings.HasPrefix(path, "/v1/audio/translations") {
return NewRelayTranslations(c) relay = NewRelayTranslations(c)
} }
return nil if relay != nil {
relay.SetChatCache(allowCache)
}
return relay
} }
func GetProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) { func GetProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
@ -120,7 +129,7 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
return nil return nil
} }
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string]) *types.OpenAIErrorWithStatusCode { func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps) (errWithOP *types.OpenAIErrorWithStatusCode) {
requester.SetEventStreamHeaders(c) requester.SetEventStreamHeaders(c)
dataChan, errChan := stream.Recv() dataChan, errChan := stream.Recv()
@ -128,19 +137,24 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
fmt.Fprintln(w, "data: "+data+"\n") streamData := "data: " + data + "\n\n"
fmt.Fprint(w, streamData)
cache.SetResponse(streamData)
return true return true
case err := <-errChan: case err := <-errChan:
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
fmt.Fprintln(w, "data: "+err.Error()+"\n") fmt.Fprint(w, "data: "+err.Error()+"\n\n")
errWithOP = common.ErrorWrapper(err, "stream_error", http.StatusInternalServerError)
} }
fmt.Fprintln(w, "data: [DONE]") streamData := "data: [DONE]\n"
fmt.Fprint(w, streamData)
cache.SetResponse(streamData)
return false return false
} }
}) })
return nil return errWithOP
} }
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode { func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {
@ -174,6 +188,22 @@ func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types
return nil return nil
} }
func responseCache(c *gin.Context, response string) {
// 检查是否是 data: 开头的流式数据
isStream := strings.HasPrefix(response, "data: ")
if isStream {
requester.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
fmt.Fprint(w, response)
return false
})
} else {
c.Data(http.StatusOK, "application/json", []byte(response))
}
}
func shouldRetry(c *gin.Context, statusCode int) bool { func shouldRetry(c *gin.Context, statusCode int) bool {
channelId := c.GetInt("specific_channel_id") channelId := c.GetInt("specific_channel_id")
if channelId > 0 { if channelId > 0 {

View File

@ -37,6 +37,10 @@ func (r *relayCompletions) setRequest() error {
return nil return nil
} }
func (r *relayCompletions) getRequest() interface{} {
return &r.request
}
func (r *relayCompletions) getPromptTokens() (int, error) { func (r *relayCompletions) getPromptTokens() (int, error) {
return common.CountTokenInput(r.request.Prompt, r.modelName), nil return common.CountTokenInput(r.request.Prompt, r.modelName), nil
} }
@ -58,7 +62,7 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
return return
} }
err = responseStreamClient(r.c, response) err = responseStreamClient(r.c, response, r.cache)
} else { } else {
var response *types.CompletionResponse var response *types.CompletionResponse
response, err = provider.CreateCompletion(&r.request) response, err = provider.CreateCompletion(&r.request)
@ -66,6 +70,7 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
return return
} }
err = responseJsonClient(r.c, response) err = responseJsonClient(r.c, response)
r.cache.SetResponse(response)
} }
if err != nil { if err != nil {

View File

@ -7,6 +7,7 @@ import (
"one-api/model" "one-api/model"
"one-api/relay/util" "one-api/relay/util"
"one-api/types" "one-api/types"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -23,6 +24,18 @@ func Relay(c *gin.Context) {
return return
} }
cacheProps := relay.GetChatCache()
cacheProps.SetHash(relay.getRequest())
// 获取缓存
cache := cacheProps.GetCache()
if cache != nil {
// 说明有缓存, 直接返回缓存内容
cacheProcessing(c, cache)
return
}
if err := relay.setProvider(relay.getOriginalModel()); err != nil { if err := relay.setProvider(relay.getOriginalModel()); err != nil {
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error()) common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
return return
@ -103,5 +116,25 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod
} }
quota.Consume(relay.getContext(), usage) quota.Consume(relay.getContext(), usage)
cacheProps := relay.GetChatCache()
go cacheProps.StoreCache(relay.getContext().GetInt("channel_id"), usage.PromptTokens, usage.CompletionTokens, relay.getModelName())
return return
} }
func cacheProcessing(c *gin.Context, cacheProps *util.ChatCacheProps) {
responseCache(c, cacheProps.Response)
// 写入日志
tokenName := c.GetString("token_name")
requestTime := 0
requestStartTimeValue := c.Request.Context().Value("requestStartTime")
if requestStartTimeValue != nil {
requestStartTime, ok := requestStartTimeValue.(time.Time)
if ok {
requestTime = int(time.Since(requestStartTime).Milliseconds())
}
}
model.RecordConsumeLog(c.Request.Context(), cacheProps.UserId, cacheProps.ChannelID, cacheProps.PromptTokens, cacheProps.CompletionTokens, cacheProps.ModelName, tokenName, 0, "缓存", requestTime)
}

128
relay/util/cache.go Normal file
View File

@ -0,0 +1,128 @@
package util
import (
"crypto/md5"
"encoding/hex"
"fmt"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
type ChatCacheProps struct {
UserId int `json:"user_id"`
TokenId int `json:"token_id"`
ChannelID int `json:"channel_id"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ModelName string `json:"model_name"`
Response string `json:"response"`
Hash string `json:"-"`
Cache bool `json:"-"`
Driver CacheDriver `json:"-"`
}
type CacheDriver interface {
Get(hash string, userId int) *ChatCacheProps
Set(hash string, props *ChatCacheProps, expire int64) error
}
func GetDebugList(userId int) ([]*ChatCacheProps, error) {
caches, err := model.GetChatCacheListByUserId(userId)
if err != nil {
return nil, err
}
var props []*ChatCacheProps
for _, cache := range caches {
prop, err := common.UnmarshalString[ChatCacheProps](cache.Data)
if err != nil {
continue
}
props = append(props, &prop)
}
return props, nil
}
func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps {
props := &ChatCacheProps{
Cache: false,
}
if !allow {
return props
}
if common.ChatCacheEnabled && c.GetBool("chat_cache") {
props.Cache = true
}
if common.RedisEnabled {
props.Driver = &ChatCacheRedis{}
} else {
props.Driver = &ChatCacheDB{}
}
props.UserId = c.GetInt("id")
props.TokenId = c.GetInt("token_id")
return props
}
func (p *ChatCacheProps) SetHash(request any) {
if !p.needCache() || request == nil {
return
}
p.hash(common.Marshal(request))
}
func (p *ChatCacheProps) SetResponse(response any) {
if !p.needCache() || response == nil {
return
}
if str, ok := response.(string); ok {
p.Response += str
return
}
p.Response = common.Marshal(response)
}
func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens int, modelName string) error {
if !p.needCache() || p.Response == "" {
return nil
}
p.ChannelID = channelId
p.PromptTokens = promptTokens
p.CompletionTokens = completionTokens
p.ModelName = modelName
return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute))
}
func (p *ChatCacheProps) GetCache() *ChatCacheProps {
if !p.needCache() {
return nil
}
return p.Driver.Get(p.getHash(), p.UserId)
}
func (p *ChatCacheProps) needCache() bool {
return common.ChatCacheEnabled && p.Cache
}
func (p *ChatCacheProps) getHash() string {
return p.Hash
}
func (p *ChatCacheProps) hash(request string) {
hash := md5.Sum([]byte(fmt.Sprintf("%d-%d-%s", p.UserId, p.TokenId, request)))
p.Hash = hex.EncodeToString(hash[:])
}

47
relay/util/cache_db.go Normal file
View File

@ -0,0 +1,47 @@
package util
import (
"errors"
"one-api/common"
"one-api/model"
"time"
)
type ChatCacheDB struct{}
func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps {
cache, _ := model.GetChatCache(hash, userId)
if cache == nil {
return nil
}
props, err := common.UnmarshalString[ChatCacheProps](cache.Data)
if err != nil {
return nil
}
return &props
}
func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) error {
return SetCacheDB(hash, props, expire)
}
func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error {
data := common.Marshal(props)
if data == "" {
return errors.New("marshal error")
}
expire = expire * 60
expire += time.Now().Unix()
cache := &model.ChatCache{
Hash: hash,
UserId: props.UserId,
Data: data,
Expiration: expire,
}
return cache.Insert()
}

44
relay/util/cache_redis.go Normal file
View File

@ -0,0 +1,44 @@
package util
import (
"errors"
"fmt"
"one-api/common"
"time"
)
type ChatCacheRedis struct{}
var chatCacheKey = "chat_cache"
func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps {
cache, err := common.RedisGet(r.getKey(hash, userId))
if err != nil {
return nil
}
props, err := common.UnmarshalString[ChatCacheProps](cache)
if err != nil {
return nil
}
return &props
}
func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) error {
if !props.Cache {
return nil
}
data := common.Marshal(&props)
if data == "" {
return errors.New("marshal error")
}
return common.RedisSet(r.getKey(hash, props.UserId), data, time.Duration(expire)*time.Minute)
}
func (r *ChatCacheRedis) getKey(hash string, userId int) string {
return fmt.Sprintf("%s:%d:%s", chatCacheKey, userId, hash)
}

View File

@ -30,7 +30,9 @@ const OperationSetting = () => {
ApproximateTokenEnabled: '', ApproximateTokenEnabled: '',
RetryTimes: 0, RetryTimes: 0,
RetryCooldownSeconds: 0, RetryCooldownSeconds: 0,
MjNotifyEnabled: '' MjNotifyEnabled: '',
ChatCacheEnabled: '',
ChatCacheExpireMinute: 5
}); });
const [originInputs, setOriginInputs] = useState({}); const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false); let [loading, setLoading] = useState(false);
@ -152,6 +154,11 @@ const OperationSetting = () => {
await updateOption('RetryCooldownSeconds', inputs.RetryCooldownSeconds); await updateOption('RetryCooldownSeconds', inputs.RetryCooldownSeconds);
} }
break; break;
case 'other':
if (originInputs['ChatCacheExpireMinute'] !== inputs.ChatCacheExpireMinute) {
await updateOption('ChatCacheExpireMinute', inputs.ChatCacheExpireMinute);
}
break;
} }
showSuccess('保存成功!'); showSuccess('保存成功!');
@ -292,7 +299,34 @@ const OperationSetting = () => {
label="Midjourney 允许回调会泄露服务器ip地址" label="Midjourney 允许回调会泄露服务器ip地址"
control={<Checkbox checked={inputs.MjNotifyEnabled === 'true'} onChange={handleInputChange} name="MjNotifyEnabled" />} control={<Checkbox checked={inputs.MjNotifyEnabled === 'true'} onChange={handleInputChange} name="MjNotifyEnabled" />}
/> />
<FormControlLabel
sx={{ marginLeft: '0px' }}
label="是否开启聊天缓存(如果没有启用Redis将会存储在数据库中)"
control={<Checkbox checked={inputs.ChatCacheEnabled === 'true'} onChange={handleInputChange} name="ChatCacheEnabled" />}
/>
</Stack> </Stack>
<Stack direction={{ sm: 'column', md: 'row' }} spacing={{ xs: 3, sm: 2, md: 4 }}>
<FormControl>
<InputLabel htmlFor="ChatCacheExpireMinute">缓存时间(分钟)</InputLabel>
<OutlinedInput
id="ChatCacheExpireMinute"
name="ChatCacheExpireMinute"
value={inputs.ChatCacheExpireMinute}
onChange={handleInputChange}
label="缓存时间(分钟)"
placeholder="开启缓存时,数据缓存的时间"
disabled={loading}
/>
</FormControl>
</Stack>
<Button
variant="contained"
onClick={() => {
submitConfig('other').then();
}}
>
保存其他设置
</Button>
</Stack> </Stack>
</SubCard> </SubCard>
<SubCard title="日志设置"> <SubCard title="日志设置">

View File

@ -17,6 +17,7 @@ import {
OutlinedInput, OutlinedInput,
InputAdornment, InputAdornment,
Switch, Switch,
FormControlLabel,
FormHelperText FormHelperText
} from '@mui/material'; } from '@mui/material';
@ -25,6 +26,7 @@ import { LocalizationProvider } from '@mui/x-date-pickers/LocalizationProvider';
import { DateTimePicker } from '@mui/x-date-pickers/DateTimePicker'; import { DateTimePicker } from '@mui/x-date-pickers/DateTimePicker';
import { renderQuotaWithPrompt, showSuccess, showError } from 'utils/common'; import { renderQuotaWithPrompt, showSuccess, showError } from 'utils/common';
import { API } from 'utils/api'; import { API } from 'utils/api';
import { useSelector } from 'react-redux';
require('dayjs/locale/zh-cn'); require('dayjs/locale/zh-cn');
const validationSchema = Yup.object().shape({ const validationSchema = Yup.object().shape({
@ -40,12 +42,14 @@ const originInputs = {
name: '', name: '',
remain_quota: 0, remain_quota: 0,
expired_time: -1, expired_time: -1,
unlimited_quota: false unlimited_quota: false,
chat_cache: false
}; };
const EditModal = ({ open, tokenId, onCancel, onOk }) => { const EditModal = ({ open, tokenId, onCancel, onOk }) => {
const theme = useTheme(); const theme = useTheme();
const [inputs, setInputs] = useState(originInputs); const [inputs, setInputs] = useState(originInputs);
const siteInfo = useSelector((state) => state.siteInfo);
const submit = async (values, { setErrors, setStatus, setSubmitting }) => { const submit = async (values, { setErrors, setStatus, setSubmitting }) => {
setSubmitting(true); setSubmitting(true);
@ -163,17 +167,22 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => {
)} )}
</FormControl> </FormControl>
)} )}
<Switch <FormControlLabel
checked={values.expired_time === -1} control={
onClick={() => { <Switch
if (values.expired_time === -1) { checked={values.expired_time === -1}
setFieldValue('expired_time', Math.floor(Date.now() / 1000)); onClick={() => {
} else { if (values.expired_time === -1) {
setFieldValue('expired_time', -1); setFieldValue('expired_time', Math.floor(Date.now() / 1000));
} } else {
}} setFieldValue('expired_time', -1);
/>{' '} }
永不过期 }}
/>
}
label="永不过期"
/>
<FormControl fullWidth error={Boolean(touched.remain_quota && errors.remain_quota)} sx={{ ...theme.typography.otherInput }}> <FormControl fullWidth error={Boolean(touched.remain_quota && errors.remain_quota)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-remain_quota-label">额度</InputLabel> <InputLabel htmlFor="channel-remain_quota-label">额度</InputLabel>
<OutlinedInput <OutlinedInput
@ -195,13 +204,32 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => {
</FormHelperText> </FormHelperText>
)} )}
</FormControl> </FormControl>
<Switch <FormControl fullWidth>
checked={values.unlimited_quota === true} <FormControlLabel
onClick={() => { control={
setFieldValue('unlimited_quota', !values.unlimited_quota); <Switch
}} checked={values.unlimited_quota === true}
/>{' '} onClick={() => {
无限额度 setFieldValue('unlimited_quota', !values.unlimited_quota);
}}
/>
}
label="无限额度"
/>
{siteInfo.chat_cache_enabled && (
<FormControlLabel
control={
<Switch
checked={values.chat_cache}
onClick={() => {
setFieldValue('chat_cache', !values.chat_cache);
}}
/>
}
label="是否开启缓存(开启后,将会缓存聊天记录,以减少消费)"
/>
)}
</FormControl>
<DialogActions> <DialogActions>
<Button onClick={onCancel}>取消</Button> <Button onClick={onCancel}>取消</Button>
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary"> <Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">