diff --git a/controller/log.go b/controller/log.go index b65867fe..44f1c053 100644 --- a/controller/log.go +++ b/controller/log.go @@ -6,8 +6,23 @@ import ( "one-api/common" "one-api/model" "strconv" + "strings" ) +func parseIntArray(input string) []int { + values := strings.Split(input, ",") + result := make([]int, 0) + + for _, value := range values { + num, err := strconv.Atoi(strings.TrimSpace(value)) + if err == nil { + result = append(result, num) + } + } + + return result +} + func GetAllLogs(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) if p < 0 { @@ -19,8 +34,8 @@ func GetAllLogs(c *gin.Context) { username := c.Query("username") tokenName := c.Query("token_name") modelName := c.Query("model_name") - channel, _ := strconv.Atoi(c.Query("channel")) - logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) + channels := parseIntArray(c.Query("channel")) + logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channels) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -107,8 +122,8 @@ func GetLogsStat(c *gin.Context) { tokenName := c.Query("token_name") username := c.Query("username") modelName := c.Query("model_name") - channel, _ := strconv.Atoi(c.Query("channel")) - quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) + channels := parseIntArray(c.Query("channel")) + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channels) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") c.JSON(http.StatusOK, gin.H{ "success": true, @@ -128,8 +143,8 @@ func GetLogsSelfStat(c *gin.Context) { endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") modelName := c.Query("model_name") - channel, _ := strconv.Atoi(c.Query("channel")) - quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) + channels := parseIntArray(c.Query("channel")) + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channels) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/model/log.go b/model/log.go index 307928c4..76bd4e26 100644 --- a/model/log.go +++ b/model/log.go @@ -72,7 +72,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke } } -func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { +func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channels []int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { tx = DB @@ -94,9 +94,12 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName if endTimestamp != 0 { tx = tx.Where("created_at <= ?", endTimestamp) } - if channel != 0 { - tx = tx.Where("channel_id = ?", channel) + if len(channels) > 1 { + tx = tx.Where("channel_id IN ?", channels) + } else if len(channels) == 1 && channels[0] != 0 { + tx = tx.Where("channel_id = ?", channels[0]) } + err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error return logs, err } @@ -134,7 +137,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { return logs, err } -func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { +func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channels []int) (quota int) { tx := DB.Table("logs").Select("ifnull(sum(quota),0)") if username != "" { tx = tx.Where("username = ?", username) @@ -151,8 +154,10 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa if modelName != "" { tx = tx.Where("model_name = ?", modelName) } - if channel != 0 { - tx = tx.Where("channel_id = ?", channel) + if len(channels) > 1 { + tx = tx.Where("channel_id IN ?", channels) + } else if len(channels) == 1 && channels[0] != 0 { + tx = tx.Where("channel_id = ?", channels[0]) } tx.Where("type = ?", LogTypeConsume).Scan("a) return quota