diff --git a/common/constants.go b/common/constants.go index b44ca170..192303b2 100644 --- a/common/constants.go +++ b/common/constants.go @@ -37,6 +37,10 @@ var WeChatAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true +// chat cache +var ChatCacheEnabled = false +var ChatCacheExpireMinute = 5 // 5 Minute + // mj var MjNotifyEnabled = false diff --git a/common/utils.go b/common/utils.go index ac6bafe6..8b0bd555 100644 --- a/common/utils.go +++ b/common/utils.go @@ -1,6 +1,7 @@ package common import ( + "encoding/json" "fmt" "html/template" "log" @@ -248,3 +249,16 @@ func EscapeMarkdownText(text string) string { } return text } + +func UnmarshalString[T interface{}](data string) (form T, err error) { + err = json.Unmarshal([]byte(data), &form) + return form, err +} + +func Marshal[T interface{}](data T) string { + res, err := json.Marshal(data) + if err != nil { + return "" + } + return string(res) +} diff --git a/controller/misc.go b/controller/misc.go index 3eae703d..1ad29e63 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -42,6 +42,7 @@ func GetStatus(c *gin.Context) { "display_in_currency": common.DisplayInCurrencyEnabled, "telegram_bot": telegram_bot, "mj_notify_enabled": common.MjNotifyEnabled, + "chat_cache_enabled": common.ChatCacheEnabled, }, }) } diff --git a/controller/token.go b/controller/token.go index 608d630b..67a576be 100644 --- a/controller/token.go +++ b/controller/token.go @@ -104,6 +104,7 @@ func AddToken(c *gin.Context) { ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, + ChatCache: token.ChatCache, } err = cleanToken.Insert() if err != nil { @@ -187,6 +188,7 @@ func UpdateToken(c *gin.Context) { cleanToken.ExpiredTime = token.ExpiredTime cleanToken.RemainQuota = token.RemainQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota + cleanToken.ChatCache = token.ChatCache } err = cleanToken.Update() if err != nil { diff --git a/cron/main.go b/cron/main.go new file mode 100644 index 00000000..8e98da0f --- /dev/null +++ b/cron/main.go @@ -0,0 +1,37 @@ +package cron + +import ( + "one-api/common" + "one-api/model" + "time" + + "github.com/go-co-op/gocron/v2" +) + +func InitCron() { + scheduler, err := gocron.NewScheduler() + if err != nil { + common.SysLog("Cron scheduler error: " + err.Error()) + return + } + + // 添加删除cache的任务 + _, err = scheduler.NewJob( + gocron.DailyJob( + 1, + gocron.NewAtTimes( + gocron.NewAtTime(0, 5, 0), + )), + gocron.NewTask(func() { + model.RemoveChatCache(time.Now().Unix()) + common.SysLog("删除过期缓存数据") + }), + ) + + if err != nil { + common.SysLog("Cron job error: " + err.Error()) + return + } + + scheduler.Start() +} diff --git a/go.mod b/go.mod index da655611..57d6cd1a 100644 --- a/go.mod +++ b/go.mod @@ -30,11 +30,14 @@ require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/chenzhuoyu/iasm v0.9.1 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/go-co-op/gocron/v2 v2.2.9 // indirect github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jonboulle/clockwork v0.4.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -46,7 +49,7 @@ require ( github.com/wneessen/go-mail v0.4.1 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect - golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect golang.org/x/sync v0.6.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/go.sum b/go.sum index a1fa358e..bfa82581 100644 --- a/go.sum +++ b/go.sum @@ -71,6 +71,8 @@ github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= +github.com/go-co-op/gocron/v2 v2.2.9 h1:aoKosYWSSdXFLecjFWX1i8+R6V7XdZb8sB2ZKAY5Yis= +github.com/go-co-op/gocron/v2 v2.2.9/go.mod h1:mZx3gMSlFnb97k3hRqX3+GdlG3+DUwTh6B8fnsTScXg= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= @@ -152,6 +154,8 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4= +github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= @@ -219,6 +223,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= @@ -284,6 +290,8 @@ golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 h1:+iq7lrkxmFNBM7xx+Rae2W6uyPfhPeDWD+n+JgppptE= +golang.org/x/exp v0.0.0-20231219180239-dc181d75b848/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= diff --git a/main.go b/main.go index df5478df..ac44532e 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "one-api/common/requester" "one-api/common/telegram" "one-api/controller" + "one-api/cron" "one-api/middleware" "one-api/model" "one-api/relay/util" @@ -48,6 +49,7 @@ func main() { controller.InitMidjourneyTask() notify.InitNotifier() + cron.InitCron() initHttpServer() } diff --git a/middleware/auth.go b/middleware/auth.go index 91e3dd56..30ccc6ef 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -105,6 +105,7 @@ func tokenAuth(c *gin.Context, key string) { c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) + c.Set("chat_cache", token.ChatCache) if len(parts) > 1 { if model.IsAdmin(token.UserId) { channelId := common.String2Int(parts[1]) diff --git a/model/chat_cache.go b/model/chat_cache.go new file mode 100644 index 00000000..ae33954f --- /dev/null +++ b/model/chat_cache.go @@ -0,0 +1,41 @@ +package model + +import ( + "time" + + "gorm.io/gorm/clause" +) + +type ChatCache struct { + Hash string `json:"hash" gorm:"type:varchar(32);primaryKey"` + UserId int `json:"user_id" gorm:"type:int;not null;index"` + Data string `json:"data" gorm:"type:json;not null"` + Expiration int64 `json:"expiration" gorm:"type:bigint;not null;index"` +} + +func (cache *ChatCache) Insert() error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).Create(cache).Error +} + +func GetChatCache(hash string, userId int) (*ChatCache, error) { + var chatCache ChatCache + // 获取当前时间戳 + now := time.Now().Unix() + err := DB.Where("hash = ? and user_id = ? and expiration > ?", hash, userId, now).Find(&chatCache).Error + return &chatCache, err +} + +func GetChatCacheListByUserId(userId int) ([]*ChatCache, error) { + var chatCaches []*ChatCache + // 获取当前时间戳 + now := time.Now().Unix() + err := DB.Where("user_id = ? and expiration >", userId, now).Find(&chatCaches).Error + return chatCaches, err +} + +func RemoveChatCache(expiration int64) error { + now := time.Now().Unix() + return DB.Where("expiration < ?", now).Delete(ChatCache{}).Error +} diff --git a/model/main.go b/model/main.go index 5e682af2..72169a34 100644 --- a/model/main.go +++ b/model/main.go @@ -144,6 +144,10 @@ func InitDB() (err error) { if err != nil { return err } + err = db.AutoMigrate(&ChatCache{}) + if err != nil { + return err + } common.SysLog("database migrated") err = createRootAccountIfNeed() return err diff --git a/model/option.go b/model/option.go index 94103bd8..fd5cb05f 100644 --- a/model/option.go +++ b/model/option.go @@ -76,6 +76,9 @@ func InitOptionMap() { common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled) + common.OptionMap["ChatCacheEnabled"] = strconv.FormatBool(common.ChatCacheEnabled) + common.OptionMap["ChatCacheExpireMinute"] = strconv.Itoa(common.ChatCacheExpireMinute) + common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() } @@ -115,14 +118,15 @@ func UpdateOption(key string, value string) error { } var optionIntMap = map[string]*int{ - "SMTPPort": &common.SMTPPort, - "QuotaForNewUser": &common.QuotaForNewUser, - "QuotaForInviter": &common.QuotaForInviter, - "QuotaForInvitee": &common.QuotaForInvitee, - "QuotaRemindThreshold": &common.QuotaRemindThreshold, - "PreConsumedQuota": &common.PreConsumedQuota, - "RetryTimes": &common.RetryTimes, - "RetryCooldownSeconds": &common.RetryCooldownSeconds, + "SMTPPort": &common.SMTPPort, + "QuotaForNewUser": &common.QuotaForNewUser, + "QuotaForInviter": &common.QuotaForInviter, + "QuotaForInvitee": &common.QuotaForInvitee, + "QuotaRemindThreshold": &common.QuotaRemindThreshold, + "PreConsumedQuota": &common.PreConsumedQuota, + "RetryTimes": &common.RetryTimes, + "RetryCooldownSeconds": &common.RetryCooldownSeconds, + "ChatCacheExpireMinute": &common.ChatCacheExpireMinute, } var optionBoolMap = map[string]*bool{ @@ -141,6 +145,7 @@ var optionBoolMap = map[string]*bool{ "DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled, "DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled, "MjNotifyEnabled": &common.MjNotifyEnabled, + "ChatCacheEnabled": &common.ChatCacheEnabled, } var optionStringMap = map[string]*string{ diff --git a/model/token.go b/model/token.go index 689ec49e..c07e0548 100644 --- a/model/token.go +++ b/model/token.go @@ -2,6 +2,7 @@ package model import ( "errors" + "fmt" "one-api/common" "one-api/common/stmp" @@ -20,6 +21,7 @@ type Token struct { RemainQuota int `json:"remain_quota" gorm:"default:0"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota + ChatCache bool `json:"chat_cache" gorm:"default:false"` } var allowedTokenOrderFields = map[string]bool{ @@ -40,7 +42,7 @@ func GetUserTokensList(userId int, params *GenericParams) (*DataResult[Token], e db = db.Where("name LIKE ?", params.Keyword+"%") } - return PaginateAndOrder[Token](db, ¶ms.PaginationParams, &tokens, allowedTokenOrderFields) + return PaginateAndOrder(db, ¶ms.PaginationParams, &tokens, allowedTokenOrderFields) } // 获取状态为可用的令牌 @@ -114,13 +116,26 @@ func GetTokenById(id int) (*Token, error) { } func (token *Token) Insert() error { + if token.ChatCache && !common.ChatCacheEnabled { + token.ChatCache = false + } + err := DB.Create(token).Error return err } // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { - err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + if token.ChatCache && !common.ChatCacheEnabled { + token.ChatCache = false + } + + err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "chat_cache").Updates(token).Error + // 防止Redis缓存不生效,直接删除 + if err == nil && common.RedisEnabled { + common.RedisDel(fmt.Sprintf("token:%s", token.Key)) + } + return err } diff --git a/relay/base.go b/relay/base.go index 5891010c..72d42aae 100644 --- a/relay/base.go +++ b/relay/base.go @@ -1,6 +1,7 @@ package relay import ( + "one-api/relay/util" "one-api/types" providersBase "one-api/providers/base" @@ -13,17 +14,33 @@ type relayBase struct { provider providersBase.ProviderInterface originalModel string modelName string + cache *util.ChatCacheProps } type RelayBaseInterface interface { send() (err *types.OpenAIErrorWithStatusCode, done bool) getPromptTokens() (int, error) setRequest() error + getRequest() any setProvider(modelName string) error getProvider() providersBase.ProviderInterface getOriginalModel() string getModelName() string getContext() *gin.Context + SetChatCache(allow bool) + GetChatCache() *util.ChatCacheProps +} + +func (r *relayBase) SetChatCache(allow bool) { + r.cache = util.NewChatCacheProps(r.c, allow) +} + +func (r *relayBase) GetChatCache() *util.ChatCacheProps { + return r.cache +} + +func (r *relayBase) getRequest() interface{} { + return nil } func (r *relayBase) setProvider(modelName string) error { diff --git a/relay/chat.go b/relay/chat.go index 33e2b469..c7ccd9ca 100644 --- a/relay/chat.go +++ b/relay/chat.go @@ -37,6 +37,10 @@ func (r *relayChat) setRequest() error { return nil } +func (r *relayChat) getRequest() interface{} { + return &r.chatRequest +} + func (r *relayChat) getPromptTokens() (int, error) { return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil } @@ -58,7 +62,7 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) { return } - err = responseStreamClient(r.c, response) + err = responseStreamClient(r.c, response, r.cache) } else { var response *types.ChatCompletionResponse response, err = chatProvider.CreateChatCompletion(&r.chatRequest) @@ -66,6 +70,7 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) { return } err = responseJsonClient(r.c, response) + r.cache.SetResponse(response) } if err != nil { diff --git a/relay/common.go b/relay/common.go index 7b2d9ffd..9289da8f 100644 --- a/relay/common.go +++ b/relay/common.go @@ -13,6 +13,7 @@ import ( "one-api/model" "one-api/providers" providersBase "one-api/providers/base" + "one-api/relay/util" "one-api/types" "strings" @@ -20,29 +21,37 @@ import ( ) func Path2Relay(c *gin.Context, path string) RelayBaseInterface { + allowCache := false + var relay RelayBaseInterface if strings.HasPrefix(path, "/v1/chat/completions") { - return NewRelayChat(c) + allowCache = true + relay = NewRelayChat(c) } else if strings.HasPrefix(path, "/v1/completions") { - return NewRelayCompletions(c) + allowCache = true + relay = NewRelayCompletions(c) } else if strings.HasPrefix(path, "/v1/embeddings") { - return NewRelayEmbeddings(c) + relay = NewRelayEmbeddings(c) } else if strings.HasPrefix(path, "/v1/moderations") { - return NewRelayModerations(c) + relay = NewRelayModerations(c) } else if strings.HasPrefix(path, "/v1/images/generations") { - return NewRelayImageGenerations(c) + relay = NewRelayImageGenerations(c) } else if strings.HasPrefix(path, "/v1/images/edits") { - return NewRelayImageEdits(c) + relay = NewRelayImageEdits(c) } else if strings.HasPrefix(path, "/v1/images/variations") { - return NewRelayImageVariations(c) + relay = NewRelayImageVariations(c) } else if strings.HasPrefix(path, "/v1/audio/speech") { - return NewRelaySpeech(c) + relay = NewRelaySpeech(c) } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { - return NewRelayTranscriptions(c) + relay = NewRelayTranscriptions(c) } else if strings.HasPrefix(path, "/v1/audio/translations") { - return NewRelayTranslations(c) + relay = NewRelayTranslations(c) } - return nil + if relay != nil { + relay.SetChatCache(allowCache) + } + + return relay } func GetProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) { @@ -120,7 +129,7 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith return nil } -func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string]) *types.OpenAIErrorWithStatusCode { +func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps) (errWithOP *types.OpenAIErrorWithStatusCode) { requester.SetEventStreamHeaders(c) dataChan, errChan := stream.Recv() @@ -128,19 +137,24 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - fmt.Fprintln(w, "data: "+data+"\n") + streamData := "data: " + data + "\n\n" + fmt.Fprint(w, streamData) + cache.SetResponse(streamData) return true case err := <-errChan: if !errors.Is(err, io.EOF) { - fmt.Fprintln(w, "data: "+err.Error()+"\n") + fmt.Fprint(w, "data: "+err.Error()+"\n\n") + errWithOP = common.ErrorWrapper(err, "stream_error", http.StatusInternalServerError) } - fmt.Fprintln(w, "data: [DONE]") + streamData := "data: [DONE]\n" + fmt.Fprint(w, streamData) + cache.SetResponse(streamData) return false } }) - return nil + return errWithOP } func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode { @@ -174,6 +188,22 @@ func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types return nil } +func responseCache(c *gin.Context, response string) { + // 检查是否是 data: 开头的流式数据 + isStream := strings.HasPrefix(response, "data: ") + + if isStream { + requester.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + fmt.Fprint(w, response) + return false + }) + } else { + c.Data(http.StatusOK, "application/json", []byte(response)) + } + +} + func shouldRetry(c *gin.Context, statusCode int) bool { channelId := c.GetInt("specific_channel_id") if channelId > 0 { diff --git a/relay/completions.go b/relay/completions.go index fdfbd03e..4abeecb1 100644 --- a/relay/completions.go +++ b/relay/completions.go @@ -37,6 +37,10 @@ func (r *relayCompletions) setRequest() error { return nil } +func (r *relayCompletions) getRequest() interface{} { + return &r.request +} + func (r *relayCompletions) getPromptTokens() (int, error) { return common.CountTokenInput(r.request.Prompt, r.modelName), nil } @@ -58,7 +62,7 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo return } - err = responseStreamClient(r.c, response) + err = responseStreamClient(r.c, response, r.cache) } else { var response *types.CompletionResponse response, err = provider.CreateCompletion(&r.request) @@ -66,6 +70,7 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo return } err = responseJsonClient(r.c, response) + r.cache.SetResponse(response) } if err != nil { diff --git a/relay/main.go b/relay/main.go index 5da7a340..3630ba2b 100644 --- a/relay/main.go +++ b/relay/main.go @@ -7,6 +7,7 @@ import ( "one-api/model" "one-api/relay/util" "one-api/types" + "time" "github.com/gin-gonic/gin" ) @@ -23,6 +24,18 @@ func Relay(c *gin.Context) { return } + cacheProps := relay.GetChatCache() + cacheProps.SetHash(relay.getRequest()) + + // 获取缓存 + cache := cacheProps.GetCache() + + if cache != nil { + // 说明有缓存, 直接返回缓存内容 + cacheProcessing(c, cache) + return + } + if err := relay.setProvider(relay.getOriginalModel()); err != nil { common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error()) return @@ -103,5 +116,25 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod } quota.Consume(relay.getContext(), usage) + cacheProps := relay.GetChatCache() + go cacheProps.StoreCache(relay.getContext().GetInt("channel_id"), usage.PromptTokens, usage.CompletionTokens, relay.getModelName()) return } + +func cacheProcessing(c *gin.Context, cacheProps *util.ChatCacheProps) { + responseCache(c, cacheProps.Response) + + // 写入日志 + tokenName := c.GetString("token_name") + + requestTime := 0 + requestStartTimeValue := c.Request.Context().Value("requestStartTime") + if requestStartTimeValue != nil { + requestStartTime, ok := requestStartTimeValue.(time.Time) + if ok { + requestTime = int(time.Since(requestStartTime).Milliseconds()) + } + } + + model.RecordConsumeLog(c.Request.Context(), cacheProps.UserId, cacheProps.ChannelID, cacheProps.PromptTokens, cacheProps.CompletionTokens, cacheProps.ModelName, tokenName, 0, "缓存", requestTime) +} diff --git a/relay/util/cache.go b/relay/util/cache.go new file mode 100644 index 00000000..f8f25c11 --- /dev/null +++ b/relay/util/cache.go @@ -0,0 +1,128 @@ +package util + +import ( + "crypto/md5" + "encoding/hex" + "fmt" + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +type ChatCacheProps struct { + UserId int `json:"user_id"` + TokenId int `json:"token_id"` + ChannelID int `json:"channel_id"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + ModelName string `json:"model_name"` + Response string `json:"response"` + + Hash string `json:"-"` + Cache bool `json:"-"` + Driver CacheDriver `json:"-"` +} + +type CacheDriver interface { + Get(hash string, userId int) *ChatCacheProps + Set(hash string, props *ChatCacheProps, expire int64) error +} + +func GetDebugList(userId int) ([]*ChatCacheProps, error) { + caches, err := model.GetChatCacheListByUserId(userId) + if err != nil { + return nil, err + } + + var props []*ChatCacheProps + for _, cache := range caches { + prop, err := common.UnmarshalString[ChatCacheProps](cache.Data) + if err != nil { + continue + } + props = append(props, &prop) + } + + return props, nil +} + +func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps { + props := &ChatCacheProps{ + Cache: false, + } + + if !allow { + return props + } + + if common.ChatCacheEnabled && c.GetBool("chat_cache") { + props.Cache = true + } + + if common.RedisEnabled { + props.Driver = &ChatCacheRedis{} + } else { + props.Driver = &ChatCacheDB{} + } + + props.UserId = c.GetInt("id") + props.TokenId = c.GetInt("token_id") + + return props +} + +func (p *ChatCacheProps) SetHash(request any) { + if !p.needCache() || request == nil { + return + } + + p.hash(common.Marshal(request)) +} + +func (p *ChatCacheProps) SetResponse(response any) { + if !p.needCache() || response == nil { + return + } + + if str, ok := response.(string); ok { + p.Response += str + return + } + + p.Response = common.Marshal(response) +} + +func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens int, modelName string) error { + if !p.needCache() || p.Response == "" { + return nil + } + + p.ChannelID = channelId + p.PromptTokens = promptTokens + p.CompletionTokens = completionTokens + p.ModelName = modelName + + return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute)) +} + +func (p *ChatCacheProps) GetCache() *ChatCacheProps { + if !p.needCache() { + return nil + } + + return p.Driver.Get(p.getHash(), p.UserId) +} + +func (p *ChatCacheProps) needCache() bool { + return common.ChatCacheEnabled && p.Cache +} + +func (p *ChatCacheProps) getHash() string { + return p.Hash +} + +func (p *ChatCacheProps) hash(request string) { + hash := md5.Sum([]byte(fmt.Sprintf("%d-%d-%s", p.UserId, p.TokenId, request))) + p.Hash = hex.EncodeToString(hash[:]) +} diff --git a/relay/util/cache_db.go b/relay/util/cache_db.go new file mode 100644 index 00000000..22466449 --- /dev/null +++ b/relay/util/cache_db.go @@ -0,0 +1,47 @@ +package util + +import ( + "errors" + "one-api/common" + "one-api/model" + "time" +) + +type ChatCacheDB struct{} + +func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps { + cache, _ := model.GetChatCache(hash, userId) + if cache == nil { + return nil + } + + props, err := common.UnmarshalString[ChatCacheProps](cache.Data) + if err != nil { + return nil + } + + return &props +} + +func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) error { + return SetCacheDB(hash, props, expire) +} + +func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error { + data := common.Marshal(props) + if data == "" { + return errors.New("marshal error") + } + + expire = expire * 60 + expire += time.Now().Unix() + + cache := &model.ChatCache{ + Hash: hash, + UserId: props.UserId, + Data: data, + Expiration: expire, + } + + return cache.Insert() +} diff --git a/relay/util/cache_redis.go b/relay/util/cache_redis.go new file mode 100644 index 00000000..a9eaa81f --- /dev/null +++ b/relay/util/cache_redis.go @@ -0,0 +1,44 @@ +package util + +import ( + "errors" + "fmt" + "one-api/common" + "time" +) + +type ChatCacheRedis struct{} + +var chatCacheKey = "chat_cache" + +func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps { + cache, err := common.RedisGet(r.getKey(hash, userId)) + if err != nil { + return nil + } + + props, err := common.UnmarshalString[ChatCacheProps](cache) + if err != nil { + return nil + } + + return &props +} + +func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) error { + + if !props.Cache { + return nil + } + + data := common.Marshal(&props) + if data == "" { + return errors.New("marshal error") + } + + return common.RedisSet(r.getKey(hash, props.UserId), data, time.Duration(expire)*time.Minute) +} + +func (r *ChatCacheRedis) getKey(hash string, userId int) string { + return fmt.Sprintf("%s:%d:%s", chatCacheKey, userId, hash) +} diff --git a/web/src/views/Setting/component/OperationSetting.js b/web/src/views/Setting/component/OperationSetting.js index 6712e6ca..43b3e32d 100644 --- a/web/src/views/Setting/component/OperationSetting.js +++ b/web/src/views/Setting/component/OperationSetting.js @@ -30,7 +30,9 @@ const OperationSetting = () => { ApproximateTokenEnabled: '', RetryTimes: 0, RetryCooldownSeconds: 0, - MjNotifyEnabled: '' + MjNotifyEnabled: '', + ChatCacheEnabled: '', + ChatCacheExpireMinute: 5 }); const [originInputs, setOriginInputs] = useState({}); let [loading, setLoading] = useState(false); @@ -152,6 +154,11 @@ const OperationSetting = () => { await updateOption('RetryCooldownSeconds', inputs.RetryCooldownSeconds); } break; + case 'other': + if (originInputs['ChatCacheExpireMinute'] !== inputs.ChatCacheExpireMinute) { + await updateOption('ChatCacheExpireMinute', inputs.ChatCacheExpireMinute); + } + break; } showSuccess('保存成功!'); @@ -292,7 +299,34 @@ const OperationSetting = () => { label="Midjourney 允许回调(会泄露服务器ip地址)" control={} /> + } + /> + + + 缓存时间(分钟) + + + + diff --git a/web/src/views/Token/component/EditModal.js b/web/src/views/Token/component/EditModal.js index 1a70ae2a..5ad5a96a 100644 --- a/web/src/views/Token/component/EditModal.js +++ b/web/src/views/Token/component/EditModal.js @@ -17,6 +17,7 @@ import { OutlinedInput, InputAdornment, Switch, + FormControlLabel, FormHelperText } from '@mui/material'; @@ -25,6 +26,7 @@ import { LocalizationProvider } from '@mui/x-date-pickers/LocalizationProvider'; import { DateTimePicker } from '@mui/x-date-pickers/DateTimePicker'; import { renderQuotaWithPrompt, showSuccess, showError } from 'utils/common'; import { API } from 'utils/api'; +import { useSelector } from 'react-redux'; require('dayjs/locale/zh-cn'); const validationSchema = Yup.object().shape({ @@ -40,12 +42,14 @@ const originInputs = { name: '', remain_quota: 0, expired_time: -1, - unlimited_quota: false + unlimited_quota: false, + chat_cache: false }; const EditModal = ({ open, tokenId, onCancel, onOk }) => { const theme = useTheme(); const [inputs, setInputs] = useState(originInputs); + const siteInfo = useSelector((state) => state.siteInfo); const submit = async (values, { setErrors, setStatus, setSubmitting }) => { setSubmitting(true); @@ -163,17 +167,22 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => { )} )} - { - if (values.expired_time === -1) { - setFieldValue('expired_time', Math.floor(Date.now() / 1000)); - } else { - setFieldValue('expired_time', -1); - } - }} - />{' '} - 永不过期 + { + if (values.expired_time === -1) { + setFieldValue('expired_time', Math.floor(Date.now() / 1000)); + } else { + setFieldValue('expired_time', -1); + } + }} + /> + } + label="永不过期" + /> + 额度 { )} - { - setFieldValue('unlimited_quota', !values.unlimited_quota); - }} - />{' '} - 无限额度 + + { + setFieldValue('unlimited_quota', !values.unlimited_quota); + }} + /> + } + label="无限额度" + /> + {siteInfo.chat_cache_enabled && ( + { + setFieldValue('chat_cache', !values.chat_cache); + }} + /> + } + label="是否开启缓存(开启后,将会缓存聊天记录,以减少消费)" + /> + )} +