From 3c7c13758bccc49a83f34a1b88d41958f65e8240 Mon Sep 17 00:00:00 2001
From: Buer <42402987+MartialBE@users.noreply.github.com>
Date: Tue, 16 Apr 2024 10:36:18 +0800
Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20chat=20cache=20(#152)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
common/constants.go | 4 +
common/utils.go | 14 ++
controller/misc.go | 1 +
controller/token.go | 2 +
cron/main.go | 37 +++++
go.mod | 5 +-
go.sum | 8 ++
main.go | 2 +
middleware/auth.go | 1 +
model/chat_cache.go | 41 ++++++
model/main.go | 4 +
model/option.go | 21 +--
model/token.go | 19 ++-
relay/base.go | 17 +++
relay/chat.go | 7 +-
relay/common.go | 62 ++++++---
relay/completions.go | 7 +-
relay/main.go | 33 +++++
relay/util/cache.go | 128 ++++++++++++++++++
relay/util/cache_db.go | 47 +++++++
relay/util/cache_redis.go | 44 ++++++
.../Setting/component/OperationSetting.js | 36 ++++-
web/src/views/Token/component/EditModal.js | 66 ++++++---
23 files changed, 557 insertions(+), 49 deletions(-)
create mode 100644 cron/main.go
create mode 100644 model/chat_cache.go
create mode 100644 relay/util/cache.go
create mode 100644 relay/util/cache_db.go
create mode 100644 relay/util/cache_redis.go
diff --git a/common/constants.go b/common/constants.go
index b44ca170..192303b2 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -37,6 +37,10 @@ var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
+// chat cache
+var ChatCacheEnabled = false
+var ChatCacheExpireMinute = 5 // 5 Minute
+
// mj
var MjNotifyEnabled = false
diff --git a/common/utils.go b/common/utils.go
index ac6bafe6..8b0bd555 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -1,6 +1,7 @@
package common
import (
+ "encoding/json"
"fmt"
"html/template"
"log"
@@ -248,3 +249,16 @@ func EscapeMarkdownText(text string) string {
}
return text
}
+
+func UnmarshalString[T interface{}](data string) (form T, err error) {
+ err = json.Unmarshal([]byte(data), &form)
+ return form, err
+}
+
+func Marshal[T interface{}](data T) string {
+ res, err := json.Marshal(data)
+ if err != nil {
+ return ""
+ }
+ return string(res)
+}
diff --git a/controller/misc.go b/controller/misc.go
index 3eae703d..1ad29e63 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -42,6 +42,7 @@ func GetStatus(c *gin.Context) {
"display_in_currency": common.DisplayInCurrencyEnabled,
"telegram_bot": telegram_bot,
"mj_notify_enabled": common.MjNotifyEnabled,
+ "chat_cache_enabled": common.ChatCacheEnabled,
},
})
}
diff --git a/controller/token.go b/controller/token.go
index 608d630b..67a576be 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -104,6 +104,7 @@ func AddToken(c *gin.Context) {
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
+ ChatCache: token.ChatCache,
}
err = cleanToken.Insert()
if err != nil {
@@ -187,6 +188,7 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota
+ cleanToken.ChatCache = token.ChatCache
}
err = cleanToken.Update()
if err != nil {
diff --git a/cron/main.go b/cron/main.go
new file mode 100644
index 00000000..8e98da0f
--- /dev/null
+++ b/cron/main.go
@@ -0,0 +1,37 @@
+package cron
+
+import (
+ "one-api/common"
+ "one-api/model"
+ "time"
+
+ "github.com/go-co-op/gocron/v2"
+)
+
+func InitCron() {
+ scheduler, err := gocron.NewScheduler()
+ if err != nil {
+ common.SysLog("Cron scheduler error: " + err.Error())
+ return
+ }
+
+ // 添加删除cache的任务
+ _, err = scheduler.NewJob(
+ gocron.DailyJob(
+ 1,
+ gocron.NewAtTimes(
+ gocron.NewAtTime(0, 5, 0),
+ )),
+ gocron.NewTask(func() {
+ model.RemoveChatCache(time.Now().Unix())
+ common.SysLog("删除过期缓存数据")
+ }),
+ )
+
+ if err != nil {
+ common.SysLog("Cron job error: " + err.Error())
+ return
+ }
+
+ scheduler.Start()
+}
diff --git a/go.mod b/go.mod
index da655611..57d6cd1a 100644
--- a/go.mod
+++ b/go.mod
@@ -30,11 +30,14 @@ require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/chenzhuoyu/iasm v0.9.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
+ github.com/go-co-op/gocron/v2 v2.2.9 // indirect
github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
+ github.com/jonboulle/clockwork v0.4.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
+ github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
@@ -46,7 +49,7 @@ require (
github.com/wneessen/go-mail v0.4.1 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
- golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
+ golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect
golang.org/x/sync v0.6.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)
diff --git a/go.sum b/go.sum
index a1fa358e..bfa82581 100644
--- a/go.sum
+++ b/go.sum
@@ -71,6 +71,8 @@ github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q=
+github.com/go-co-op/gocron/v2 v2.2.9 h1:aoKosYWSSdXFLecjFWX1i8+R6V7XdZb8sB2ZKAY5Yis=
+github.com/go-co-op/gocron/v2 v2.2.9/go.mod h1:mZx3gMSlFnb97k3hRqX3+GdlG3+DUwTh6B8fnsTScXg=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
@@ -152,6 +154,8 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
+github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4=
+github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
@@ -219,6 +223,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg=
+github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
+github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@@ -284,6 +290,8 @@ golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
+golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 h1:+iq7lrkxmFNBM7xx+Rae2W6uyPfhPeDWD+n+JgppptE=
+golang.org/x/exp v0.0.0-20231219180239-dc181d75b848/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
diff --git a/main.go b/main.go
index df5478df..ac44532e 100644
--- a/main.go
+++ b/main.go
@@ -9,6 +9,7 @@ import (
"one-api/common/requester"
"one-api/common/telegram"
"one-api/controller"
+ "one-api/cron"
"one-api/middleware"
"one-api/model"
"one-api/relay/util"
@@ -48,6 +49,7 @@ func main() {
controller.InitMidjourneyTask()
notify.InitNotifier()
+ cron.InitCron()
initHttpServer()
}
diff --git a/middleware/auth.go b/middleware/auth.go
index 91e3dd56..30ccc6ef 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -105,6 +105,7 @@ func tokenAuth(c *gin.Context, key string) {
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
+ c.Set("chat_cache", token.ChatCache)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
channelId := common.String2Int(parts[1])
diff --git a/model/chat_cache.go b/model/chat_cache.go
new file mode 100644
index 00000000..ae33954f
--- /dev/null
+++ b/model/chat_cache.go
@@ -0,0 +1,41 @@
+package model
+
+import (
+ "time"
+
+ "gorm.io/gorm/clause"
+)
+
+type ChatCache struct {
+ Hash string `json:"hash" gorm:"type:varchar(32);primaryKey"`
+ UserId int `json:"user_id" gorm:"type:int;not null;index"`
+ Data string `json:"data" gorm:"type:json;not null"`
+ Expiration int64 `json:"expiration" gorm:"type:bigint;not null;index"`
+}
+
+func (cache *ChatCache) Insert() error {
+ return DB.Clauses(clause.OnConflict{
+ UpdateAll: true,
+ }).Create(cache).Error
+}
+
+func GetChatCache(hash string, userId int) (*ChatCache, error) {
+ var chatCache ChatCache
+ // 获取当前时间戳
+ now := time.Now().Unix()
+ err := DB.Where("hash = ? and user_id = ? and expiration > ?", hash, userId, now).Find(&chatCache).Error
+ return &chatCache, err
+}
+
+func GetChatCacheListByUserId(userId int) ([]*ChatCache, error) {
+ var chatCaches []*ChatCache
+ // 获取当前时间戳
+ now := time.Now().Unix()
+ err := DB.Where("user_id = ? and expiration >", userId, now).Find(&chatCaches).Error
+ return chatCaches, err
+}
+
+func RemoveChatCache(expiration int64) error {
+ now := time.Now().Unix()
+ return DB.Where("expiration < ?", now).Delete(ChatCache{}).Error
+}
diff --git a/model/main.go b/model/main.go
index 5e682af2..72169a34 100644
--- a/model/main.go
+++ b/model/main.go
@@ -144,6 +144,10 @@ func InitDB() (err error) {
if err != nil {
return err
}
+ err = db.AutoMigrate(&ChatCache{})
+ if err != nil {
+ return err
+ }
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
diff --git a/model/option.go b/model/option.go
index 94103bd8..fd5cb05f 100644
--- a/model/option.go
+++ b/model/option.go
@@ -76,6 +76,9 @@ func InitOptionMap() {
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled)
+ common.OptionMap["ChatCacheEnabled"] = strconv.FormatBool(common.ChatCacheEnabled)
+ common.OptionMap["ChatCacheExpireMinute"] = strconv.Itoa(common.ChatCacheExpireMinute)
+
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@@ -115,14 +118,15 @@ func UpdateOption(key string, value string) error {
}
var optionIntMap = map[string]*int{
- "SMTPPort": &common.SMTPPort,
- "QuotaForNewUser": &common.QuotaForNewUser,
- "QuotaForInviter": &common.QuotaForInviter,
- "QuotaForInvitee": &common.QuotaForInvitee,
- "QuotaRemindThreshold": &common.QuotaRemindThreshold,
- "PreConsumedQuota": &common.PreConsumedQuota,
- "RetryTimes": &common.RetryTimes,
- "RetryCooldownSeconds": &common.RetryCooldownSeconds,
+ "SMTPPort": &common.SMTPPort,
+ "QuotaForNewUser": &common.QuotaForNewUser,
+ "QuotaForInviter": &common.QuotaForInviter,
+ "QuotaForInvitee": &common.QuotaForInvitee,
+ "QuotaRemindThreshold": &common.QuotaRemindThreshold,
+ "PreConsumedQuota": &common.PreConsumedQuota,
+ "RetryTimes": &common.RetryTimes,
+ "RetryCooldownSeconds": &common.RetryCooldownSeconds,
+ "ChatCacheExpireMinute": &common.ChatCacheExpireMinute,
}
var optionBoolMap = map[string]*bool{
@@ -141,6 +145,7 @@ var optionBoolMap = map[string]*bool{
"DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled,
"DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled,
"MjNotifyEnabled": &common.MjNotifyEnabled,
+ "ChatCacheEnabled": &common.ChatCacheEnabled,
}
var optionStringMap = map[string]*string{
diff --git a/model/token.go b/model/token.go
index 689ec49e..c07e0548 100644
--- a/model/token.go
+++ b/model/token.go
@@ -2,6 +2,7 @@ package model
import (
"errors"
+ "fmt"
"one-api/common"
"one-api/common/stmp"
@@ -20,6 +21,7 @@ type Token struct {
RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
+ ChatCache bool `json:"chat_cache" gorm:"default:false"`
}
var allowedTokenOrderFields = map[string]bool{
@@ -40,7 +42,7 @@ func GetUserTokensList(userId int, params *GenericParams) (*DataResult[Token], e
db = db.Where("name LIKE ?", params.Keyword+"%")
}
- return PaginateAndOrder[Token](db, ¶ms.PaginationParams, &tokens, allowedTokenOrderFields)
+ return PaginateAndOrder(db, ¶ms.PaginationParams, &tokens, allowedTokenOrderFields)
}
// 获取状态为可用的令牌
@@ -114,13 +116,26 @@ func GetTokenById(id int) (*Token, error) {
}
func (token *Token) Insert() error {
+ if token.ChatCache && !common.ChatCacheEnabled {
+ token.ChatCache = false
+ }
+
err := DB.Create(token).Error
return err
}
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
- err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
+ if token.ChatCache && !common.ChatCacheEnabled {
+ token.ChatCache = false
+ }
+
+ err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "chat_cache").Updates(token).Error
+ // 防止Redis缓存不生效,直接删除
+ if err == nil && common.RedisEnabled {
+ common.RedisDel(fmt.Sprintf("token:%s", token.Key))
+ }
+
return err
}
diff --git a/relay/base.go b/relay/base.go
index 5891010c..72d42aae 100644
--- a/relay/base.go
+++ b/relay/base.go
@@ -1,6 +1,7 @@
package relay
import (
+ "one-api/relay/util"
"one-api/types"
providersBase "one-api/providers/base"
@@ -13,17 +14,33 @@ type relayBase struct {
provider providersBase.ProviderInterface
originalModel string
modelName string
+ cache *util.ChatCacheProps
}
type RelayBaseInterface interface {
send() (err *types.OpenAIErrorWithStatusCode, done bool)
getPromptTokens() (int, error)
setRequest() error
+ getRequest() any
setProvider(modelName string) error
getProvider() providersBase.ProviderInterface
getOriginalModel() string
getModelName() string
getContext() *gin.Context
+ SetChatCache(allow bool)
+ GetChatCache() *util.ChatCacheProps
+}
+
+func (r *relayBase) SetChatCache(allow bool) {
+ r.cache = util.NewChatCacheProps(r.c, allow)
+}
+
+func (r *relayBase) GetChatCache() *util.ChatCacheProps {
+ return r.cache
+}
+
+func (r *relayBase) getRequest() interface{} {
+ return nil
}
func (r *relayBase) setProvider(modelName string) error {
diff --git a/relay/chat.go b/relay/chat.go
index 33e2b469..c7ccd9ca 100644
--- a/relay/chat.go
+++ b/relay/chat.go
@@ -37,6 +37,10 @@ func (r *relayChat) setRequest() error {
return nil
}
+func (r *relayChat) getRequest() interface{} {
+ return &r.chatRequest
+}
+
func (r *relayChat) getPromptTokens() (int, error) {
return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil
}
@@ -58,7 +62,7 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
return
}
- err = responseStreamClient(r.c, response)
+ err = responseStreamClient(r.c, response, r.cache)
} else {
var response *types.ChatCompletionResponse
response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
@@ -66,6 +70,7 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
return
}
err = responseJsonClient(r.c, response)
+ r.cache.SetResponse(response)
}
if err != nil {
diff --git a/relay/common.go b/relay/common.go
index 7b2d9ffd..9289da8f 100644
--- a/relay/common.go
+++ b/relay/common.go
@@ -13,6 +13,7 @@ import (
"one-api/model"
"one-api/providers"
providersBase "one-api/providers/base"
+ "one-api/relay/util"
"one-api/types"
"strings"
@@ -20,29 +21,37 @@ import (
)
func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
+ allowCache := false
+ var relay RelayBaseInterface
if strings.HasPrefix(path, "/v1/chat/completions") {
- return NewRelayChat(c)
+ allowCache = true
+ relay = NewRelayChat(c)
} else if strings.HasPrefix(path, "/v1/completions") {
- return NewRelayCompletions(c)
+ allowCache = true
+ relay = NewRelayCompletions(c)
} else if strings.HasPrefix(path, "/v1/embeddings") {
- return NewRelayEmbeddings(c)
+ relay = NewRelayEmbeddings(c)
} else if strings.HasPrefix(path, "/v1/moderations") {
- return NewRelayModerations(c)
+ relay = NewRelayModerations(c)
} else if strings.HasPrefix(path, "/v1/images/generations") {
- return NewRelayImageGenerations(c)
+ relay = NewRelayImageGenerations(c)
} else if strings.HasPrefix(path, "/v1/images/edits") {
- return NewRelayImageEdits(c)
+ relay = NewRelayImageEdits(c)
} else if strings.HasPrefix(path, "/v1/images/variations") {
- return NewRelayImageVariations(c)
+ relay = NewRelayImageVariations(c)
} else if strings.HasPrefix(path, "/v1/audio/speech") {
- return NewRelaySpeech(c)
+ relay = NewRelaySpeech(c)
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
- return NewRelayTranscriptions(c)
+ relay = NewRelayTranscriptions(c)
} else if strings.HasPrefix(path, "/v1/audio/translations") {
- return NewRelayTranslations(c)
+ relay = NewRelayTranslations(c)
}
- return nil
+ if relay != nil {
+ relay.SetChatCache(allowCache)
+ }
+
+ return relay
}
func GetProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
@@ -120,7 +129,7 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
return nil
}
-func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string]) *types.OpenAIErrorWithStatusCode {
+func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps) (errWithOP *types.OpenAIErrorWithStatusCode) {
requester.SetEventStreamHeaders(c)
dataChan, errChan := stream.Recv()
@@ -128,19 +137,24 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
- fmt.Fprintln(w, "data: "+data+"\n")
+ streamData := "data: " + data + "\n\n"
+ fmt.Fprint(w, streamData)
+ cache.SetResponse(streamData)
return true
case err := <-errChan:
if !errors.Is(err, io.EOF) {
- fmt.Fprintln(w, "data: "+err.Error()+"\n")
+ fmt.Fprint(w, "data: "+err.Error()+"\n\n")
+ errWithOP = common.ErrorWrapper(err, "stream_error", http.StatusInternalServerError)
}
- fmt.Fprintln(w, "data: [DONE]")
+ streamData := "data: [DONE]\n"
+ fmt.Fprint(w, streamData)
+ cache.SetResponse(streamData)
return false
}
})
- return nil
+ return errWithOP
}
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {
@@ -174,6 +188,22 @@ func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types
return nil
}
+func responseCache(c *gin.Context, response string) {
+ // 检查是否是 data: 开头的流式数据
+ isStream := strings.HasPrefix(response, "data: ")
+
+ if isStream {
+ requester.SetEventStreamHeaders(c)
+ c.Stream(func(w io.Writer) bool {
+ fmt.Fprint(w, response)
+ return false
+ })
+ } else {
+ c.Data(http.StatusOK, "application/json", []byte(response))
+ }
+
+}
+
func shouldRetry(c *gin.Context, statusCode int) bool {
channelId := c.GetInt("specific_channel_id")
if channelId > 0 {
diff --git a/relay/completions.go b/relay/completions.go
index fdfbd03e..4abeecb1 100644
--- a/relay/completions.go
+++ b/relay/completions.go
@@ -37,6 +37,10 @@ func (r *relayCompletions) setRequest() error {
return nil
}
+func (r *relayCompletions) getRequest() interface{} {
+ return &r.request
+}
+
func (r *relayCompletions) getPromptTokens() (int, error) {
return common.CountTokenInput(r.request.Prompt, r.modelName), nil
}
@@ -58,7 +62,7 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
return
}
- err = responseStreamClient(r.c, response)
+ err = responseStreamClient(r.c, response, r.cache)
} else {
var response *types.CompletionResponse
response, err = provider.CreateCompletion(&r.request)
@@ -66,6 +70,7 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
return
}
err = responseJsonClient(r.c, response)
+ r.cache.SetResponse(response)
}
if err != nil {
diff --git a/relay/main.go b/relay/main.go
index 5da7a340..3630ba2b 100644
--- a/relay/main.go
+++ b/relay/main.go
@@ -7,6 +7,7 @@ import (
"one-api/model"
"one-api/relay/util"
"one-api/types"
+ "time"
"github.com/gin-gonic/gin"
)
@@ -23,6 +24,18 @@ func Relay(c *gin.Context) {
return
}
+ cacheProps := relay.GetChatCache()
+ cacheProps.SetHash(relay.getRequest())
+
+ // 获取缓存
+ cache := cacheProps.GetCache()
+
+ if cache != nil {
+ // 说明有缓存, 直接返回缓存内容
+ cacheProcessing(c, cache)
+ return
+ }
+
if err := relay.setProvider(relay.getOriginalModel()); err != nil {
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
return
@@ -103,5 +116,25 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod
}
quota.Consume(relay.getContext(), usage)
+ cacheProps := relay.GetChatCache()
+ go cacheProps.StoreCache(relay.getContext().GetInt("channel_id"), usage.PromptTokens, usage.CompletionTokens, relay.getModelName())
return
}
+
+func cacheProcessing(c *gin.Context, cacheProps *util.ChatCacheProps) {
+ responseCache(c, cacheProps.Response)
+
+ // 写入日志
+ tokenName := c.GetString("token_name")
+
+ requestTime := 0
+ requestStartTimeValue := c.Request.Context().Value("requestStartTime")
+ if requestStartTimeValue != nil {
+ requestStartTime, ok := requestStartTimeValue.(time.Time)
+ if ok {
+ requestTime = int(time.Since(requestStartTime).Milliseconds())
+ }
+ }
+
+ model.RecordConsumeLog(c.Request.Context(), cacheProps.UserId, cacheProps.ChannelID, cacheProps.PromptTokens, cacheProps.CompletionTokens, cacheProps.ModelName, tokenName, 0, "缓存", requestTime)
+}
diff --git a/relay/util/cache.go b/relay/util/cache.go
new file mode 100644
index 00000000..f8f25c11
--- /dev/null
+++ b/relay/util/cache.go
@@ -0,0 +1,128 @@
+package util
+
+import (
+ "crypto/md5"
+ "encoding/hex"
+ "fmt"
+ "one-api/common"
+ "one-api/model"
+
+ "github.com/gin-gonic/gin"
+)
+
+type ChatCacheProps struct {
+ UserId int `json:"user_id"`
+ TokenId int `json:"token_id"`
+ ChannelID int `json:"channel_id"`
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ ModelName string `json:"model_name"`
+ Response string `json:"response"`
+
+ Hash string `json:"-"`
+ Cache bool `json:"-"`
+ Driver CacheDriver `json:"-"`
+}
+
+type CacheDriver interface {
+ Get(hash string, userId int) *ChatCacheProps
+ Set(hash string, props *ChatCacheProps, expire int64) error
+}
+
+func GetDebugList(userId int) ([]*ChatCacheProps, error) {
+ caches, err := model.GetChatCacheListByUserId(userId)
+ if err != nil {
+ return nil, err
+ }
+
+ var props []*ChatCacheProps
+ for _, cache := range caches {
+ prop, err := common.UnmarshalString[ChatCacheProps](cache.Data)
+ if err != nil {
+ continue
+ }
+ props = append(props, &prop)
+ }
+
+ return props, nil
+}
+
+func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps {
+ props := &ChatCacheProps{
+ Cache: false,
+ }
+
+ if !allow {
+ return props
+ }
+
+ if common.ChatCacheEnabled && c.GetBool("chat_cache") {
+ props.Cache = true
+ }
+
+ if common.RedisEnabled {
+ props.Driver = &ChatCacheRedis{}
+ } else {
+ props.Driver = &ChatCacheDB{}
+ }
+
+ props.UserId = c.GetInt("id")
+ props.TokenId = c.GetInt("token_id")
+
+ return props
+}
+
+func (p *ChatCacheProps) SetHash(request any) {
+ if !p.needCache() || request == nil {
+ return
+ }
+
+ p.hash(common.Marshal(request))
+}
+
+func (p *ChatCacheProps) SetResponse(response any) {
+ if !p.needCache() || response == nil {
+ return
+ }
+
+ if str, ok := response.(string); ok {
+ p.Response += str
+ return
+ }
+
+ p.Response = common.Marshal(response)
+}
+
+func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens int, modelName string) error {
+ if !p.needCache() || p.Response == "" {
+ return nil
+ }
+
+ p.ChannelID = channelId
+ p.PromptTokens = promptTokens
+ p.CompletionTokens = completionTokens
+ p.ModelName = modelName
+
+ return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute))
+}
+
+func (p *ChatCacheProps) GetCache() *ChatCacheProps {
+ if !p.needCache() {
+ return nil
+ }
+
+ return p.Driver.Get(p.getHash(), p.UserId)
+}
+
+func (p *ChatCacheProps) needCache() bool {
+ return common.ChatCacheEnabled && p.Cache
+}
+
+func (p *ChatCacheProps) getHash() string {
+ return p.Hash
+}
+
+func (p *ChatCacheProps) hash(request string) {
+ hash := md5.Sum([]byte(fmt.Sprintf("%d-%d-%s", p.UserId, p.TokenId, request)))
+ p.Hash = hex.EncodeToString(hash[:])
+}
diff --git a/relay/util/cache_db.go b/relay/util/cache_db.go
new file mode 100644
index 00000000..22466449
--- /dev/null
+++ b/relay/util/cache_db.go
@@ -0,0 +1,47 @@
+package util
+
+import (
+ "errors"
+ "one-api/common"
+ "one-api/model"
+ "time"
+)
+
+type ChatCacheDB struct{}
+
+func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps {
+ cache, _ := model.GetChatCache(hash, userId)
+ if cache == nil {
+ return nil
+ }
+
+ props, err := common.UnmarshalString[ChatCacheProps](cache.Data)
+ if err != nil {
+ return nil
+ }
+
+ return &props
+}
+
+func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) error {
+ return SetCacheDB(hash, props, expire)
+}
+
+func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error {
+ data := common.Marshal(props)
+ if data == "" {
+ return errors.New("marshal error")
+ }
+
+ expire = expire * 60
+ expire += time.Now().Unix()
+
+ cache := &model.ChatCache{
+ Hash: hash,
+ UserId: props.UserId,
+ Data: data,
+ Expiration: expire,
+ }
+
+ return cache.Insert()
+}
diff --git a/relay/util/cache_redis.go b/relay/util/cache_redis.go
new file mode 100644
index 00000000..a9eaa81f
--- /dev/null
+++ b/relay/util/cache_redis.go
@@ -0,0 +1,44 @@
+package util
+
+import (
+ "errors"
+ "fmt"
+ "one-api/common"
+ "time"
+)
+
+type ChatCacheRedis struct{}
+
+var chatCacheKey = "chat_cache"
+
+func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps {
+ cache, err := common.RedisGet(r.getKey(hash, userId))
+ if err != nil {
+ return nil
+ }
+
+ props, err := common.UnmarshalString[ChatCacheProps](cache)
+ if err != nil {
+ return nil
+ }
+
+ return &props
+}
+
+func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) error {
+
+ if !props.Cache {
+ return nil
+ }
+
+ data := common.Marshal(&props)
+ if data == "" {
+ return errors.New("marshal error")
+ }
+
+ return common.RedisSet(r.getKey(hash, props.UserId), data, time.Duration(expire)*time.Minute)
+}
+
+func (r *ChatCacheRedis) getKey(hash string, userId int) string {
+ return fmt.Sprintf("%s:%d:%s", chatCacheKey, userId, hash)
+}
diff --git a/web/src/views/Setting/component/OperationSetting.js b/web/src/views/Setting/component/OperationSetting.js
index 6712e6ca..43b3e32d 100644
--- a/web/src/views/Setting/component/OperationSetting.js
+++ b/web/src/views/Setting/component/OperationSetting.js
@@ -30,7 +30,9 @@ const OperationSetting = () => {
ApproximateTokenEnabled: '',
RetryTimes: 0,
RetryCooldownSeconds: 0,
- MjNotifyEnabled: ''
+ MjNotifyEnabled: '',
+ ChatCacheEnabled: '',
+ ChatCacheExpireMinute: 5
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
@@ -152,6 +154,11 @@ const OperationSetting = () => {
await updateOption('RetryCooldownSeconds', inputs.RetryCooldownSeconds);
}
break;
+ case 'other':
+ if (originInputs['ChatCacheExpireMinute'] !== inputs.ChatCacheExpireMinute) {
+ await updateOption('ChatCacheExpireMinute', inputs.ChatCacheExpireMinute);
+ }
+ break;
}
showSuccess('保存成功!');
@@ -292,7 +299,34 @@ const OperationSetting = () => {
label="Midjourney 允许回调(会泄露服务器ip地址)"
control={}
/>
+ }
+ />
+
+
+ 缓存时间(分钟)
+
+
+
+
diff --git a/web/src/views/Token/component/EditModal.js b/web/src/views/Token/component/EditModal.js
index 1a70ae2a..5ad5a96a 100644
--- a/web/src/views/Token/component/EditModal.js
+++ b/web/src/views/Token/component/EditModal.js
@@ -17,6 +17,7 @@ import {
OutlinedInput,
InputAdornment,
Switch,
+ FormControlLabel,
FormHelperText
} from '@mui/material';
@@ -25,6 +26,7 @@ import { LocalizationProvider } from '@mui/x-date-pickers/LocalizationProvider';
import { DateTimePicker } from '@mui/x-date-pickers/DateTimePicker';
import { renderQuotaWithPrompt, showSuccess, showError } from 'utils/common';
import { API } from 'utils/api';
+import { useSelector } from 'react-redux';
require('dayjs/locale/zh-cn');
const validationSchema = Yup.object().shape({
@@ -40,12 +42,14 @@ const originInputs = {
name: '',
remain_quota: 0,
expired_time: -1,
- unlimited_quota: false
+ unlimited_quota: false,
+ chat_cache: false
};
const EditModal = ({ open, tokenId, onCancel, onOk }) => {
const theme = useTheme();
const [inputs, setInputs] = useState(originInputs);
+ const siteInfo = useSelector((state) => state.siteInfo);
const submit = async (values, { setErrors, setStatus, setSubmitting }) => {
setSubmitting(true);
@@ -163,17 +167,22 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => {
)}
)}
- {
- if (values.expired_time === -1) {
- setFieldValue('expired_time', Math.floor(Date.now() / 1000));
- } else {
- setFieldValue('expired_time', -1);
- }
- }}
- />{' '}
- 永不过期
+ {
+ if (values.expired_time === -1) {
+ setFieldValue('expired_time', Math.floor(Date.now() / 1000));
+ } else {
+ setFieldValue('expired_time', -1);
+ }
+ }}
+ />
+ }
+ label="永不过期"
+ />
+
额度
{
)}
- {
- setFieldValue('unlimited_quota', !values.unlimited_quota);
- }}
- />{' '}
- 无限额度
+
+ {
+ setFieldValue('unlimited_quota', !values.unlimited_quota);
+ }}
+ />
+ }
+ label="无限额度"
+ />
+ {siteInfo.chat_cache_enabled && (
+ {
+ setFieldValue('chat_cache', !values.chat_cache);
+ }}
+ />
+ }
+ label="是否开启缓存(开启后,将会缓存聊天记录,以减少消费)"
+ />
+ )}
+