diff --git a/README.md b/README.md index 69bb10ef..1cb30591 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [百川大模型](https://platform.baichuan-ai.com) + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) + [x] [MINIMAX](https://api.minimax.chat/) + + [x] [Groq](https://wow.groq.com/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 @@ -105,6 +106,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [GitHub 开放授权](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 +24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 ## 部署 ### 基于 Docker 进行部署 @@ -374,6 +376,9 @@ graph LR 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +19. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 +20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 +21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/blacklist/main.go b/common/blacklist/main.go new file mode 100644 index 00000000..f84ce6ae --- /dev/null +++ b/common/blacklist/main.go @@ -0,0 +1,29 @@ +package blacklist + +import ( + "fmt" + "sync" +) + +var blackList sync.Map + +func init() { + blackList = sync.Map{} +} + +func userId2Key(id int) string { + return fmt.Sprintf("userid_%d", id) +} + +func BanUser(id int) { + blackList.Store(userId2Key(id), true) +} + +func UnbanUser(id int) { + blackList.Delete(userId2Key(id)) +} + +func IsUserBanned(id int) bool { + _, ok := blackList.Load(userId2Key(id)) + return ok +} diff --git a/common/config/config.go b/common/config/config.go index dd0236b4..53af824f 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -52,6 +52,7 @@ var EmailDomainWhitelist = []string{ } var DebugEnabled = os.Getenv("DEBUG") == "true" +var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true" var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" var LogConsumeEnabled = true @@ -69,6 +70,9 @@ var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" +var MessagePusherAddress = "" +var MessagePusherToken = "" + var TurnstileSiteKey = "" var TurnstileSecretKey = "" @@ -125,3 +129,9 @@ var ( ) var RateLimitKeyExpirationDuration = 20 * time.Minute + +var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false) +var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10) +var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) +var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024) +var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128) diff --git a/common/constants.go b/common/constants.go index 63c627bc..de71bc7a 100644 --- a/common/constants.go +++ b/common/constants.go @@ -15,6 +15,7 @@ const ( const ( UserStatusEnabled = 1 // don't use 0, 0 is the default value! UserStatusDisabled = 2 // also don't use 0 + UserStatusDeleted = 3 ) const ( @@ -67,6 +68,7 @@ const ( ChannelTypeBaichuan ChannelTypeMinimax ChannelTypeMistral + ChannelTypeGroq ChannelTypeDummy ) @@ -101,6 +103,7 @@ var ChannelBaseURLs = []string{ "https://api.baichuan-ai.com", // 26 "https://api.minimax.chat", // 27 "https://api.mistral.ai", // 28 + "https://api.groq.com/openai", // 29 } const ( diff --git a/common/database.go b/common/database.go index 9b52a0d5..df60bdd5 100644 --- a/common/database.go +++ b/common/database.go @@ -4,6 +4,7 @@ import "github.com/songquanpeng/one-api/common/helper" var UsingSQLite = false var UsingPostgreSQL = false +var UsingMySQL = false var SQLitePath = "one-api.db" var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/helper/helper.go b/common/helper/helper.go index babe422b..23578842 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -195,6 +195,13 @@ func Max(a int, b int) int { } } +func GetOrDefaultEnvBool(env string, defaultValue bool) bool { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) == "true" +} + func GetOrDefaultEnvInt(env string, defaultValue int) int { if env == "" || os.Getenv(env) == "" { return defaultValue @@ -207,6 +214,18 @@ func GetOrDefaultEnvInt(env string, defaultValue int) int { return num } +func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.ParseFloat(os.Getenv(env), 64) + if err != nil { + logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue)) + return defaultValue + } + return num +} + func GetOrDefaultEnvString(env string, defaultValue string) string { if env == "" || os.Getenv(env) == "" { return defaultValue diff --git a/common/logger/logger.go b/common/logger/logger.go index 8232b2fc..41b98ca3 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -19,9 +19,6 @@ const ( loggerError = "ERR" ) -const maxLogCount = 1000000 - -var logCount int var setupLogLock sync.Mutex var setupLogWorking bool @@ -96,9 +93,7 @@ func logHelper(ctx context.Context, level string, msg string) { id := ctx.Value(RequestIdKey) now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) - logCount++ // we don't need accurate count, so no lock here - if logCount > maxLogCount && !setupLogWorking { - logCount = 0 + if !setupLogWorking { setupLogWorking = true go func() { SetupLogger() diff --git a/common/email.go b/common/message/email.go similarity index 96% rename from common/email.go rename to common/message/email.go index 2689da6a..b06782db 100644 --- a/common/email.go +++ b/common/message/email.go @@ -1,4 +1,4 @@ -package common +package message import ( "crypto/rand" @@ -12,6 +12,9 @@ import ( ) func SendEmail(subject string, receiver string, content string) error { + if receiver == "" { + return fmt.Errorf("receiver is empty") + } if config.SMTPFrom == "" { // for compatibility config.SMTPFrom = config.SMTPAccount } diff --git a/common/message/main.go b/common/message/main.go new file mode 100644 index 00000000..5ce82a64 --- /dev/null +++ b/common/message/main.go @@ -0,0 +1,22 @@ +package message + +import ( + "fmt" + "github.com/songquanpeng/one-api/common/config" +) + +const ( + ByAll = "all" + ByEmail = "email" + ByMessagePusher = "message_pusher" +) + +func Notify(by string, title string, description string, content string) error { + if by == ByEmail { + return SendEmail(title, config.RootUserEmail, content) + } + if by == ByMessagePusher { + return SendMessage(title, description, content) + } + return fmt.Errorf("unknown notify method: %s", by) +} diff --git a/common/message/message-pusher.go b/common/message/message-pusher.go new file mode 100644 index 00000000..69949b4b --- /dev/null +++ b/common/message/message-pusher.go @@ -0,0 +1,53 @@ +package message + +import ( + "bytes" + "encoding/json" + "errors" + "github.com/songquanpeng/one-api/common/config" + "net/http" +) + +type request struct { + Title string `json:"title"` + Description string `json:"description"` + Content string `json:"content"` + URL string `json:"url"` + Channel string `json:"channel"` + Token string `json:"token"` +} + +type response struct { + Success bool `json:"success"` + Message string `json:"message"` +} + +func SendMessage(title string, description string, content string) error { + if config.MessagePusherAddress == "" { + return errors.New("message pusher address is not set") + } + req := request{ + Title: title, + Description: description, + Content: content, + Token: config.MessagePusherToken, + } + data, err := json.Marshal(req) + if err != nil { + return err + } + resp, err := http.Post(config.MessagePusherAddress, + "application/json", bytes.NewBuffer(data)) + if err != nil { + return err + } + var res response + err = json.NewDecoder(resp.Body).Decode(&res) + if err != nil { + return err + } + if !res.Success { + return errors.New(res.Message) + } + return nil +} diff --git a/common/model-ratio.go b/common/model-ratio.go index ac0b37ad..5b0a759b 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -125,6 +125,11 @@ var ModelRatio = map[string]float64{ "mistral-medium-latest": 2.7 / 1000 * USD, "mistral-large-latest": 8.0 / 1000 * USD, "mistral-embed": 0.1 / 1000 * USD, + // https://wow.groq.com/ + "llama2-70b-4096": 0.7 / 1000 * USD, + "llama2-7b-2048": 0.1 / 1000 * USD, + "mixtral-8x7b-32768": 0.27 / 1000 * USD, + "gemma-7b-it": 0.1 / 1000 * USD, } var CompletionRatio = map[string]float64{} @@ -143,6 +148,26 @@ func init() { } } +func AddNewMissingRatio(oldRatio string) string { + newRatio := make(map[string]float64) + err := json.Unmarshal([]byte(oldRatio), &newRatio) + if err != nil { + logger.SysError("error unmarshalling old ratio: " + err.Error()) + return oldRatio + } + for k, v := range DefaultModelRatio { + if _, ok := newRatio[k]; !ok { + newRatio[k] = v + } + } + jsonBytes, err := json.Marshal(newRatio) + if err != nil { + logger.SysError("error marshalling new ratio: " + err.Error()) + return oldRatio + } + return string(jsonBytes) +} + func ModelRatio2JSONString() string { jsonBytes, err := json.Marshal(ModelRatio) if err != nil { @@ -209,7 +234,7 @@ func GetCompletionRatio(name string) float64 { return 2 } } - return 1.333333 + return 4.0 / 3.0 } if strings.HasPrefix(name, "gpt-4") { if strings.HasSuffix(name, "preview") { @@ -226,5 +251,9 @@ func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "mistral-") { return 3 } + switch name { + case "llama2-70b-4096": + return 0.8 / 0.7 + } return 1 } diff --git a/controller/channel-billing.go b/controller/channel-billing.go index abeab26a..03c97349 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" @@ -295,7 +296,7 @@ func UpdateChannelBalance(c *gin.Context) { } func updateAllChannelsBalance() error { - channels, err := model.GetAllChannels(0, 0, true) + channels, err := model.GetAllChannels(0, 0, "all") if err != nil { return err } @@ -313,7 +314,7 @@ func updateAllChannelsBalance() error { } else { // err is nil & balance <= 0 means quota is used up if balance <= 0 { - disableChannel(channel.Id, channel.Name, "余额不足") + monitor.DisableChannel(channel.Id, channel.Name, "余额不足") } } time.Sleep(config.RequestInterval) @@ -322,15 +323,14 @@ func updateAllChannelsBalance() error { } func UpdateAllChannelsBalance(c *gin.Context) { - // TODO: make it async - err := updateAllChannelsBalance() - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } + //err := updateAllChannelsBalance() + //if err != nil { + // c.JSON(http.StatusOK, gin.H{ + // "success": false, + // "message": err.Error(), + // }) + // return + //} c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/controller/channel-test.go b/controller/channel-test.go index 7007e205..6d18305a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -8,8 +8,10 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" relaymodel "github.com/songquanpeng/one-api/relay/model" @@ -148,33 +150,7 @@ func TestChannel(c *gin.Context) { var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false -func notifyRootUser(subject string, content string) { - if config.RootUserEmail == "" { - config.RootUserEmail = model.GetRootUserEmail() - } - err := common.SendEmail(subject, config.RootUserEmail, content) - if err != nil { - logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } -} - -// disable & notify -func disableChannel(channelId int, channelName string, reason string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - notifyRootUser(subject, content) -} - -// enable & notify -func enableChannel(channelId int, channelName string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - notifyRootUser(subject, content) -} - -func testAllChannels(notify bool) error { +func testChannels(notify bool, scope string) error { if config.RootUserEmail == "" { config.RootUserEmail = model.GetRootUserEmail() } @@ -185,7 +161,7 @@ func testAllChannels(notify bool) error { } testAllChannelsRunning = true testAllChannelsLock.Unlock() - channels, err := model.GetAllChannels(0, 0, true) + channels, err := model.GetAllChannels(0, 0, scope) if err != nil { return err } @@ -202,13 +178,17 @@ func testAllChannels(notify bool) error { milliseconds := tok.Sub(tik).Milliseconds() if isChannelEnabled && milliseconds > disableThreshold { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - disableChannel(channel.Id, channel.Name, err.Error()) + if config.AutomaticDisableChannelEnabled { + monitor.DisableChannel(channel.Id, channel.Name, err.Error()) + } else { + _ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error()) + } } if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { - disableChannel(channel.Id, channel.Name, err.Error()) + monitor.DisableChannel(channel.Id, channel.Name, err.Error()) } if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { - enableChannel(channel.Id, channel.Name) + monitor.EnableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) time.Sleep(config.RequestInterval) @@ -217,7 +197,7 @@ func testAllChannels(notify bool) error { testAllChannelsRunning = false testAllChannelsLock.Unlock() if notify { - err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") + err := message.Notify(message.ByAll, "通道测试完成", "", "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") if err != nil { logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } @@ -226,8 +206,12 @@ func testAllChannels(notify bool) error { return nil } -func TestAllChannels(c *gin.Context) { - err := testAllChannels(true) +func TestChannels(c *gin.Context) { + scope := c.Query("scope") + if scope == "" { + scope = "all" + } + err := testChannels(true, scope) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -246,7 +230,7 @@ func AutomaticallyTestChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) logger.SysLog("testing all channels") - _ = testAllChannels(false) + _ = testChannels(false, "all") logger.SysLog("channel test finished") } } diff --git a/controller/channel.go b/controller/channel.go index bdfa00d9..37bfb99d 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -15,7 +15,7 @@ func GetAllChannels(c *gin.Context) { if p < 0 { p = 0 } - channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) + channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/misc.go b/controller/misc.go index 036bdbd1..f27fdb12 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/model" "net/http" "strings" @@ -110,7 +111,7 @@ func SendEmailVerification(c *gin.Context) { content := fmt.Sprintf("

您好,你正在进行%s邮箱验证。

"+ "

您的验证码为: %s

"+ "

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

", config.SystemName, code, common.VerificationValidMinutes) - err := common.SendEmail(subject, email, content) + err := message.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -149,7 +150,7 @@ func SendPasswordResetEmail(c *gin.Context) { "

点击 此处 进行密码重置。

"+ "

如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s

"+ "

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

", config.SystemName, link, link, common.VerificationValidMinutes) - err := common.SendEmail(subject, email, content) + err := message.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/model.go b/controller/model.go index 0486634c..4c5476b4 100644 --- a/controller/model.go +++ b/controller/model.go @@ -4,11 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel/ai360" - "github.com/songquanpeng/one-api/relay/channel/baichuan" - "github.com/songquanpeng/one-api/relay/channel/minimax" - "github.com/songquanpeng/one-api/relay/channel/mistral" - "github.com/songquanpeng/one-api/relay/channel/moonshot" + "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" relaymodel "github.com/songquanpeng/one-api/relay/model" @@ -83,60 +79,22 @@ func init() { }) } } - for _, modelName := range ai360.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "360", - Permission: permission, - Root: modelName, - Parent: nil, - }) - } - for _, modelName := range moonshot.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "moonshot", - Permission: permission, - Root: modelName, - Parent: nil, - }) - } - for _, modelName := range baichuan.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "baichuan", - Permission: permission, - Root: modelName, - Parent: nil, - }) - } - for _, modelName := range minimax.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "minimax", - Permission: permission, - Root: modelName, - Parent: nil, - }) - } - for _, modelName := range mistral.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "mistralai", - Permission: permission, - Root: modelName, - Parent: nil, - }) + for _, channelType := range openai.CompatibleChannels { + if channelType == common.ChannelTypeAzure { + continue + } + channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) + for _, modelName := range channelModelList { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: channelName, + Permission: permission, + Root: modelName, + Parent: nil, + }) + } } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay.go b/controller/relay.go index 9b2d462c..b34768df 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -11,6 +11,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/middleware" dbmodel "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" @@ -45,11 +46,12 @@ func Relay(c *gin.Context) { requestBody, _ := common.GetRequestBody(c) logger.Debugf(ctx, "request body: %s", string(requestBody)) } + channelId := c.GetInt("channel_id") bizErr := relay(c, relayMode) if bizErr == nil { + monitor.Emit(channelId, true) return } - channelId := c.GetInt("channel_id") lastFailedChannelId := channelId channelName := c.GetString("channel_name") group := c.GetString("group") @@ -117,7 +119,9 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors if util.ShouldDisableChannel(&err.Error, err.StatusCode) { - disableChannel(channelId, channelName, err.Message) + monitor.DisableChannel(channelId, channelName, err.Message) + } else { + monitor.Emit(channelId, false) } } diff --git a/main.go b/main.go index 1f43a45f..83d7e7ed 100644 --- a/main.go +++ b/main.go @@ -64,13 +64,6 @@ func main() { go model.SyncOptions(config.SyncFrequency) go model.SyncChannelCache(config.SyncFrequency) } - if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { - frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) - if err != nil { - logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) - } - go controller.AutomaticallyUpdateChannels(frequency) - } if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) if err != nil { @@ -83,6 +76,9 @@ func main() { logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") model.InitBatchUpdater() } + if config.EnableMetric { + logger.SysLog("metric enabled, will disable channel if too much request failed") + } openai.InitTokenEncoders() // Initialize HTTP server diff --git a/middleware/auth.go b/middleware/auth.go index 9d25f395..30997efd 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -4,6 +4,7 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/model" "net/http" "strings" @@ -42,11 +43,14 @@ func authHelper(c *gin.Context, minRole int) { return } } - if status.(int) == common.UserStatusDisabled { + if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户已被封禁", }) + session := sessions.Default(c) + session.Clear() + _ = session.Save() c.Abort() return } @@ -99,7 +103,7 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusInternalServerError, err.Error()) return } - if !userEnabled { + if !userEnabled || blacklist.IsUserBanned(token.UserId) { abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } diff --git a/middleware/recover.go b/middleware/recover.go index 02e3e3bb..cfc3f827 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -3,6 +3,7 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "net/http" "runtime/debug" @@ -12,11 +13,15 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - logger.SysError(fmt.Sprintf("panic detected: %v", err)) - logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + ctx := c.Request.Context() + logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err)) + logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path)) + body, _ := common.GetRequestBody(c) + logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ - "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), + "message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err), "type": "one_api_panic", }, }) diff --git a/model/channel.go b/model/channel.go index 19af2263..605c6d17 100644 --- a/model/channel.go +++ b/model/channel.go @@ -13,7 +13,7 @@ import ( type Channel struct { Id int `json:"id"` Type int `json:"type" gorm:"default:0"` - Key string `json:"key" gorm:"not null;index"` + Key string `json:"key" gorm:"type:text"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` Weight *uint `json:"weight" gorm:"default:0"` @@ -32,23 +32,22 @@ type Channel struct { Config string `json:"config"` } -func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { +func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { var channels []*Channel var err error - if selectAll { + switch scope { + case "all": err = DB.Order("id desc").Find(&channels).Error - } else { + case "disabled": + err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error + default: err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error } return channels, err } func SearchChannels(keyword string) (channels []*Channel, err error) { - keyCol := "`key`" - if common.UsingPostgreSQL { - keyCol = `"key"` - } - err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error + err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error return channels, err } diff --git a/model/main.go b/model/main.go index 18ed01d0..f27cdb6f 100644 --- a/model/main.go +++ b/model/main.go @@ -56,6 +56,7 @@ func chooseDB() (*gorm.DB, error) { } // Use MySQL logger.SysLog("using MySQL as database") + common.UsingMySQL = true return gorm.Open(mysql.Open(dsn), &gorm.Config{ PrepareStmt: true, // precompile SQL }) @@ -72,7 +73,7 @@ func chooseDB() (*gorm.DB, error) { func InitDB() (err error) { db, err := chooseDB() if err == nil { - if config.DebugEnabled { + if config.DebugSQLEnabled { db = db.Debug() } DB = db @@ -87,6 +88,9 @@ func InitDB() (err error) { if !config.IsMasterNode { return nil } + if common.UsingMySQL { + _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded + } logger.SysLog("database migration started") err = db.AutoMigrate(&Channel{}) if err != nil { diff --git a/model/option.go b/model/option.go index 6002c795..e129b9f0 100644 --- a/model/option.go +++ b/model/option.go @@ -57,6 +57,8 @@ func InitOptionMap() { config.OptionMap["WeChatServerAddress"] = "" config.OptionMap["WeChatServerToken"] = "" config.OptionMap["WeChatAccountQRCodeImageURL"] = "" + config.OptionMap["MessagePusherAddress"] = "" + config.OptionMap["MessagePusherToken"] = "" config.OptionMap["TurnstileSiteKey"] = "" config.OptionMap["TurnstileSecretKey"] = "" config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) @@ -79,6 +81,9 @@ func InitOptionMap() { func loadOptionsFromDatabase() { options, _ := AllOption() for _, option := range options { + if option.Key == "ModelRatio" { + option.Value = common.AddNewMissingRatio(option.Value) + } err := updateOptionMap(option.Key, option.Value) if err != nil { logger.SysError("failed to update option map: " + err.Error()) @@ -179,6 +184,10 @@ func updateOptionMap(key string, value string) (err error) { config.WeChatServerToken = value case "WeChatAccountQRCodeImageURL": config.WeChatAccountQRCodeImageURL = value + case "MessagePusherAddress": + config.MessagePusherAddress = value + case "MessagePusherToken": + config.MessagePusherToken = value case "TurnstileSiteKey": config.TurnstileSiteKey = value case "TurnstileSecretKey": diff --git a/model/token.go b/model/token.go index d0a0648a..c4669e0b 100644 --- a/model/token.go +++ b/model/token.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" "gorm.io/gorm" ) @@ -213,7 +214,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { } if email != "" { topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) - err = common.SendEmail(prompt, email, + err = message.SendEmail(prompt, email, fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) if err != nil { logger.SysError("failed to send email" + err.Error()) diff --git a/model/user.go b/model/user.go index ca6e28bc..5973c8c9 100644 --- a/model/user.go +++ b/model/user.go @@ -5,6 +5,7 @@ import ( "fmt" "regexp" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" @@ -41,7 +42,7 @@ func GetMaxUserId() int { } func GetAllUsers(startIdx int, num int) (users []*User, err error) { - err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error + err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted).Find(&users).Error return users, err } @@ -131,6 +132,11 @@ func (user *User) Update(updatePassword bool) error { return err } } + if user.Status == common.UserStatusDisabled { + blacklist.BanUser(user.Id) + } else if user.Status == common.UserStatusEnabled { + blacklist.UnbanUser(user.Id) + } err = DB.Model(user).Updates(user).Error return err } @@ -139,7 +145,10 @@ func (user *User) Delete() error { if user.Id == 0 { return errors.New("id 为空!") } - err := DB.Delete(user).Error + blacklist.BanUser(user.Id) + user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID()) + user.Status = common.UserStatusDeleted + err := DB.Model(user).Updates(user).Error return err } diff --git a/monitor/channel.go b/monitor/channel.go new file mode 100644 index 00000000..597ab11a --- /dev/null +++ b/monitor/channel.go @@ -0,0 +1,55 @@ +package monitor + +import ( + "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" + "github.com/songquanpeng/one-api/model" +) + +func notifyRootUser(subject string, content string) { + if config.MessagePusherAddress != "" { + err := message.SendMessage(subject, content, content) + if err != nil { + logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error())) + } else { + return + } + } + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() + } + err := message.SendEmail(subject, config.RootUserEmail, content) + if err != nil { + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } +} + +// DisableChannel disable & notify +func DisableChannel(channelId int, channelName string, reason string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + notifyRootUser(subject, content) +} + +func MetricDisableChannel(channelId int, successRate float64) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) + subject := fmt.Sprintf("通道 #%d 已被禁用", channelId) + content := fmt.Sprintf("该渠道在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", + config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) + notifyRootUser(subject, content) +} + +// EnableChannel enable & notify +func EnableChannel(channelId int, channelName string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) + subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + notifyRootUser(subject, content) +} diff --git a/monitor/metric.go b/monitor/metric.go new file mode 100644 index 00000000..98bc546e --- /dev/null +++ b/monitor/metric.go @@ -0,0 +1,79 @@ +package monitor + +import ( + "github.com/songquanpeng/one-api/common/config" +) + +var store = make(map[int][]bool) +var metricSuccessChan = make(chan int, config.MetricSuccessChanSize) +var metricFailChan = make(chan int, config.MetricFailChanSize) + +func consumeSuccess(channelId int) { + if len(store[channelId]) > config.MetricQueueSize { + store[channelId] = store[channelId][1:] + } + store[channelId] = append(store[channelId], true) +} + +func consumeFail(channelId int) (bool, float64) { + if len(store[channelId]) > config.MetricQueueSize { + store[channelId] = store[channelId][1:] + } + store[channelId] = append(store[channelId], false) + successCount := 0 + for _, success := range store[channelId] { + if success { + successCount++ + } + } + successRate := float64(successCount) / float64(len(store[channelId])) + if len(store[channelId]) < config.MetricQueueSize { + return false, successRate + } + if successRate < config.MetricSuccessRateThreshold { + store[channelId] = make([]bool, 0) + return true, successRate + } + return false, successRate +} + +func metricSuccessConsumer() { + for { + select { + case channelId := <-metricSuccessChan: + consumeSuccess(channelId) + } + } +} + +func metricFailConsumer() { + for { + select { + case channelId := <-metricFailChan: + disable, successRate := consumeFail(channelId) + if disable { + go MetricDisableChannel(channelId, successRate) + } + } + } +} + +func init() { + if config.EnableMetric { + go metricSuccessConsumer() + go metricFailConsumer() + } +} + +func Emit(channelId int, success bool) { + if !config.EnableMetric { + return + } + go func() { + if success { + metricSuccessChan <- channelId + } else { + metricFailChan <- channelId + } + }() +} diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go index 5e291690..a165b35c 100644 --- a/relay/channel/anthropic/adaptor.go +++ b/relay/channel/anthropic/adaptor.go @@ -59,5 +59,5 @@ func (a *Adaptor) GetModelList() []string { } func (a *Adaptor) GetChannelName() string { - return "authropic" + return "anthropic" } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 066a8107..1a96997a 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -2,6 +2,7 @@ package baidu import ( "errors" + "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/constant" @@ -9,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/relay/util" "io" "net/http" + "strings" ) type Adaptor struct { @@ -20,25 +22,33 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t - var fullRequestURL string - switch meta.ActualModelName { - case "ERNIE-Bot-4": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" - case "ERNIE-Bot-8K": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k" - case "ERNIE-Bot": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" - case "ERNIE-Speed": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed" - case "ERNIE-Bot-turbo": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" - case "BLOOMZ-7B": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" - case "Embedding-V1": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" - default: - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + meta.ActualModelName + suffix := "chat/" + if strings.HasPrefix("Embedding", meta.ActualModelName) { + suffix = "embeddings/" } + switch meta.ActualModelName { + case "ERNIE-4.0": + suffix += "completions_pro" + case "ERNIE-Bot-4": + suffix += "completions_pro" + case "ERNIE-3.5-8K": + suffix += "completions" + case "ERNIE-Bot-8K": + suffix += "ernie_bot_8k" + case "ERNIE-Bot": + suffix += "completions" + case "ERNIE-Speed": + suffix += "ernie_speed" + case "ERNIE-Bot-turbo": + suffix += "eb-instant" + case "BLOOMZ-7B": + suffix += "bloomz_7b1" + case "Embedding-V1": + suffix += "embedding-v1" + default: + suffix += meta.ActualModelName + } + fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) var accessToken string var err error if accessToken, err = GetAccessToken(meta.APIKey); err != nil { diff --git a/relay/channel/groq/constants.go b/relay/channel/groq/constants.go new file mode 100644 index 00000000..fc9a9ebd --- /dev/null +++ b/relay/channel/groq/constants.go @@ -0,0 +1,10 @@ +package groq + +// https://console.groq.com/docs/models + +var ModelList = []string{ + "gemma-7b-it", + "llama2-7b-2048", + "llama2-70b-4096", + "mixtral-8x7b-32768", +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 5a04a768..47594030 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -6,11 +6,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/ai360" - "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/minimax" - "github.com/songquanpeng/one-api/relay/channel/mistral" - "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -86,37 +82,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel } func (a *Adaptor) GetModelList() []string { - switch a.ChannelType { - case common.ChannelType360: - return ai360.ModelList - case common.ChannelTypeMoonshot: - return moonshot.ModelList - case common.ChannelTypeBaichuan: - return baichuan.ModelList - case common.ChannelTypeMinimax: - return minimax.ModelList - case common.ChannelTypeMistral: - return mistral.ModelList - default: - return ModelList - } + _, modelList := GetCompatibleChannelMeta(a.ChannelType) + return modelList } func (a *Adaptor) GetChannelName() string { - switch a.ChannelType { - case common.ChannelTypeAzure: - return "azure" - case common.ChannelType360: - return "360" - case common.ChannelTypeMoonshot: - return "moonshot" - case common.ChannelTypeBaichuan: - return "baichuan" - case common.ChannelTypeMinimax: - return "minimax" - case common.ChannelTypeMistral: - return "mistralai" - default: - return "openai" - } + channelName, _ := GetCompatibleChannelMeta(a.ChannelType) + return channelName } diff --git a/relay/channel/openai/compatible.go b/relay/channel/openai/compatible.go new file mode 100644 index 00000000..767eec4b --- /dev/null +++ b/relay/channel/openai/compatible.go @@ -0,0 +1,42 @@ +package openai + +import ( + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/relay/channel/ai360" + "github.com/songquanpeng/one-api/relay/channel/baichuan" + "github.com/songquanpeng/one-api/relay/channel/groq" + "github.com/songquanpeng/one-api/relay/channel/minimax" + "github.com/songquanpeng/one-api/relay/channel/mistral" + "github.com/songquanpeng/one-api/relay/channel/moonshot" +) + +var CompatibleChannels = []int{ + common.ChannelTypeAzure, + common.ChannelType360, + common.ChannelTypeMoonshot, + common.ChannelTypeBaichuan, + common.ChannelTypeMinimax, + common.ChannelTypeMistral, + common.ChannelTypeGroq, +} + +func GetCompatibleChannelMeta(channelType int) (string, []string) { + switch channelType { + case common.ChannelTypeAzure: + return "azure", ModelList + case common.ChannelType360: + return "360", ai360.ModelList + case common.ChannelTypeMoonshot: + return "moonshot", moonshot.ModelList + case common.ChannelTypeBaichuan: + return "baichuan", baichuan.ModelList + case common.ChannelTypeMinimax: + return "minimax", minimax.ModelList + case common.ChannelTypeMistral: + return "mistralai", mistral.ModelList + case common.ChannelTypeGroq: + return "groq", groq.ModelList + default: + return "openai", ModelList + } +} diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go index fa26651b..cfdc0bfd 100644 --- a/relay/channel/tencent/main.go +++ b/relay/channel/tencent/main.go @@ -28,17 +28,6 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] - if message.Role == "system" { - messages = append(messages, Message{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "assistant", - Content: "Okay", - }) - continue - } messages = append(messages, Message{ Content: message.StringContent(), Role: message.Role, diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go index 620e808f..f89aea2b 100644 --- a/relay/channel/xunfei/main.go +++ b/relay/channel/xunfei/main.go @@ -27,21 +27,10 @@ import ( func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { - if message.Role == "system" { - messages = append(messages, Message{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "assistant", - Content: "Okay", - }) - } else { - messages = append(messages, Message{ - Role: message.Role, - Content: message.StringContent(), - }) - } + messages = append(messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) } xunfeiRequest := ChatRequest{} xunfeiRequest.Header.AppId = xunfeiAppId diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go index 7c3e83f3..a46fd537 100644 --- a/relay/channel/zhipu/main.go +++ b/relay/channel/zhipu/main.go @@ -76,21 +76,10 @@ func GetToken(apikey string) string { func ConvertRequest(request model.GeneralOpenAIRequest) *Request { messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { - if message.Role == "system" { - messages = append(messages, Message{ - Role: "system", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "user", - Content: "Okay", - }) - } else { - messages = append(messages, Message{ - Role: message.Role, - Content: message.StringContent(), - }) - } + messages = append(messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) } return &Request{ Prompt: messages, diff --git a/relay/controller/text.go b/relay/controller/text.go index 59c5f637..781170f4 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -83,11 +83,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - if resp.StatusCode != http.StatusOK { + errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") + if errorHappened { util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return util.RelayErrorHandler(resp) } + meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") // do response usage, respErr := adaptor.DoResponse(c, resp, meta) diff --git a/relay/util/common.go b/relay/util/common.go index 6d993378..20257488 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -27,7 +27,16 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { if statusCode == http.StatusUnauthorized { return true } - if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + switch err.Type { + case "insufficient_quota": + return true + // https://docs.anthropic.com/claude/reference/errors + case "authentication_error": + return true + case "permission_error": + return true + } + if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { return true } return false @@ -101,6 +110,9 @@ func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.Err if err != nil { return } + if config.DebugEnabled { + logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody))) + } err = resp.Body.Close() if err != nil { return diff --git a/router/api-router.go b/router/api-router.go index dc1fdc6b..5b755ede 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -70,7 +70,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/models", controller.ListModels) channelRoute.GET("/:id", controller.GetChannel) - channelRoute.GET("/test", controller.TestAllChannels) + channelRoute.GET("/test", controller.TestChannels) channelRoute.GET("/test/:id", controller.TestChannel) channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index 31c45048..8e9fc97c 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -15,7 +15,7 @@ export const CHANNEL_OPTIONS = { key: 3, text: 'Azure OpenAI', value: 3, - color: 'orange' + color: 'secondary' }, 11: { key: 11, @@ -89,6 +89,12 @@ export const CHANNEL_OPTIONS = { value: 27, color: 'default' }, + 29: { + key: 29, + text: 'Groq', + value: 29, + color: 'default' + }, 8: { key: 8, text: '自定义渠道', diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index 4dec33de..897db189 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -163,6 +163,9 @@ const typeConfig = { }, modelGroup: "minimax", }, + 29: { + modelGroup: "groq", + }, }; export { defaultConfig, typeConfig }; diff --git a/web/default/src/components/ChannelsTable.js b/web/default/src/components/ChannelsTable.js index 358b9262..5f837d03 100644 --- a/web/default/src/components/ChannelsTable.js +++ b/web/default/src/components/ChannelsTable.js @@ -240,11 +240,11 @@ const ChannelsTable = () => { } }; - const testAllChannels = async () => { - const res = await API.get(`/api/channel/test`); + const testChannels = async (scope) => { + const res = await API.get(`/api/channel/test?scope=${scope}`); const { success, message } = res.data; if (success) { - showInfo('已成功开始测试所有通道,请刷新页面查看结果。'); + showInfo('已成功开始测试通道,请刷新页面查看结果。'); } else { showError(message); } @@ -529,9 +529,12 @@ const ChannelsTable = () => { - + {/**/} { const [disableButton, setDisableButton] = useState(false); const [countdown, setCountdown] = useState(30); + useEffect(() => { + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + }, []); + useEffect(() => { let countdownInterval = null; if (disableButton && countdown > 0) { diff --git a/web/default/src/components/SystemSetting.js b/web/default/src/components/SystemSetting.js index 7b34ce5b..09b98665 100644 --- a/web/default/src/components/SystemSetting.js +++ b/web/default/src/components/SystemSetting.js @@ -22,6 +22,8 @@ const SystemSetting = () => { WeChatServerAddress: '', WeChatServerToken: '', WeChatAccountQRCodeImageURL: '', + MessagePusherAddress: '', + MessagePusherToken: '', TurnstileCheckEnabled: '', TurnstileSiteKey: '', TurnstileSecretKey: '', @@ -183,6 +185,21 @@ const SystemSetting = () => { } }; + const submitMessagePusher = async () => { + if (originInputs['MessagePusherAddress'] !== inputs.MessagePusherAddress) { + await updateOption( + 'MessagePusherAddress', + removeTrailingSlash(inputs.MessagePusherAddress) + ); + } + if ( + originInputs['MessagePusherToken'] !== inputs.MessagePusherToken && + inputs.MessagePusherToken !== '' + ) { + await updateOption('MessagePusherToken', inputs.MessagePusherToken); + } + }; + const submitGitHubOAuth = async () => { if (originInputs['GitHubClientId'] !== inputs.GitHubClientId) { await updateOption('GitHubClientId', inputs.GitHubClientId); @@ -496,6 +513,42 @@ const SystemSetting = () => { 保存 WeChat Server 设置 +
+ 配置 Message Pusher + + 用以推送报警信息, + + 点击此处 + + 了解 Message Pusher + +
+ + + + + + 保存 Message Pusher 设置 + +
配置 Turnstile diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index b21bb15d..f6db46c3 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -14,6 +14,7 @@ export const CHANNEL_OPTIONS = [ { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, { key: 26, text: '百川大模型', value: 26, color: 'orange' }, { key: 27, text: 'MiniMax', value: 27, color: 'red' }, + { key: 29, text: 'Groq', value: 29, color: 'orange' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index 59cce0d4..4de8e87a 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -1,7 +1,7 @@ import React, { useEffect, useState } from 'react'; import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; import { useNavigate, useParams } from 'react-router-dom'; -import { API, getChannelModels, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; +import { API, copy, getChannelModels, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; import { CHANNEL_OPTIONS } from '../../constants'; const MODEL_MAPPING_EXAMPLE = { @@ -214,6 +214,7 @@ const EditChannel = () => { label='类型' name='type' required + search options={CHANNEL_OPTIONS} value={inputs.type} onChange={handleInputChange} @@ -342,6 +343,8 @@ const EditChannel = () => { required fluid multiple + search + onLabelClick={(e, { value }) => {copy(value).then()}} selection onChange={handleInputChange} value={inputs.models}