feat: 支持设置渠道优先级 & 日志中显示使用的渠道ID

This commit is contained in:
Xiangyuan Liu 2023-08-30 19:12:41 +08:00
parent 56b5007379
commit 5cee70c9a4
No known key found for this signature in database
GPG Key ID: A2FCA5A37FC0F888
10 changed files with 110 additions and 21 deletions

View File

@ -18,7 +18,8 @@ func GetAllLogs(c *gin.Context) {
username := c.Query("username")
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
if err != nil {
c.JSON(200, gin.H{
"success": false,
@ -44,7 +45,8 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
if err != nil {
c.JSON(200, gin.H{
"success": false,
@ -101,7 +103,8 @@ func GetLogsStat(c *gin.Context) {
tokenName := c.Query("token_name")
username := c.Query("username")
modelName := c.Query("model_name")
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
channel, _ := strconv.Atoi(c.Query("channel"))
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
c.JSON(200, gin.H{
"success": true,
@ -120,7 +123,8 @@ func GetLogsSelfStat(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
channel, _ := strconv.Atoi(c.Query("channel"))
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
c.JSON(200, gin.H{
"success": true,

View File

@ -17,6 +17,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelID := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
@ -106,7 +107,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
model.RecordConsumeLog(userId, channelID, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

View File

@ -18,6 +18,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelID := c.GetInt("channel_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
@ -137,7 +138,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
model.RecordConsumeLog(userId, channelID, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

View File

@ -36,6 +36,7 @@ func init() {
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
channelID := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
@ -360,7 +361,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.RecordConsumeLog(userId, channelID, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)

View File

@ -10,15 +10,16 @@ type Ability struct {
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error
} else {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error
}
if err != nil {
return nil, err

View File

@ -6,6 +6,7 @@ import (
"fmt"
"math/rand"
"one-api/common"
"sort"
"strconv"
"strings"
"sync"
@ -154,6 +155,17 @@ func InitChannelCache() {
}
}
}
// sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetPriority() > channels[j].GetPriority()
})
newGroup2model2channels[group][model] = channels
}
}
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
@ -178,6 +190,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
// choose by priority
firstChannel := channels[0]
if firstChannel.GetPriority() > 0 {
return firstChannel, nil
}
idx := rand.Intn(len(channels))
return channels[idx], nil
}

View File

@ -23,6 +23,14 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
}
func (c *Channel) GetPriority() int64 {
if c.Priority == nil {
return 0
}
return *c.Priority
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {

View File

@ -17,6 +17,7 @@ type Log struct {
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
Channel int `json:"channel" gorm:"default:0"`
}
const (
@ -44,7 +45,7 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
func RecordConsumeLog(userId int, channelID, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
if !common.LogConsumeEnabled {
return
}
@ -59,6 +60,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
Channel: channelID,
}
err := DB.Create(log).Error
if err != nil {
@ -66,7 +68,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB
@ -88,11 +90,14 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
if channel != 0 {
tx = tx.Where("channel = ?", channel)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err
}
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB.Where("user_id = ?", userId)
@ -111,6 +116,9 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
if channel != 0 {
tx = tx.Where("channel = ?", channel)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
return logs, err
}
@ -125,7 +133,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
return logs, err
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
tx := DB.Table("logs").Select("sum(quota)")
if username != "" {
tx = tx.Where("username = ?", username)
@ -142,6 +150,9 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
}
if channel != 0 {
tx = tx.Where("channel = ?", channel)
}
tx.Where("type = ?", LogTypeConsume).Scan(&quota)
return quota
}

View File

@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
@ -96,7 +96,7 @@ const ChannelsTable = () => {
});
}, []);
const manageChannel = async (id, action, idx) => {
const manageChannel = async (id, action, idx, priority) => {
let data = { id };
let res;
switch (action) {
@ -111,6 +111,13 @@ const ChannelsTable = () => {
data.status = 2;
res = await API.put('/api/channel/', data);
break;
case 'priority':
if (priority === '') {
return;
}
data.priority = parseInt(priority);
res = await API.put('/api/channel/', data);
break;
}
const { success, message } = res.data;
if (success) {
@ -334,6 +341,14 @@ const ChannelsTable = () => {
>
余额
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortChannel('priority');
}}
>
优先级
</Table.HeaderCell>
<Table.HeaderCell>操作</Table.HeaderCell>
</Table.Row>
</Table.Header>
@ -372,6 +387,22 @@ const ChannelsTable = () => {
basic
/>
</Table.Cell>
<Table.Cell>
<Popup
trigger={<Input type="number" defaultValue={channel.priority} onBlur={(event) => {
manageChannel(
channel.id,
'priority',
idx,
event.target.value,
);
}}>
<input style={{maxWidth:'60px'}} />
</Input>}
content='输入优先级,越高越优先'
basic
/>
</Table.Cell>
<Table.Cell>
<div>
<Button

View File

@ -56,9 +56,10 @@ const LogsTable = () => {
token_name: '',
model_name: '',
start_timestamp: timestamp2string(0),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: 0,
});
const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
const [stat, setStat] = useState({
quota: 0,
@ -72,7 +73,7 @@ const LogsTable = () => {
const getLogSelfStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
const { success, message, data } = res.data;
if (success) {
setStat(data);
@ -84,7 +85,7 @@ const LogsTable = () => {
const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
const { success, message, data } = res.data;
if (success) {
setStat(data);
@ -109,9 +110,9 @@ const LogsTable = () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
if (isAdminUser) {
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
} else {
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
}
const res = await API.get(url);
const { success, message, data } = res.data;
@ -217,6 +218,9 @@ const LogsTable = () => {
<Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值'
name='model_name'
onChange={handleInputChange} />
<Form.Input fluid label='渠道' width={isAdminUser ? 2 : 3} value={channel} placeholder='可选值'
name='channel'
onChange={handleInputChange} />
<Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
name='start_timestamp'
onChange={handleInputChange} />
@ -276,6 +280,15 @@ const LogsTable = () => {
>
模型
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('channel');
}}
width={2}
>
渠道
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
@ -334,6 +347,7 @@ const LogsTable = () => {
<Table.Cell>{log.token_name ? <Label basic>{log.token_name}</Label> : ''}</Table.Cell>
<Table.Cell>{renderType(log.type)}</Table.Cell>
<Table.Cell>{log.model_name ? <Label basic>{log.model_name}</Label> : ''}</Table.Cell>
<Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
<Table.Cell>{log.prompt_tokens ? log.prompt_tokens : ''}</Table.Cell>
<Table.Cell>{log.completion_tokens ? log.completion_tokens : ''}</Table.Cell>
<Table.Cell>{log.quota ? renderQuota(log.quota, 6) : ''}</Table.Cell>