diff --git a/controller/log.go b/controller/log.go index ba043349..c0c3f69f 100644 --- a/controller/log.go +++ b/controller/log.go @@ -18,7 +18,8 @@ func GetAllLogs(c *gin.Context) { username := c.Query("username") tokenName := c.Query("token_name") modelName := c.Query("model_name") - logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) + channel, _ := strconv.Atoi(c.Query("channel")) + logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) if err != nil { c.JSON(200, gin.H{ "success": false, @@ -44,7 +45,8 @@ func GetUserLogs(c *gin.Context) { endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") modelName := c.Query("model_name") - logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) + channel, _ := strconv.Atoi(c.Query("channel")) + logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) if err != nil { c.JSON(200, gin.H{ "success": false, @@ -101,7 +103,8 @@ func GetLogsStat(c *gin.Context) { tokenName := c.Query("token_name") username := c.Query("username") modelName := c.Query("model_name") - quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) + channel, _ := strconv.Atoi(c.Query("channel")) + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") c.JSON(200, gin.H{ "success": true, @@ -120,7 +123,8 @@ func GetLogsSelfStat(c *gin.Context) { endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") modelName := c.Query("model_name") - quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) + channel, _ := strconv.Atoi(c.Query("channel")) + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) c.JSON(200, gin.H{ "success": true, diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 277ab404..fe8fac34 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -17,6 +17,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") + channelID := c.GetInt("channel_id") userId := c.GetInt("id") group := c.GetString("group") @@ -106,7 +107,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent) + model.RecordConsumeLog(userId, channelID, 0, 0, audioModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) diff --git a/controller/relay-image.go b/controller/relay-image.go index de623288..90092160 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -18,6 +18,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") + channelID := c.GetInt("channel_id") userId := c.GetInt("id") consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") @@ -137,7 +138,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) + model.RecordConsumeLog(userId, channelID, 0, 0, imageModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) diff --git a/controller/relay-text.go b/controller/relay-text.go index 624b9d01..0b92d427 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -36,6 +36,7 @@ func init() { func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { channelType := c.GetInt("channel") + channelID := c.GetInt("channel_id") tokenId := c.GetInt("token_id") userId := c.GetInt("id") consumeQuota := c.GetBool("consume_quota") @@ -360,7 +361,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } if quota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + model.RecordConsumeLog(userId, channelID, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateChannelUsedQuota(channelId, quota) diff --git a/model/ability.go b/model/ability.go index e87c3940..b7577b11 100644 --- a/model/ability.go +++ b/model/ability.go @@ -10,15 +10,16 @@ type Ability struct { Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` Enabled bool `json:"enabled"` + Priority *int64 `json:"priority" gorm:"bigint;default:0"` } func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { ability := Ability{} var err error = nil if common.UsingSQLite { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error + err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error } else { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error + err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error } if err != nil { return nil, err diff --git a/model/cache.go b/model/cache.go index 55fbba9b..92576165 100644 --- a/model/cache.go +++ b/model/cache.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand" "one-api/common" + "sort" "strconv" "strings" "sync" @@ -154,6 +155,17 @@ func InitChannelCache() { } } } + + // sort by priority + for group, model2channels := range newGroup2model2channels { + for model, channels := range model2channels { + sort.Slice(channels, func(i, j int) bool { + return channels[i].GetPriority() > channels[j].GetPriority() + }) + newGroup2model2channels[group][model] = channels + } + } + channelSyncLock.Lock() group2model2channels = newGroup2model2channels channelSyncLock.Unlock() @@ -178,6 +190,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error if len(channels) == 0 { return nil, errors.New("channel not found") } + // choose by priority + firstChannel := channels[0] + if firstChannel.GetPriority() > 0 { + return firstChannel, nil + } idx := rand.Intn(len(channels)) return channels[idx], nil } diff --git a/model/channel.go b/model/channel.go index 7cc9fa9b..1f5d9571 100644 --- a/model/channel.go +++ b/model/channel.go @@ -23,6 +23,14 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` + Priority *int64 `json:"priority" gorm:"bigint;default:0"` +} + +func (c *Channel) GetPriority() int64 { + if c.Priority == nil { + return 0 + } + return *c.Priority } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { diff --git a/model/log.go b/model/log.go index b0d6409a..aa659d39 100644 --- a/model/log.go +++ b/model/log.go @@ -17,6 +17,7 @@ type Log struct { Quota int `json:"quota" gorm:"default:0"` PromptTokens int `json:"prompt_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"` + Channel int `json:"channel" gorm:"default:0"` } const ( @@ -44,7 +45,7 @@ func RecordLog(userId int, logType int, content string) { } } -func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { +func RecordConsumeLog(userId int, channelID, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { if !common.LogConsumeEnabled { return } @@ -59,6 +60,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN TokenName: tokenName, ModelName: modelName, Quota: quota, + Channel: channelID, } err := DB.Create(log).Error if err != nil { @@ -66,7 +68,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN } } -func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) { +func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { tx = DB @@ -88,11 +90,14 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName if endTimestamp != 0 { tx = tx.Where("created_at <= ?", endTimestamp) } + if channel != 0 { + tx = tx.Where("channel = ?", channel) + } err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error return logs, err } -func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { +func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { tx = DB.Where("user_id = ?", userId) @@ -111,6 +116,9 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int if endTimestamp != 0 { tx = tx.Where("created_at <= ?", endTimestamp) } + if channel != 0 { + tx = tx.Where("channel = ?", channel) + } err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error return logs, err } @@ -125,7 +133,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { return logs, err } -func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) { +func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { tx := DB.Table("logs").Select("sum(quota)") if username != "" { tx = tx.Where("username = ?", username) @@ -142,6 +150,9 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa if modelName != "" { tx = tx.Where("model_name = ?", modelName) } + if channel != 0 { + tx = tx.Where("channel = ?", channel) + } tx.Where("type = ?", LogTypeConsume).Scan("a) return quota } diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 5eb39783..708d9daf 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -1,5 +1,5 @@ import React, { useEffect, useState } from 'react'; -import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; +import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react'; import { Link } from 'react-router-dom'; import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; @@ -96,7 +96,7 @@ const ChannelsTable = () => { }); }, []); - const manageChannel = async (id, action, idx) => { + const manageChannel = async (id, action, idx, priority) => { let data = { id }; let res; switch (action) { @@ -111,6 +111,13 @@ const ChannelsTable = () => { data.status = 2; res = await API.put('/api/channel/', data); break; + case 'priority': + if (priority === '') { + return; + } + data.priority = parseInt(priority); + res = await API.put('/api/channel/', data); + break; } const { success, message } = res.data; if (success) { @@ -334,6 +341,14 @@ const ChannelsTable = () => { > 余额 + { + sortChannel('priority'); + }} + > + 优先级 + 操作 @@ -372,6 +387,22 @@ const ChannelsTable = () => { basic /> + + { + manageChannel( + channel.id, + 'priority', + idx, + event.target.value, + ); + }}> + + } + content='输入优先级,越高越优先' + basic + /> +