feat: add chat cache (#152)

This commit is contained in:
Buer 2024-04-16 10:36:18 +08:00 committed by GitHub
parent bbaa4eec4b
commit 3c7c13758b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 557 additions and 49 deletions

View File

@ -37,6 +37,10 @@ var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
// chat cache
var ChatCacheEnabled = false
var ChatCacheExpireMinute = 5 // 5 Minute
// mj
var MjNotifyEnabled = false

View File

@ -1,6 +1,7 @@
package common
import (
"encoding/json"
"fmt"
"html/template"
"log"
@ -248,3 +249,16 @@ func EscapeMarkdownText(text string) string {
}
return text
}
func UnmarshalString[T interface{}](data string) (form T, err error) {
err = json.Unmarshal([]byte(data), &form)
return form, err
}
func Marshal[T interface{}](data T) string {
res, err := json.Marshal(data)
if err != nil {
return ""
}
return string(res)
}

View File

@ -42,6 +42,7 @@ func GetStatus(c *gin.Context) {
"display_in_currency": common.DisplayInCurrencyEnabled,
"telegram_bot": telegram_bot,
"mj_notify_enabled": common.MjNotifyEnabled,
"chat_cache_enabled": common.ChatCacheEnabled,
},
})
}

View File

@ -104,6 +104,7 @@ func AddToken(c *gin.Context) {
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
ChatCache: token.ChatCache,
}
err = cleanToken.Insert()
if err != nil {
@ -187,6 +188,7 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.ChatCache = token.ChatCache
}
err = cleanToken.Update()
if err != nil {

37
cron/main.go Normal file
View File

@ -0,0 +1,37 @@
package cron
import (
"one-api/common"
"one-api/model"
"time"
"github.com/go-co-op/gocron/v2"
)
func InitCron() {
scheduler, err := gocron.NewScheduler()
if err != nil {
common.SysLog("Cron scheduler error: " + err.Error())
return
}
// 添加删除cache的任务
_, err = scheduler.NewJob(
gocron.DailyJob(
1,
gocron.NewAtTimes(
gocron.NewAtTime(0, 5, 0),
)),
gocron.NewTask(func() {
model.RemoveChatCache(time.Now().Unix())
common.SysLog("删除过期缓存数据")
}),
)
if err != nil {
common.SysLog("Cron job error: " + err.Error())
return
}
scheduler.Start()
}

5
go.mod
View File

@ -30,11 +30,14 @@ require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/chenzhuoyu/iasm v0.9.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/go-co-op/gocron/v2 v2.2.9 // indirect
github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jonboulle/clockwork v0.4.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
@ -46,7 +49,7 @@ require (
github.com/wneessen/go-mail v0.4.1 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect
golang.org/x/sync v0.6.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)

8
go.sum
View File

@ -71,6 +71,8 @@ github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q=
github.com/go-co-op/gocron/v2 v2.2.9 h1:aoKosYWSSdXFLecjFWX1i8+R6V7XdZb8sB2ZKAY5Yis=
github.com/go-co-op/gocron/v2 v2.2.9/go.mod h1:mZx3gMSlFnb97k3hRqX3+GdlG3+DUwTh6B8fnsTScXg=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
@ -152,6 +154,8 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4=
github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
@ -219,6 +223,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@ -284,6 +290,8 @@ golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 h1:+iq7lrkxmFNBM7xx+Rae2W6uyPfhPeDWD+n+JgppptE=
golang.org/x/exp v0.0.0-20231219180239-dc181d75b848/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=

View File

@ -9,6 +9,7 @@ import (
"one-api/common/requester"
"one-api/common/telegram"
"one-api/controller"
"one-api/cron"
"one-api/middleware"
"one-api/model"
"one-api/relay/util"
@ -48,6 +49,7 @@ func main() {
controller.InitMidjourneyTask()
notify.InitNotifier()
cron.InitCron()
initHttpServer()
}

View File

@ -105,6 +105,7 @@ func tokenAuth(c *gin.Context, key string) {
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
c.Set("chat_cache", token.ChatCache)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
channelId := common.String2Int(parts[1])

41
model/chat_cache.go Normal file
View File

@ -0,0 +1,41 @@
package model
import (
"time"
"gorm.io/gorm/clause"
)
type ChatCache struct {
Hash string `json:"hash" gorm:"type:varchar(32);primaryKey"`
UserId int `json:"user_id" gorm:"type:int;not null;index"`
Data string `json:"data" gorm:"type:json;not null"`
Expiration int64 `json:"expiration" gorm:"type:bigint;not null;index"`
}
func (cache *ChatCache) Insert() error {
return DB.Clauses(clause.OnConflict{
UpdateAll: true,
}).Create(cache).Error
}
func GetChatCache(hash string, userId int) (*ChatCache, error) {
var chatCache ChatCache
// 获取当前时间戳
now := time.Now().Unix()
err := DB.Where("hash = ? and user_id = ? and expiration > ?", hash, userId, now).Find(&chatCache).Error
return &chatCache, err
}
func GetChatCacheListByUserId(userId int) ([]*ChatCache, error) {
var chatCaches []*ChatCache
// 获取当前时间戳
now := time.Now().Unix()
err := DB.Where("user_id = ? and expiration >", userId, now).Find(&chatCaches).Error
return chatCaches, err
}
func RemoveChatCache(expiration int64) error {
now := time.Now().Unix()
return DB.Where("expiration < ?", now).Delete(ChatCache{}).Error
}

View File

@ -144,6 +144,10 @@ func InitDB() (err error) {
if err != nil {
return err
}
err = db.AutoMigrate(&ChatCache{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err

View File

@ -76,6 +76,9 @@ func InitOptionMap() {
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled)
common.OptionMap["ChatCacheEnabled"] = strconv.FormatBool(common.ChatCacheEnabled)
common.OptionMap["ChatCacheExpireMinute"] = strconv.Itoa(common.ChatCacheExpireMinute)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@ -115,14 +118,15 @@ func UpdateOption(key string, value string) error {
}
var optionIntMap = map[string]*int{
"SMTPPort": &common.SMTPPort,
"QuotaForNewUser": &common.QuotaForNewUser,
"QuotaForInviter": &common.QuotaForInviter,
"QuotaForInvitee": &common.QuotaForInvitee,
"QuotaRemindThreshold": &common.QuotaRemindThreshold,
"PreConsumedQuota": &common.PreConsumedQuota,
"RetryTimes": &common.RetryTimes,
"RetryCooldownSeconds": &common.RetryCooldownSeconds,
"SMTPPort": &common.SMTPPort,
"QuotaForNewUser": &common.QuotaForNewUser,
"QuotaForInviter": &common.QuotaForInviter,
"QuotaForInvitee": &common.QuotaForInvitee,
"QuotaRemindThreshold": &common.QuotaRemindThreshold,
"PreConsumedQuota": &common.PreConsumedQuota,
"RetryTimes": &common.RetryTimes,
"RetryCooldownSeconds": &common.RetryCooldownSeconds,
"ChatCacheExpireMinute": &common.ChatCacheExpireMinute,
}
var optionBoolMap = map[string]*bool{
@ -141,6 +145,7 @@ var optionBoolMap = map[string]*bool{
"DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled,
"DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled,
"MjNotifyEnabled": &common.MjNotifyEnabled,
"ChatCacheEnabled": &common.ChatCacheEnabled,
}
var optionStringMap = map[string]*string{

View File

@ -2,6 +2,7 @@ package model
import (
"errors"
"fmt"
"one-api/common"
"one-api/common/stmp"
@ -20,6 +21,7 @@ type Token struct {
RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
ChatCache bool `json:"chat_cache" gorm:"default:false"`
}
var allowedTokenOrderFields = map[string]bool{
@ -40,7 +42,7 @@ func GetUserTokensList(userId int, params *GenericParams) (*DataResult[Token], e
db = db.Where("name LIKE ?", params.Keyword+"%")
}
return PaginateAndOrder[Token](db, &params.PaginationParams, &tokens, allowedTokenOrderFields)
return PaginateAndOrder(db, &params.PaginationParams, &tokens, allowedTokenOrderFields)
}
// 获取状态为可用的令牌
@ -114,13 +116,26 @@ func GetTokenById(id int) (*Token, error) {
}
func (token *Token) Insert() error {
if token.ChatCache && !common.ChatCacheEnabled {
token.ChatCache = false
}
err := DB.Create(token).Error
return err
}
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
if token.ChatCache && !common.ChatCacheEnabled {
token.ChatCache = false
}
err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "chat_cache").Updates(token).Error
// 防止Redis缓存不生效直接删除
if err == nil && common.RedisEnabled {
common.RedisDel(fmt.Sprintf("token:%s", token.Key))
}
return err
}

View File

@ -1,6 +1,7 @@
package relay
import (
"one-api/relay/util"
"one-api/types"
providersBase "one-api/providers/base"
@ -13,17 +14,33 @@ type relayBase struct {
provider providersBase.ProviderInterface
originalModel string
modelName string
cache *util.ChatCacheProps
}
type RelayBaseInterface interface {
send() (err *types.OpenAIErrorWithStatusCode, done bool)
getPromptTokens() (int, error)
setRequest() error
getRequest() any
setProvider(modelName string) error
getProvider() providersBase.ProviderInterface
getOriginalModel() string
getModelName() string
getContext() *gin.Context
SetChatCache(allow bool)
GetChatCache() *util.ChatCacheProps
}
func (r *relayBase) SetChatCache(allow bool) {
r.cache = util.NewChatCacheProps(r.c, allow)
}
func (r *relayBase) GetChatCache() *util.ChatCacheProps {
return r.cache
}
func (r *relayBase) getRequest() interface{} {
return nil
}
func (r *relayBase) setProvider(modelName string) error {

View File

@ -37,6 +37,10 @@ func (r *relayChat) setRequest() error {
return nil
}
func (r *relayChat) getRequest() interface{} {
return &r.chatRequest
}
func (r *relayChat) getPromptTokens() (int, error) {
return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil
}
@ -58,7 +62,7 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
return
}
err = responseStreamClient(r.c, response)
err = responseStreamClient(r.c, response, r.cache)
} else {
var response *types.ChatCompletionResponse
response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
@ -66,6 +70,7 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
return
}
err = responseJsonClient(r.c, response)
r.cache.SetResponse(response)
}
if err != nil {

View File

@ -13,6 +13,7 @@ import (
"one-api/model"
"one-api/providers"
providersBase "one-api/providers/base"
"one-api/relay/util"
"one-api/types"
"strings"
@ -20,29 +21,37 @@ import (
)
func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
allowCache := false
var relay RelayBaseInterface
if strings.HasPrefix(path, "/v1/chat/completions") {
return NewRelayChat(c)
allowCache = true
relay = NewRelayChat(c)
} else if strings.HasPrefix(path, "/v1/completions") {
return NewRelayCompletions(c)
allowCache = true
relay = NewRelayCompletions(c)
} else if strings.HasPrefix(path, "/v1/embeddings") {
return NewRelayEmbeddings(c)
relay = NewRelayEmbeddings(c)
} else if strings.HasPrefix(path, "/v1/moderations") {
return NewRelayModerations(c)
relay = NewRelayModerations(c)
} else if strings.HasPrefix(path, "/v1/images/generations") {
return NewRelayImageGenerations(c)
relay = NewRelayImageGenerations(c)
} else if strings.HasPrefix(path, "/v1/images/edits") {
return NewRelayImageEdits(c)
relay = NewRelayImageEdits(c)
} else if strings.HasPrefix(path, "/v1/images/variations") {
return NewRelayImageVariations(c)
relay = NewRelayImageVariations(c)
} else if strings.HasPrefix(path, "/v1/audio/speech") {
return NewRelaySpeech(c)
relay = NewRelaySpeech(c)
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
return NewRelayTranscriptions(c)
relay = NewRelayTranscriptions(c)
} else if strings.HasPrefix(path, "/v1/audio/translations") {
return NewRelayTranslations(c)
relay = NewRelayTranslations(c)
}
return nil
if relay != nil {
relay.SetChatCache(allowCache)
}
return relay
}
func GetProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
@ -120,7 +129,7 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
return nil
}
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string]) *types.OpenAIErrorWithStatusCode {
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps) (errWithOP *types.OpenAIErrorWithStatusCode) {
requester.SetEventStreamHeaders(c)
dataChan, errChan := stream.Recv()
@ -128,19 +137,24 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
fmt.Fprintln(w, "data: "+data+"\n")
streamData := "data: " + data + "\n\n"
fmt.Fprint(w, streamData)
cache.SetResponse(streamData)
return true
case err := <-errChan:
if !errors.Is(err, io.EOF) {
fmt.Fprintln(w, "data: "+err.Error()+"\n")
fmt.Fprint(w, "data: "+err.Error()+"\n\n")
errWithOP = common.ErrorWrapper(err, "stream_error", http.StatusInternalServerError)
}
fmt.Fprintln(w, "data: [DONE]")
streamData := "data: [DONE]\n"
fmt.Fprint(w, streamData)
cache.SetResponse(streamData)
return false
}
})
return nil
return errWithOP
}
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {
@ -174,6 +188,22 @@ func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types
return nil
}
func responseCache(c *gin.Context, response string) {
// 检查是否是 data: 开头的流式数据
isStream := strings.HasPrefix(response, "data: ")
if isStream {
requester.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
fmt.Fprint(w, response)
return false
})
} else {
c.Data(http.StatusOK, "application/json", []byte(response))
}
}
func shouldRetry(c *gin.Context, statusCode int) bool {
channelId := c.GetInt("specific_channel_id")
if channelId > 0 {

View File

@ -37,6 +37,10 @@ func (r *relayCompletions) setRequest() error {
return nil
}
func (r *relayCompletions) getRequest() interface{} {
return &r.request
}
func (r *relayCompletions) getPromptTokens() (int, error) {
return common.CountTokenInput(r.request.Prompt, r.modelName), nil
}
@ -58,7 +62,7 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
return
}
err = responseStreamClient(r.c, response)
err = responseStreamClient(r.c, response, r.cache)
} else {
var response *types.CompletionResponse
response, err = provider.CreateCompletion(&r.request)
@ -66,6 +70,7 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
return
}
err = responseJsonClient(r.c, response)
r.cache.SetResponse(response)
}
if err != nil {

View File

@ -7,6 +7,7 @@ import (
"one-api/model"
"one-api/relay/util"
"one-api/types"
"time"
"github.com/gin-gonic/gin"
)
@ -23,6 +24,18 @@ func Relay(c *gin.Context) {
return
}
cacheProps := relay.GetChatCache()
cacheProps.SetHash(relay.getRequest())
// 获取缓存
cache := cacheProps.GetCache()
if cache != nil {
// 说明有缓存, 直接返回缓存内容
cacheProcessing(c, cache)
return
}
if err := relay.setProvider(relay.getOriginalModel()); err != nil {
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
return
@ -103,5 +116,25 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod
}
quota.Consume(relay.getContext(), usage)
cacheProps := relay.GetChatCache()
go cacheProps.StoreCache(relay.getContext().GetInt("channel_id"), usage.PromptTokens, usage.CompletionTokens, relay.getModelName())
return
}
func cacheProcessing(c *gin.Context, cacheProps *util.ChatCacheProps) {
responseCache(c, cacheProps.Response)
// 写入日志
tokenName := c.GetString("token_name")
requestTime := 0
requestStartTimeValue := c.Request.Context().Value("requestStartTime")
if requestStartTimeValue != nil {
requestStartTime, ok := requestStartTimeValue.(time.Time)
if ok {
requestTime = int(time.Since(requestStartTime).Milliseconds())
}
}
model.RecordConsumeLog(c.Request.Context(), cacheProps.UserId, cacheProps.ChannelID, cacheProps.PromptTokens, cacheProps.CompletionTokens, cacheProps.ModelName, tokenName, 0, "缓存", requestTime)
}

128
relay/util/cache.go Normal file
View File

@ -0,0 +1,128 @@
package util
import (
"crypto/md5"
"encoding/hex"
"fmt"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
type ChatCacheProps struct {
UserId int `json:"user_id"`
TokenId int `json:"token_id"`
ChannelID int `json:"channel_id"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ModelName string `json:"model_name"`
Response string `json:"response"`
Hash string `json:"-"`
Cache bool `json:"-"`
Driver CacheDriver `json:"-"`
}
type CacheDriver interface {
Get(hash string, userId int) *ChatCacheProps
Set(hash string, props *ChatCacheProps, expire int64) error
}
func GetDebugList(userId int) ([]*ChatCacheProps, error) {
caches, err := model.GetChatCacheListByUserId(userId)
if err != nil {
return nil, err
}
var props []*ChatCacheProps
for _, cache := range caches {
prop, err := common.UnmarshalString[ChatCacheProps](cache.Data)
if err != nil {
continue
}
props = append(props, &prop)
}
return props, nil
}
func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps {
props := &ChatCacheProps{
Cache: false,
}
if !allow {
return props
}
if common.ChatCacheEnabled && c.GetBool("chat_cache") {
props.Cache = true
}
if common.RedisEnabled {
props.Driver = &ChatCacheRedis{}
} else {
props.Driver = &ChatCacheDB{}
}
props.UserId = c.GetInt("id")
props.TokenId = c.GetInt("token_id")
return props
}
func (p *ChatCacheProps) SetHash(request any) {
if !p.needCache() || request == nil {
return
}
p.hash(common.Marshal(request))
}
func (p *ChatCacheProps) SetResponse(response any) {
if !p.needCache() || response == nil {
return
}
if str, ok := response.(string); ok {
p.Response += str
return
}
p.Response = common.Marshal(response)
}
func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens int, modelName string) error {
if !p.needCache() || p.Response == "" {
return nil
}
p.ChannelID = channelId
p.PromptTokens = promptTokens
p.CompletionTokens = completionTokens
p.ModelName = modelName
return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute))
}
func (p *ChatCacheProps) GetCache() *ChatCacheProps {
if !p.needCache() {
return nil
}
return p.Driver.Get(p.getHash(), p.UserId)
}
func (p *ChatCacheProps) needCache() bool {
return common.ChatCacheEnabled && p.Cache
}
func (p *ChatCacheProps) getHash() string {
return p.Hash
}
func (p *ChatCacheProps) hash(request string) {
hash := md5.Sum([]byte(fmt.Sprintf("%d-%d-%s", p.UserId, p.TokenId, request)))
p.Hash = hex.EncodeToString(hash[:])
}

47
relay/util/cache_db.go Normal file
View File

@ -0,0 +1,47 @@
package util
import (
"errors"
"one-api/common"
"one-api/model"
"time"
)
type ChatCacheDB struct{}
func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps {
cache, _ := model.GetChatCache(hash, userId)
if cache == nil {
return nil
}
props, err := common.UnmarshalString[ChatCacheProps](cache.Data)
if err != nil {
return nil
}
return &props
}
func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) error {
return SetCacheDB(hash, props, expire)
}
func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error {
data := common.Marshal(props)
if data == "" {
return errors.New("marshal error")
}
expire = expire * 60
expire += time.Now().Unix()
cache := &model.ChatCache{
Hash: hash,
UserId: props.UserId,
Data: data,
Expiration: expire,
}
return cache.Insert()
}

44
relay/util/cache_redis.go Normal file
View File

@ -0,0 +1,44 @@
package util
import (
"errors"
"fmt"
"one-api/common"
"time"
)
type ChatCacheRedis struct{}
var chatCacheKey = "chat_cache"
func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps {
cache, err := common.RedisGet(r.getKey(hash, userId))
if err != nil {
return nil
}
props, err := common.UnmarshalString[ChatCacheProps](cache)
if err != nil {
return nil
}
return &props
}
func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) error {
if !props.Cache {
return nil
}
data := common.Marshal(&props)
if data == "" {
return errors.New("marshal error")
}
return common.RedisSet(r.getKey(hash, props.UserId), data, time.Duration(expire)*time.Minute)
}
func (r *ChatCacheRedis) getKey(hash string, userId int) string {
return fmt.Sprintf("%s:%d:%s", chatCacheKey, userId, hash)
}

View File

@ -30,7 +30,9 @@ const OperationSetting = () => {
ApproximateTokenEnabled: '',
RetryTimes: 0,
RetryCooldownSeconds: 0,
MjNotifyEnabled: ''
MjNotifyEnabled: '',
ChatCacheEnabled: '',
ChatCacheExpireMinute: 5
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
@ -152,6 +154,11 @@ const OperationSetting = () => {
await updateOption('RetryCooldownSeconds', inputs.RetryCooldownSeconds);
}
break;
case 'other':
if (originInputs['ChatCacheExpireMinute'] !== inputs.ChatCacheExpireMinute) {
await updateOption('ChatCacheExpireMinute', inputs.ChatCacheExpireMinute);
}
break;
}
showSuccess('保存成功!');
@ -292,7 +299,34 @@ const OperationSetting = () => {
label="Midjourney 允许回调会泄露服务器ip地址"
control={<Checkbox checked={inputs.MjNotifyEnabled === 'true'} onChange={handleInputChange} name="MjNotifyEnabled" />}
/>
<FormControlLabel
sx={{ marginLeft: '0px' }}
label="是否开启聊天缓存(如果没有启用Redis将会存储在数据库中)"
control={<Checkbox checked={inputs.ChatCacheEnabled === 'true'} onChange={handleInputChange} name="ChatCacheEnabled" />}
/>
</Stack>
<Stack direction={{ sm: 'column', md: 'row' }} spacing={{ xs: 3, sm: 2, md: 4 }}>
<FormControl>
<InputLabel htmlFor="ChatCacheExpireMinute">缓存时间(分钟)</InputLabel>
<OutlinedInput
id="ChatCacheExpireMinute"
name="ChatCacheExpireMinute"
value={inputs.ChatCacheExpireMinute}
onChange={handleInputChange}
label="缓存时间(分钟)"
placeholder="开启缓存时,数据缓存的时间"
disabled={loading}
/>
</FormControl>
</Stack>
<Button
variant="contained"
onClick={() => {
submitConfig('other').then();
}}
>
保存其他设置
</Button>
</Stack>
</SubCard>
<SubCard title="日志设置">

View File

@ -17,6 +17,7 @@ import {
OutlinedInput,
InputAdornment,
Switch,
FormControlLabel,
FormHelperText
} from '@mui/material';
@ -25,6 +26,7 @@ import { LocalizationProvider } from '@mui/x-date-pickers/LocalizationProvider';
import { DateTimePicker } from '@mui/x-date-pickers/DateTimePicker';
import { renderQuotaWithPrompt, showSuccess, showError } from 'utils/common';
import { API } from 'utils/api';
import { useSelector } from 'react-redux';
require('dayjs/locale/zh-cn');
const validationSchema = Yup.object().shape({
@ -40,12 +42,14 @@ const originInputs = {
name: '',
remain_quota: 0,
expired_time: -1,
unlimited_quota: false
unlimited_quota: false,
chat_cache: false
};
const EditModal = ({ open, tokenId, onCancel, onOk }) => {
const theme = useTheme();
const [inputs, setInputs] = useState(originInputs);
const siteInfo = useSelector((state) => state.siteInfo);
const submit = async (values, { setErrors, setStatus, setSubmitting }) => {
setSubmitting(true);
@ -163,17 +167,22 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => {
)}
</FormControl>
)}
<Switch
checked={values.expired_time === -1}
onClick={() => {
if (values.expired_time === -1) {
setFieldValue('expired_time', Math.floor(Date.now() / 1000));
} else {
setFieldValue('expired_time', -1);
}
}}
/>{' '}
永不过期
<FormControlLabel
control={
<Switch
checked={values.expired_time === -1}
onClick={() => {
if (values.expired_time === -1) {
setFieldValue('expired_time', Math.floor(Date.now() / 1000));
} else {
setFieldValue('expired_time', -1);
}
}}
/>
}
label="永不过期"
/>
<FormControl fullWidth error={Boolean(touched.remain_quota && errors.remain_quota)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-remain_quota-label">额度</InputLabel>
<OutlinedInput
@ -195,13 +204,32 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => {
</FormHelperText>
)}
</FormControl>
<Switch
checked={values.unlimited_quota === true}
onClick={() => {
setFieldValue('unlimited_quota', !values.unlimited_quota);
}}
/>{' '}
无限额度
<FormControl fullWidth>
<FormControlLabel
control={
<Switch
checked={values.unlimited_quota === true}
onClick={() => {
setFieldValue('unlimited_quota', !values.unlimited_quota);
}}
/>
}
label="无限额度"
/>
{siteInfo.chat_cache_enabled && (
<FormControlLabel
control={
<Switch
checked={values.chat_cache}
onClick={() => {
setFieldValue('chat_cache', !values.chat_cache);
}}
/>
}
label="是否开启缓存(开启后,将会缓存聊天记录,以减少消费)"
/>
)}
</FormControl>
<DialogActions>
<Button onClick={onCancel}>取消</Button>
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">