✨ feat: add chat cache (#152)
This commit is contained in:
parent
bbaa4eec4b
commit
3c7c13758b
@ -37,6 +37,10 @@ var WeChatAuthEnabled = false
|
||||
var TurnstileCheckEnabled = false
|
||||
var RegisterEnabled = true
|
||||
|
||||
// chat cache
|
||||
var ChatCacheEnabled = false
|
||||
var ChatCacheExpireMinute = 5 // 5 Minute
|
||||
|
||||
// mj
|
||||
var MjNotifyEnabled = false
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"log"
|
||||
@ -248,3 +249,16 @@ func EscapeMarkdownText(text string) string {
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func UnmarshalString[T interface{}](data string) (form T, err error) {
|
||||
err = json.Unmarshal([]byte(data), &form)
|
||||
return form, err
|
||||
}
|
||||
|
||||
func Marshal[T interface{}](data T) string {
|
||||
res, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(res)
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ func GetStatus(c *gin.Context) {
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"telegram_bot": telegram_bot,
|
||||
"mj_notify_enabled": common.MjNotifyEnabled,
|
||||
"chat_cache_enabled": common.ChatCacheEnabled,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
@ -104,6 +104,7 @@ func AddToken(c *gin.Context) {
|
||||
ExpiredTime: token.ExpiredTime,
|
||||
RemainQuota: token.RemainQuota,
|
||||
UnlimitedQuota: token.UnlimitedQuota,
|
||||
ChatCache: token.ChatCache,
|
||||
}
|
||||
err = cleanToken.Insert()
|
||||
if err != nil {
|
||||
@ -187,6 +188,7 @@ func UpdateToken(c *gin.Context) {
|
||||
cleanToken.ExpiredTime = token.ExpiredTime
|
||||
cleanToken.RemainQuota = token.RemainQuota
|
||||
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
||||
cleanToken.ChatCache = token.ChatCache
|
||||
}
|
||||
err = cleanToken.Update()
|
||||
if err != nil {
|
||||
|
37
cron/main.go
Normal file
37
cron/main.go
Normal file
@ -0,0 +1,37 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"time"
|
||||
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
)
|
||||
|
||||
func InitCron() {
|
||||
scheduler, err := gocron.NewScheduler()
|
||||
if err != nil {
|
||||
common.SysLog("Cron scheduler error: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 添加删除cache的任务
|
||||
_, err = scheduler.NewJob(
|
||||
gocron.DailyJob(
|
||||
1,
|
||||
gocron.NewAtTimes(
|
||||
gocron.NewAtTime(0, 5, 0),
|
||||
)),
|
||||
gocron.NewTask(func() {
|
||||
model.RemoveChatCache(time.Now().Unix())
|
||||
common.SysLog("删除过期缓存数据")
|
||||
}),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
common.SysLog("Cron job error: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
scheduler.Start()
|
||||
}
|
5
go.mod
5
go.mod
@ -30,11 +30,14 @@ require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/chenzhuoyu/iasm v0.9.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/go-co-op/gocron/v2 v2.2.9 // indirect
|
||||
github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jonboulle/clockwork v0.4.0 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
@ -46,7 +49,7 @@ require (
|
||||
github.com/wneessen/go-mail v0.4.1 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||
golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect
|
||||
golang.org/x/sync v0.6.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
8
go.sum
8
go.sum
@ -71,6 +71,8 @@ github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q=
|
||||
github.com/go-co-op/gocron/v2 v2.2.9 h1:aoKosYWSSdXFLecjFWX1i8+R6V7XdZb8sB2ZKAY5Yis=
|
||||
github.com/go-co-op/gocron/v2 v2.2.9/go.mod h1:mZx3gMSlFnb97k3hRqX3+GdlG3+DUwTh6B8fnsTScXg=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
|
||||
@ -152,6 +154,8 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4=
|
||||
github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc=
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
@ -219,6 +223,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
@ -284,6 +290,8 @@ golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
|
||||
golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 h1:+iq7lrkxmFNBM7xx+Rae2W6uyPfhPeDWD+n+JgppptE=
|
||||
golang.org/x/exp v0.0.0-20231219180239-dc181d75b848/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
|
||||
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
||||
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
|
2
main.go
2
main.go
@ -9,6 +9,7 @@ import (
|
||||
"one-api/common/requester"
|
||||
"one-api/common/telegram"
|
||||
"one-api/controller"
|
||||
"one-api/cron"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay/util"
|
||||
@ -48,6 +49,7 @@ func main() {
|
||||
|
||||
controller.InitMidjourneyTask()
|
||||
notify.InitNotifier()
|
||||
cron.InitCron()
|
||||
|
||||
initHttpServer()
|
||||
}
|
||||
|
@ -105,6 +105,7 @@ func tokenAuth(c *gin.Context, key string) {
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
c.Set("chat_cache", token.ChatCache)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
channelId := common.String2Int(parts[1])
|
||||
|
41
model/chat_cache.go
Normal file
41
model/chat_cache.go
Normal file
@ -0,0 +1,41 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type ChatCache struct {
|
||||
Hash string `json:"hash" gorm:"type:varchar(32);primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"type:int;not null;index"`
|
||||
Data string `json:"data" gorm:"type:json;not null"`
|
||||
Expiration int64 `json:"expiration" gorm:"type:bigint;not null;index"`
|
||||
}
|
||||
|
||||
func (cache *ChatCache) Insert() error {
|
||||
return DB.Clauses(clause.OnConflict{
|
||||
UpdateAll: true,
|
||||
}).Create(cache).Error
|
||||
}
|
||||
|
||||
func GetChatCache(hash string, userId int) (*ChatCache, error) {
|
||||
var chatCache ChatCache
|
||||
// 获取当前时间戳
|
||||
now := time.Now().Unix()
|
||||
err := DB.Where("hash = ? and user_id = ? and expiration > ?", hash, userId, now).Find(&chatCache).Error
|
||||
return &chatCache, err
|
||||
}
|
||||
|
||||
func GetChatCacheListByUserId(userId int) ([]*ChatCache, error) {
|
||||
var chatCaches []*ChatCache
|
||||
// 获取当前时间戳
|
||||
now := time.Now().Unix()
|
||||
err := DB.Where("user_id = ? and expiration >", userId, now).Find(&chatCaches).Error
|
||||
return chatCaches, err
|
||||
}
|
||||
|
||||
func RemoveChatCache(expiration int64) error {
|
||||
now := time.Now().Unix()
|
||||
return DB.Where("expiration < ?", now).Delete(ChatCache{}).Error
|
||||
}
|
@ -144,6 +144,10 @@ func InitDB() (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = db.AutoMigrate(&ChatCache{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.SysLog("database migrated")
|
||||
err = createRootAccountIfNeed()
|
||||
return err
|
||||
|
@ -76,6 +76,9 @@ func InitOptionMap() {
|
||||
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled)
|
||||
|
||||
common.OptionMap["ChatCacheEnabled"] = strconv.FormatBool(common.ChatCacheEnabled)
|
||||
common.OptionMap["ChatCacheExpireMinute"] = strconv.Itoa(common.ChatCacheExpireMinute)
|
||||
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
}
|
||||
@ -115,14 +118,15 @@ func UpdateOption(key string, value string) error {
|
||||
}
|
||||
|
||||
var optionIntMap = map[string]*int{
|
||||
"SMTPPort": &common.SMTPPort,
|
||||
"QuotaForNewUser": &common.QuotaForNewUser,
|
||||
"QuotaForInviter": &common.QuotaForInviter,
|
||||
"QuotaForInvitee": &common.QuotaForInvitee,
|
||||
"QuotaRemindThreshold": &common.QuotaRemindThreshold,
|
||||
"PreConsumedQuota": &common.PreConsumedQuota,
|
||||
"RetryTimes": &common.RetryTimes,
|
||||
"RetryCooldownSeconds": &common.RetryCooldownSeconds,
|
||||
"SMTPPort": &common.SMTPPort,
|
||||
"QuotaForNewUser": &common.QuotaForNewUser,
|
||||
"QuotaForInviter": &common.QuotaForInviter,
|
||||
"QuotaForInvitee": &common.QuotaForInvitee,
|
||||
"QuotaRemindThreshold": &common.QuotaRemindThreshold,
|
||||
"PreConsumedQuota": &common.PreConsumedQuota,
|
||||
"RetryTimes": &common.RetryTimes,
|
||||
"RetryCooldownSeconds": &common.RetryCooldownSeconds,
|
||||
"ChatCacheExpireMinute": &common.ChatCacheExpireMinute,
|
||||
}
|
||||
|
||||
var optionBoolMap = map[string]*bool{
|
||||
@ -141,6 +145,7 @@ var optionBoolMap = map[string]*bool{
|
||||
"DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled,
|
||||
"DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled,
|
||||
"MjNotifyEnabled": &common.MjNotifyEnabled,
|
||||
"ChatCacheEnabled": &common.ChatCacheEnabled,
|
||||
}
|
||||
|
||||
var optionStringMap = map[string]*string{
|
||||
|
@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/common/stmp"
|
||||
|
||||
@ -20,6 +21,7 @@ type Token struct {
|
||||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||
ChatCache bool `json:"chat_cache" gorm:"default:false"`
|
||||
}
|
||||
|
||||
var allowedTokenOrderFields = map[string]bool{
|
||||
@ -40,7 +42,7 @@ func GetUserTokensList(userId int, params *GenericParams) (*DataResult[Token], e
|
||||
db = db.Where("name LIKE ?", params.Keyword+"%")
|
||||
}
|
||||
|
||||
return PaginateAndOrder[Token](db, ¶ms.PaginationParams, &tokens, allowedTokenOrderFields)
|
||||
return PaginateAndOrder(db, ¶ms.PaginationParams, &tokens, allowedTokenOrderFields)
|
||||
}
|
||||
|
||||
// 获取状态为可用的令牌
|
||||
@ -114,13 +116,26 @@ func GetTokenById(id int) (*Token, error) {
|
||||
}
|
||||
|
||||
func (token *Token) Insert() error {
|
||||
if token.ChatCache && !common.ChatCacheEnabled {
|
||||
token.ChatCache = false
|
||||
}
|
||||
|
||||
err := DB.Create(token).Error
|
||||
return err
|
||||
}
|
||||
|
||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||
func (token *Token) Update() error {
|
||||
err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
|
||||
if token.ChatCache && !common.ChatCacheEnabled {
|
||||
token.ChatCache = false
|
||||
}
|
||||
|
||||
err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "chat_cache").Updates(token).Error
|
||||
// 防止Redis缓存不生效,直接删除
|
||||
if err == nil && common.RedisEnabled {
|
||||
common.RedisDel(fmt.Sprintf("token:%s", token.Key))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"one-api/relay/util"
|
||||
"one-api/types"
|
||||
|
||||
providersBase "one-api/providers/base"
|
||||
@ -13,17 +14,33 @@ type relayBase struct {
|
||||
provider providersBase.ProviderInterface
|
||||
originalModel string
|
||||
modelName string
|
||||
cache *util.ChatCacheProps
|
||||
}
|
||||
|
||||
type RelayBaseInterface interface {
|
||||
send() (err *types.OpenAIErrorWithStatusCode, done bool)
|
||||
getPromptTokens() (int, error)
|
||||
setRequest() error
|
||||
getRequest() any
|
||||
setProvider(modelName string) error
|
||||
getProvider() providersBase.ProviderInterface
|
||||
getOriginalModel() string
|
||||
getModelName() string
|
||||
getContext() *gin.Context
|
||||
SetChatCache(allow bool)
|
||||
GetChatCache() *util.ChatCacheProps
|
||||
}
|
||||
|
||||
func (r *relayBase) SetChatCache(allow bool) {
|
||||
r.cache = util.NewChatCacheProps(r.c, allow)
|
||||
}
|
||||
|
||||
func (r *relayBase) GetChatCache() *util.ChatCacheProps {
|
||||
return r.cache
|
||||
}
|
||||
|
||||
func (r *relayBase) getRequest() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayBase) setProvider(modelName string) error {
|
||||
|
@ -37,6 +37,10 @@ func (r *relayChat) setRequest() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayChat) getRequest() interface{} {
|
||||
return &r.chatRequest
|
||||
}
|
||||
|
||||
func (r *relayChat) getPromptTokens() (int, error) {
|
||||
return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil
|
||||
}
|
||||
@ -58,7 +62,7 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
return
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response)
|
||||
err = responseStreamClient(r.c, response, r.cache)
|
||||
} else {
|
||||
var response *types.ChatCompletionResponse
|
||||
response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
|
||||
@ -66,6 +70,7 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
return
|
||||
}
|
||||
err = responseJsonClient(r.c, response)
|
||||
r.cache.SetResponse(response)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/providers"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/relay/util"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
@ -20,29 +21,37 @@ import (
|
||||
)
|
||||
|
||||
func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
|
||||
allowCache := false
|
||||
var relay RelayBaseInterface
|
||||
if strings.HasPrefix(path, "/v1/chat/completions") {
|
||||
return NewRelayChat(c)
|
||||
allowCache = true
|
||||
relay = NewRelayChat(c)
|
||||
} else if strings.HasPrefix(path, "/v1/completions") {
|
||||
return NewRelayCompletions(c)
|
||||
allowCache = true
|
||||
relay = NewRelayCompletions(c)
|
||||
} else if strings.HasPrefix(path, "/v1/embeddings") {
|
||||
return NewRelayEmbeddings(c)
|
||||
relay = NewRelayEmbeddings(c)
|
||||
} else if strings.HasPrefix(path, "/v1/moderations") {
|
||||
return NewRelayModerations(c)
|
||||
relay = NewRelayModerations(c)
|
||||
} else if strings.HasPrefix(path, "/v1/images/generations") {
|
||||
return NewRelayImageGenerations(c)
|
||||
relay = NewRelayImageGenerations(c)
|
||||
} else if strings.HasPrefix(path, "/v1/images/edits") {
|
||||
return NewRelayImageEdits(c)
|
||||
relay = NewRelayImageEdits(c)
|
||||
} else if strings.HasPrefix(path, "/v1/images/variations") {
|
||||
return NewRelayImageVariations(c)
|
||||
relay = NewRelayImageVariations(c)
|
||||
} else if strings.HasPrefix(path, "/v1/audio/speech") {
|
||||
return NewRelaySpeech(c)
|
||||
relay = NewRelaySpeech(c)
|
||||
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
|
||||
return NewRelayTranscriptions(c)
|
||||
relay = NewRelayTranscriptions(c)
|
||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||
return NewRelayTranslations(c)
|
||||
relay = NewRelayTranslations(c)
|
||||
}
|
||||
|
||||
return nil
|
||||
if relay != nil {
|
||||
relay.SetChatCache(allowCache)
|
||||
}
|
||||
|
||||
return relay
|
||||
}
|
||||
|
||||
func GetProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
|
||||
@ -120,7 +129,7 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string]) *types.OpenAIErrorWithStatusCode {
|
||||
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps) (errWithOP *types.OpenAIErrorWithStatusCode) {
|
||||
requester.SetEventStreamHeaders(c)
|
||||
dataChan, errChan := stream.Recv()
|
||||
|
||||
@ -128,19 +137,24 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
fmt.Fprintln(w, "data: "+data+"\n")
|
||||
streamData := "data: " + data + "\n\n"
|
||||
fmt.Fprint(w, streamData)
|
||||
cache.SetResponse(streamData)
|
||||
return true
|
||||
case err := <-errChan:
|
||||
if !errors.Is(err, io.EOF) {
|
||||
fmt.Fprintln(w, "data: "+err.Error()+"\n")
|
||||
fmt.Fprint(w, "data: "+err.Error()+"\n\n")
|
||||
errWithOP = common.ErrorWrapper(err, "stream_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fmt.Fprintln(w, "data: [DONE]")
|
||||
streamData := "data: [DONE]\n"
|
||||
fmt.Fprint(w, streamData)
|
||||
cache.SetResponse(streamData)
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
return errWithOP
|
||||
}
|
||||
|
||||
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {
|
||||
@ -174,6 +188,22 @@ func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseCache(c *gin.Context, response string) {
|
||||
// 检查是否是 data: 开头的流式数据
|
||||
isStream := strings.HasPrefix(response, "data: ")
|
||||
|
||||
if isStream {
|
||||
requester.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
fmt.Fprint(w, response)
|
||||
return false
|
||||
})
|
||||
} else {
|
||||
c.Data(http.StatusOK, "application/json", []byte(response))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func shouldRetry(c *gin.Context, statusCode int) bool {
|
||||
channelId := c.GetInt("specific_channel_id")
|
||||
if channelId > 0 {
|
||||
|
@ -37,6 +37,10 @@ func (r *relayCompletions) setRequest() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayCompletions) getRequest() interface{} {
|
||||
return &r.request
|
||||
}
|
||||
|
||||
func (r *relayCompletions) getPromptTokens() (int, error) {
|
||||
return common.CountTokenInput(r.request.Prompt, r.modelName), nil
|
||||
}
|
||||
@ -58,7 +62,7 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
|
||||
return
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response)
|
||||
err = responseStreamClient(r.c, response, r.cache)
|
||||
} else {
|
||||
var response *types.CompletionResponse
|
||||
response, err = provider.CreateCompletion(&r.request)
|
||||
@ -66,6 +70,7 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
|
||||
return
|
||||
}
|
||||
err = responseJsonClient(r.c, response)
|
||||
r.cache.SetResponse(response)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/relay/util"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@ -23,6 +24,18 @@ func Relay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
cacheProps := relay.GetChatCache()
|
||||
cacheProps.SetHash(relay.getRequest())
|
||||
|
||||
// 获取缓存
|
||||
cache := cacheProps.GetCache()
|
||||
|
||||
if cache != nil {
|
||||
// 说明有缓存, 直接返回缓存内容
|
||||
cacheProcessing(c, cache)
|
||||
return
|
||||
}
|
||||
|
||||
if err := relay.setProvider(relay.getOriginalModel()); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
|
||||
return
|
||||
@ -103,5 +116,25 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod
|
||||
}
|
||||
|
||||
quota.Consume(relay.getContext(), usage)
|
||||
cacheProps := relay.GetChatCache()
|
||||
go cacheProps.StoreCache(relay.getContext().GetInt("channel_id"), usage.PromptTokens, usage.CompletionTokens, relay.getModelName())
|
||||
return
|
||||
}
|
||||
|
||||
func cacheProcessing(c *gin.Context, cacheProps *util.ChatCacheProps) {
|
||||
responseCache(c, cacheProps.Response)
|
||||
|
||||
// 写入日志
|
||||
tokenName := c.GetString("token_name")
|
||||
|
||||
requestTime := 0
|
||||
requestStartTimeValue := c.Request.Context().Value("requestStartTime")
|
||||
if requestStartTimeValue != nil {
|
||||
requestStartTime, ok := requestStartTimeValue.(time.Time)
|
||||
if ok {
|
||||
requestTime = int(time.Since(requestStartTime).Milliseconds())
|
||||
}
|
||||
}
|
||||
|
||||
model.RecordConsumeLog(c.Request.Context(), cacheProps.UserId, cacheProps.ChannelID, cacheProps.PromptTokens, cacheProps.CompletionTokens, cacheProps.ModelName, tokenName, 0, "缓存", requestTime)
|
||||
}
|
||||
|
128
relay/util/cache.go
Normal file
128
relay/util/cache.go
Normal file
@ -0,0 +1,128 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ChatCacheProps struct {
|
||||
UserId int `json:"user_id"`
|
||||
TokenId int `json:"token_id"`
|
||||
ChannelID int `json:"channel_id"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
ModelName string `json:"model_name"`
|
||||
Response string `json:"response"`
|
||||
|
||||
Hash string `json:"-"`
|
||||
Cache bool `json:"-"`
|
||||
Driver CacheDriver `json:"-"`
|
||||
}
|
||||
|
||||
type CacheDriver interface {
|
||||
Get(hash string, userId int) *ChatCacheProps
|
||||
Set(hash string, props *ChatCacheProps, expire int64) error
|
||||
}
|
||||
|
||||
func GetDebugList(userId int) ([]*ChatCacheProps, error) {
|
||||
caches, err := model.GetChatCacheListByUserId(userId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var props []*ChatCacheProps
|
||||
for _, cache := range caches {
|
||||
prop, err := common.UnmarshalString[ChatCacheProps](cache.Data)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
props = append(props, &prop)
|
||||
}
|
||||
|
||||
return props, nil
|
||||
}
|
||||
|
||||
func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps {
|
||||
props := &ChatCacheProps{
|
||||
Cache: false,
|
||||
}
|
||||
|
||||
if !allow {
|
||||
return props
|
||||
}
|
||||
|
||||
if common.ChatCacheEnabled && c.GetBool("chat_cache") {
|
||||
props.Cache = true
|
||||
}
|
||||
|
||||
if common.RedisEnabled {
|
||||
props.Driver = &ChatCacheRedis{}
|
||||
} else {
|
||||
props.Driver = &ChatCacheDB{}
|
||||
}
|
||||
|
||||
props.UserId = c.GetInt("id")
|
||||
props.TokenId = c.GetInt("token_id")
|
||||
|
||||
return props
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) SetHash(request any) {
|
||||
if !p.needCache() || request == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.hash(common.Marshal(request))
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) SetResponse(response any) {
|
||||
if !p.needCache() || response == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if str, ok := response.(string); ok {
|
||||
p.Response += str
|
||||
return
|
||||
}
|
||||
|
||||
p.Response = common.Marshal(response)
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens int, modelName string) error {
|
||||
if !p.needCache() || p.Response == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.ChannelID = channelId
|
||||
p.PromptTokens = promptTokens
|
||||
p.CompletionTokens = completionTokens
|
||||
p.ModelName = modelName
|
||||
|
||||
return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute))
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) GetCache() *ChatCacheProps {
|
||||
if !p.needCache() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.Driver.Get(p.getHash(), p.UserId)
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) needCache() bool {
|
||||
return common.ChatCacheEnabled && p.Cache
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) getHash() string {
|
||||
return p.Hash
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) hash(request string) {
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d-%d-%s", p.UserId, p.TokenId, request)))
|
||||
p.Hash = hex.EncodeToString(hash[:])
|
||||
}
|
47
relay/util/cache_db.go
Normal file
47
relay/util/cache_db.go
Normal file
@ -0,0 +1,47 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatCacheDB struct{}
|
||||
|
||||
func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps {
|
||||
cache, _ := model.GetChatCache(hash, userId)
|
||||
if cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
props, err := common.UnmarshalString[ChatCacheProps](cache.Data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &props
|
||||
}
|
||||
|
||||
func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) error {
|
||||
return SetCacheDB(hash, props, expire)
|
||||
}
|
||||
|
||||
func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error {
|
||||
data := common.Marshal(props)
|
||||
if data == "" {
|
||||
return errors.New("marshal error")
|
||||
}
|
||||
|
||||
expire = expire * 60
|
||||
expire += time.Now().Unix()
|
||||
|
||||
cache := &model.ChatCache{
|
||||
Hash: hash,
|
||||
UserId: props.UserId,
|
||||
Data: data,
|
||||
Expiration: expire,
|
||||
}
|
||||
|
||||
return cache.Insert()
|
||||
}
|
44
relay/util/cache_redis.go
Normal file
44
relay/util/cache_redis.go
Normal file
@ -0,0 +1,44 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatCacheRedis struct{}
|
||||
|
||||
var chatCacheKey = "chat_cache"
|
||||
|
||||
func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps {
|
||||
cache, err := common.RedisGet(r.getKey(hash, userId))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
props, err := common.UnmarshalString[ChatCacheProps](cache)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &props
|
||||
}
|
||||
|
||||
func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) error {
|
||||
|
||||
if !props.Cache {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := common.Marshal(&props)
|
||||
if data == "" {
|
||||
return errors.New("marshal error")
|
||||
}
|
||||
|
||||
return common.RedisSet(r.getKey(hash, props.UserId), data, time.Duration(expire)*time.Minute)
|
||||
}
|
||||
|
||||
func (r *ChatCacheRedis) getKey(hash string, userId int) string {
|
||||
return fmt.Sprintf("%s:%d:%s", chatCacheKey, userId, hash)
|
||||
}
|
@ -30,7 +30,9 @@ const OperationSetting = () => {
|
||||
ApproximateTokenEnabled: '',
|
||||
RetryTimes: 0,
|
||||
RetryCooldownSeconds: 0,
|
||||
MjNotifyEnabled: ''
|
||||
MjNotifyEnabled: '',
|
||||
ChatCacheEnabled: '',
|
||||
ChatCacheExpireMinute: 5
|
||||
});
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
let [loading, setLoading] = useState(false);
|
||||
@ -152,6 +154,11 @@ const OperationSetting = () => {
|
||||
await updateOption('RetryCooldownSeconds', inputs.RetryCooldownSeconds);
|
||||
}
|
||||
break;
|
||||
case 'other':
|
||||
if (originInputs['ChatCacheExpireMinute'] !== inputs.ChatCacheExpireMinute) {
|
||||
await updateOption('ChatCacheExpireMinute', inputs.ChatCacheExpireMinute);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
showSuccess('保存成功!');
|
||||
@ -292,7 +299,34 @@ const OperationSetting = () => {
|
||||
label="Midjourney 允许回调(会泄露服务器ip地址)"
|
||||
control={<Checkbox checked={inputs.MjNotifyEnabled === 'true'} onChange={handleInputChange} name="MjNotifyEnabled" />}
|
||||
/>
|
||||
<FormControlLabel
|
||||
sx={{ marginLeft: '0px' }}
|
||||
label="是否开启聊天缓存(如果没有启用Redis,将会存储在数据库中)"
|
||||
control={<Checkbox checked={inputs.ChatCacheEnabled === 'true'} onChange={handleInputChange} name="ChatCacheEnabled" />}
|
||||
/>
|
||||
</Stack>
|
||||
<Stack direction={{ sm: 'column', md: 'row' }} spacing={{ xs: 3, sm: 2, md: 4 }}>
|
||||
<FormControl>
|
||||
<InputLabel htmlFor="ChatCacheExpireMinute">缓存时间(分钟)</InputLabel>
|
||||
<OutlinedInput
|
||||
id="ChatCacheExpireMinute"
|
||||
name="ChatCacheExpireMinute"
|
||||
value={inputs.ChatCacheExpireMinute}
|
||||
onChange={handleInputChange}
|
||||
label="缓存时间(分钟)"
|
||||
placeholder="开启缓存时,数据缓存的时间"
|
||||
disabled={loading}
|
||||
/>
|
||||
</FormControl>
|
||||
</Stack>
|
||||
<Button
|
||||
variant="contained"
|
||||
onClick={() => {
|
||||
submitConfig('other').then();
|
||||
}}
|
||||
>
|
||||
保存其他设置
|
||||
</Button>
|
||||
</Stack>
|
||||
</SubCard>
|
||||
<SubCard title="日志设置">
|
||||
|
@ -17,6 +17,7 @@ import {
|
||||
OutlinedInput,
|
||||
InputAdornment,
|
||||
Switch,
|
||||
FormControlLabel,
|
||||
FormHelperText
|
||||
} from '@mui/material';
|
||||
|
||||
@ -25,6 +26,7 @@ import { LocalizationProvider } from '@mui/x-date-pickers/LocalizationProvider';
|
||||
import { DateTimePicker } from '@mui/x-date-pickers/DateTimePicker';
|
||||
import { renderQuotaWithPrompt, showSuccess, showError } from 'utils/common';
|
||||
import { API } from 'utils/api';
|
||||
import { useSelector } from 'react-redux';
|
||||
require('dayjs/locale/zh-cn');
|
||||
|
||||
const validationSchema = Yup.object().shape({
|
||||
@ -40,12 +42,14 @@ const originInputs = {
|
||||
name: '',
|
||||
remain_quota: 0,
|
||||
expired_time: -1,
|
||||
unlimited_quota: false
|
||||
unlimited_quota: false,
|
||||
chat_cache: false
|
||||
};
|
||||
|
||||
const EditModal = ({ open, tokenId, onCancel, onOk }) => {
|
||||
const theme = useTheme();
|
||||
const [inputs, setInputs] = useState(originInputs);
|
||||
const siteInfo = useSelector((state) => state.siteInfo);
|
||||
|
||||
const submit = async (values, { setErrors, setStatus, setSubmitting }) => {
|
||||
setSubmitting(true);
|
||||
@ -163,17 +167,22 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => {
|
||||
)}
|
||||
</FormControl>
|
||||
)}
|
||||
<Switch
|
||||
checked={values.expired_time === -1}
|
||||
onClick={() => {
|
||||
if (values.expired_time === -1) {
|
||||
setFieldValue('expired_time', Math.floor(Date.now() / 1000));
|
||||
} else {
|
||||
setFieldValue('expired_time', -1);
|
||||
}
|
||||
}}
|
||||
/>{' '}
|
||||
永不过期
|
||||
<FormControlLabel
|
||||
control={
|
||||
<Switch
|
||||
checked={values.expired_time === -1}
|
||||
onClick={() => {
|
||||
if (values.expired_time === -1) {
|
||||
setFieldValue('expired_time', Math.floor(Date.now() / 1000));
|
||||
} else {
|
||||
setFieldValue('expired_time', -1);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
}
|
||||
label="永不过期"
|
||||
/>
|
||||
|
||||
<FormControl fullWidth error={Boolean(touched.remain_quota && errors.remain_quota)} sx={{ ...theme.typography.otherInput }}>
|
||||
<InputLabel htmlFor="channel-remain_quota-label">额度</InputLabel>
|
||||
<OutlinedInput
|
||||
@ -195,13 +204,32 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => {
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
<Switch
|
||||
checked={values.unlimited_quota === true}
|
||||
onClick={() => {
|
||||
setFieldValue('unlimited_quota', !values.unlimited_quota);
|
||||
}}
|
||||
/>{' '}
|
||||
无限额度
|
||||
<FormControl fullWidth>
|
||||
<FormControlLabel
|
||||
control={
|
||||
<Switch
|
||||
checked={values.unlimited_quota === true}
|
||||
onClick={() => {
|
||||
setFieldValue('unlimited_quota', !values.unlimited_quota);
|
||||
}}
|
||||
/>
|
||||
}
|
||||
label="无限额度"
|
||||
/>
|
||||
{siteInfo.chat_cache_enabled && (
|
||||
<FormControlLabel
|
||||
control={
|
||||
<Switch
|
||||
checked={values.chat_cache}
|
||||
onClick={() => {
|
||||
setFieldValue('chat_cache', !values.chat_cache);
|
||||
}}
|
||||
/>
|
||||
}
|
||||
label="是否开启缓存(开启后,将会缓存聊天记录,以减少消费)"
|
||||
/>
|
||||
)}
|
||||
</FormControl>
|
||||
<DialogActions>
|
||||
<Button onClick={onCancel}>取消</Button>
|
||||
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
|
||||
|
Loading…
Reference in New Issue
Block a user