diff --git a/common/model-ratio.go b/common/model-ratio.go index 92b262e6..0b0a0f8a 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -17,8 +17,3 @@ var DalleGenerationImageAmounts = map[string][2]int{ "dall-e-2": {1, 10}, "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. } - -var DalleImagePromptLengthLimitations = map[string]int{ - "dall-e-2": 1000, - "dall-e-3": 4000, -} diff --git a/common/notify/channel/email.go b/common/notify/channel/email.go index 05425b1a..05613a5a 100644 --- a/common/notify/channel/email.go +++ b/common/notify/channel/email.go @@ -25,7 +25,7 @@ func (e *Email) Name() string { return "Email" } -func (e *Email) Send(ctx context.Context, title, message string) error { +func (e *Email) Send(_ context.Context, title, message string) error { to := e.To if to == "" { to = config.RootUserEmail diff --git a/common/token.go b/common/token.go index 8d5090c3..ba1d985e 100644 --- a/common/token.go +++ b/common/token.go @@ -148,7 +148,7 @@ const ( // https://platform.openai.com/docs/guides/vision/calculating-costs // https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb func countImageTokens(url string, detail string) (_ int, err error) { - var fetchSize = true + // var fetchSize = true var width, height int // Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding // detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting. @@ -183,11 +183,9 @@ func countImageTokens(url string, detail string) (_ int, err error) { case "low": return lowDetailCost, nil case "high": - if fetchSize { - width, height, err = image.GetImageSize(url) - if err != nil { - return 0, err - } + width, height, err = image.GetImageSize(url) + if err != nil { + return 0, err } if width > 2048 || height > 2048 { // max(width, height) > 2048 ratio := float64(2048) / math.Max(float64(width), float64(height)) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 2f3254d1..3b3ac0c9 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -80,7 +80,7 @@ func UpdateChannelBalance(c *gin.Context) { }) return } - channel, err := model.GetChannelById(id, true) + channel, err := model.GetChannelById(id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/channel-test.go b/controller/channel-test.go index 677d48c7..47d17b6b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -100,7 +100,7 @@ func TestChannel(c *gin.Context) { }) return } - channel, err := model.GetChannelById(id, true) + channel, err := model.GetChannelById(id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/channel.go b/controller/channel.go index 0f401a45..7de7f115 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -40,7 +40,7 @@ func GetChannel(c *gin.Context) { }) return } - channel, err := model.GetChannelById(id, false) + channel, err := model.GetChannelById(id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/common.go b/controller/common.go index 2a3e17f3..943c173f 100644 --- a/controller/common.go +++ b/controller/common.go @@ -96,18 +96,6 @@ func EnableChannel(channelId int, channelName string, sendNotify bool) { notify.Send(subject, content) } -func RelayNotImplemented(c *gin.Context) { - err := types.OpenAIError{ - Message: "API not implemented", - Type: "one_api_error", - Param: "", - Code: "api_not_implemented", - } - c.JSON(http.StatusNotImplemented, gin.H{ - "error": err, - }) -} - func RelayNotFound(c *gin.Context) { err := types.OpenAIError{ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), diff --git a/controller/log.go b/controller/log.go index bc30fea7..2ab8e8f5 100644 --- a/controller/log.go +++ b/controller/log.go @@ -49,50 +49,15 @@ func GetUserLogsList(c *gin.Context) { }) } -func SearchAllLogs(c *gin.Context) { - keyword := c.Query("keyword") - logs, err := model.SearchAllLogs(keyword) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": logs, - }) -} - -func SearchUserLogs(c *gin.Context) { - keyword := c.Query("keyword") - userId := c.GetInt("id") - logs, err := model.SearchUserLogs(userId, keyword) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": logs, - }) -} - func GetLogsStat(c *gin.Context) { - logType, _ := strconv.Atoi(c.Query("type")) + // logType, _ := strconv.Atoi(c.Query("type")) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") username := c.Query("username") modelName := c.Query("model_name") channel, _ := strconv.Atoi(c.Query("channel")) - quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) + quotaNum := model.SumUsedQuota(startTimestamp, endTimestamp, modelName, username, tokenName, channel) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") c.JSON(http.StatusOK, gin.H{ "success": true, @@ -106,13 +71,13 @@ func GetLogsStat(c *gin.Context) { func GetLogsSelfStat(c *gin.Context) { username := c.GetString("username") - logType, _ := strconv.Atoi(c.Query("type")) + // logType, _ := strconv.Atoi(c.Query("type")) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") modelName := c.Query("model_name") channel, _ := strconv.Atoi(c.Query("channel")) - quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) + quotaNum := model.SumUsedQuota(startTimestamp, endTimestamp, modelName, username, tokenName, channel) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/controller/token.go b/controller/token.go index 43eed435..4da6f2f2 100644 --- a/controller/token.go +++ b/controller/token.go @@ -90,30 +90,6 @@ func GetPlaygroundToken(c *gin.Context) { }) } -func GetTokenStatus(c *gin.Context) { - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - token, err := model.GetTokenByIds(tokenId, userId) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - expiredAt := token.ExpiredTime - if expiredAt == -1 { - expiredAt = 0 - } - c.JSON(http.StatusOK, gin.H{ - "object": "credit_summary", - "total_granted": token.RemainQuota, - "total_used": 0, // not supported currently - "total_available": token.RemainQuota, - "expires_at": expiredAt * 1000, - }) -} - func AddToken(c *gin.Context) { token := model.Token{} err := c.ShouldBindJSON(&token) diff --git a/controller/user.go b/controller/user.go index 82b4a556..2bf44122 100644 --- a/controller/user.go +++ b/controller/user.go @@ -476,32 +476,6 @@ func DeleteUser(c *gin.Context) { } } -func DeleteSelf(c *gin.Context) { - id := c.GetInt("id") - user, _ := model.GetUserById(id, false) - - if user.Role == config.RoleRootUser { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "不能删除超级管理员账户", - }) - return - } - - err := model.DeleteUserById(id) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - }) -} - func CreateUser(c *gin.Context) { var user model.User err := json.NewDecoder(c.Request.Body).Decode(&user) diff --git a/cron/main.go b/cron/main.go index 198d7540..d80b4d85 100644 --- a/cron/main.go +++ b/cron/main.go @@ -3,7 +3,6 @@ package cron import ( "one-api/common/logger" "one-api/model" - "time" "github.com/go-co-op/gocron/v2" ) @@ -23,7 +22,7 @@ func InitCron() { gocron.NewAtTime(0, 5, 0), )), gocron.NewTask(func() { - model.RemoveChatCache(time.Now().Unix()) + model.RemoveChatCache() logger.SysLog("删除过期缓存数据") }), ) diff --git a/model/ability.go b/model/ability.go index 2b973e80..8c73c5e7 100644 --- a/model/ability.go +++ b/model/ability.go @@ -15,48 +15,6 @@ type Ability struct { Weight *uint `json:"weight" gorm:"default:1"` } -func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { - ability := Ability{} - groupCol := "`group`" - trueVal := "1" - if common.UsingPostgreSQL { - groupCol = `"group"` - trueVal = "true" - } - - var err error = nil - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) - channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) - if common.UsingSQLite || common.UsingPostgreSQL { - err = channelQuery.Order("RANDOM()").First(&ability).Error - } else { - err = channelQuery.Order("RAND()").First(&ability).Error - } - if err != nil { - return nil, err - } - channel := Channel{} - channel.Id = ability.ChannelId - err = DB.First(&channel, "id = ?", ability.ChannelId).Error - return &channel, err -} - -func GetGroupModels(group string) ([]string, error) { - var models []string - groupCol := "`group`" - trueVal := "1" - if common.UsingPostgreSQL { - groupCol = `"group"` - trueVal = "true" - } - - err := DB.Model(&Ability{}).Where(groupCol+" = ? and enabled = ? ", group, trueVal).Distinct("model").Pluck("model", &models).Error - if err != nil { - return nil, err - } - return models, nil -} - func (channel *Channel) AddAbilities() error { models_ := strings.Split(channel.Models, ",") groups_ := strings.Split(channel.Group, ",") @@ -102,17 +60,6 @@ func UpdateAbilityStatus(channelId int, status bool) error { return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error } -func GetEnabledAbility() ([]*Ability, error) { - trueVal := "1" - if common.UsingPostgreSQL { - trueVal = "true" - } - - var abilities []*Ability - err := DB.Where("enabled = ?", trueVal).Order("priority desc, weight desc").Find(&abilities).Error - return abilities, err -} - type AbilityChannelGroup struct { Group string `json:"group"` Model string `json:"model"` diff --git a/model/channel.go b/model/channel.go index ab38cefc..b7b01d58 100644 --- a/model/channel.go +++ b/model/channel.go @@ -100,7 +100,7 @@ func GetAllChannels() ([]*Channel, error) { return channels, err } -func GetChannelById(id int, selectAll bool) (*Channel, error) { +func GetChannelById(id int) (*Channel, error) { channel := Channel{Id: id} var err error = nil err = DB.First(&channel, "id = ?", id).Error @@ -312,11 +312,6 @@ func updateChannelUsedQuota(id int, quota int) { } } -func DeleteChannelByStatus(status int64) (int64, error) { - result := DB.Where("status = ?", status).Delete(&Channel{}) - return result.RowsAffected, result.Error -} - func DeleteDisabledChannel() (int64, error) { result := DB.Where("status = ? or status = ?", config.ChannelStatusAutoDisabled, config.ChannelStatusManuallyDisabled).Delete(&Channel{}) // 同时删除Ability diff --git a/model/chat_cache.go b/model/chat_cache.go index ae33954f..5038e127 100644 --- a/model/chat_cache.go +++ b/model/chat_cache.go @@ -35,7 +35,7 @@ func GetChatCacheListByUserId(userId int) ([]*ChatCache, error) { return chatCaches, err } -func RemoveChatCache(expiration int64) error { +func RemoveChatCache() error { now := time.Now().Unix() return DB.Where("expiration < ?", now).Delete(ChatCache{}).Error } diff --git a/model/log.go b/model/log.go index 9bac2cc7..57ee33d5 100644 --- a/model/log.go +++ b/model/log.go @@ -165,7 +165,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { return logs, err } -func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { +func SumUsedQuota(startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { tx := DB.Table("logs").Select(assembleSumSelectStr("quota")) if username != "" { tx = tx.Where("username = ?", username) @@ -189,27 +189,6 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa return quota } -func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { - tx := DB.Table("logs").Select(assembleSumSelectStr("prompt_tokens") + " + " + assembleSumSelectStr("completion_tokens")) - if username != "" { - tx = tx.Where("username = ?", username) - } - if tokenName != "" { - tx = tx.Where("token_name = ?", tokenName) - } - if startTimestamp != 0 { - tx = tx.Where("created_at >= ?", startTimestamp) - } - if endTimestamp != 0 { - tx = tx.Where("created_at <= ?", endTimestamp) - } - if modelName != "" { - tx = tx.Where("model_name = ?", modelName) - } - tx.Where("type = ?", LogTypeConsume).Scan(&token) - return token -} - func DeleteOldLog(targetTimestamp int64) (int64, error) { result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) return result.RowsAffected, result.Error diff --git a/model/midjourney.go b/model/midjourney.go index a300db3c..1d3c9a70 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -148,19 +148,6 @@ func GetByMJIds(userId int, mjIds []string) []*Midjourney { return mj } -func GetMjByuId(id int) *Midjourney { - var mj *Midjourney - err := DB.Where("id = ?", id).First(&mj).Error - if err != nil { - return nil - } - return mj -} - -func UpdateProgress(id int, progress string) error { - return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error -} - func (midjourney *Midjourney) Insert() error { return DB.Create(midjourney).Error } diff --git a/model/token.go b/model/token.go index a6281af0..44e49fbb 100644 --- a/model/token.go +++ b/model/token.go @@ -48,12 +48,6 @@ func GetUserTokensList(userId int, params *GenericParams) (*DataResult[Token], e return PaginateAndOrder(db, ¶ms.PaginationParams, &tokens, allowedTokenOrderFields) } -// 获取状态为可用的令牌 -func GetUserEnabledTokens(userId int) (tokens []*Token, err error) { - err = DB.Where("user_id = ? and status = ?", userId, config.TokenStatusEnabled).Find(&tokens).Error - return tokens, err -} - func ValidateUserToken(key string) (token *Token, err error) { if key == "" { return nil, errors.New("未提供令牌") diff --git a/model/user.go b/model/user.go index 2393ac5e..380dce5d 100644 --- a/model/user.go +++ b/model/user.go @@ -285,10 +285,6 @@ func IsTelegramIdAlreadyTaken(telegramId int64) bool { return DB.Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1 } -func IsUsernameAlreadyTaken(username string) bool { - return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 -} - func ResetUserPasswordByEmail(email string, password string) error { if email == "" || password == "" { return errors.New("邮箱地址或密码为空!") @@ -348,11 +344,6 @@ func GetUserUsedQuota(id int) (quota int, err error) { return quota, err } -func GetUserEmail(id int) (email string, err error) { - err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error - return email, err -} - func GetUserGroup(id int) (group string, err error) { groupCol := "`group`" if common.UsingPostgreSQL { diff --git a/providers/azureSpeech/base.go b/providers/azureSpeech/base.go index 52bc46ac..af3b6822 100644 --- a/providers/azureSpeech/base.go +++ b/providers/azureSpeech/base.go @@ -28,7 +28,7 @@ type AzureSpeechProvider struct { base.BaseProvider } -func (p *AzureSpeechProvider) GetFullRequestURL(requestURL string, modelName string) string { +func (p *AzureSpeechProvider) GetFullRequestURL(requestURL string) string { baseURL := "" if p.Channel.Other != "" { baseURL = fmt.Sprintf("https://%s.tts.speech.microsoft.com", p.Channel.Other) diff --git a/providers/azureSpeech/speech.go b/providers/azureSpeech/speech.go index 2210e5d6..695c1581 100644 --- a/providers/azureSpeech/speech.go +++ b/providers/azureSpeech/speech.go @@ -87,7 +87,7 @@ func (p *AzureSpeechProvider) CreateSpeech(request *types.SpeechAudioRequest) (* if errWithCode != nil { return nil, errWithCode } - fullRequestURL := p.GetFullRequestURL(url, request.Model) + fullRequestURL := p.GetFullRequestURL(url) headers := p.GetRequestHeaders() responseFormatr := outputFormatMap[request.ResponseFormat] if responseFormatr == "" { diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go index 6b4551fa..92f44b1c 100644 --- a/providers/baidu/chat.go +++ b/providers/baidu/chat.go @@ -210,7 +210,7 @@ func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string return } - h.convertToOpenaiStream(&baiduResponse, dataChan, errChan) + h.convertToOpenaiStream(&baiduResponse, dataChan) if baiduResponse.IsEnd { errChan <- io.EOF @@ -219,7 +219,7 @@ func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string } } -func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, dataChan chan string, errChan chan error) { +func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, dataChan chan string) { choice := types.ChatCompletionStreamChoice{ Index: 0, Delta: types.ChatCompletionStreamChoiceDelta{ diff --git a/providers/base/common.go b/providers/base/common.go index 9b952b64..5dc8f502 100644 --- a/providers/base/common.go +++ b/providers/base/common.go @@ -48,7 +48,7 @@ func (p *BaseProvider) GetBaseURL() string { } // 获取完整请求URL -func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string { +func (p *BaseProvider) GetFullRequestURL(requestURL string, _ string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") return fmt.Sprintf("%s%s", baseURL, requestURL) diff --git a/providers/bedrock/sigv4/header.go b/providers/bedrock/sigv4/header.go index 8db400bd..cf7511b4 100644 --- a/providers/bedrock/sigv4/header.go +++ b/providers/bedrock/sigv4/header.go @@ -58,10 +58,6 @@ var requiredHeaders = map[string]struct{}{ "X-Amz-Tagging": {}, } -// headerPredicate is a function that evaluates whether a header is of the -// specific type. For example, whether a header should be ignored during signing. -type headerPredicate func(header string) bool - // isIgnoredHeader returns true if header must be ignored during signing. func isIgnoredHeader(header string) bool { _, ok := ignoreHeaders[header] diff --git a/providers/bedrock/sigv4/helper.go b/providers/bedrock/sigv4/helper.go index 5ee90771..9e468eb9 100644 --- a/providers/bedrock/sigv4/helper.go +++ b/providers/bedrock/sigv4/helper.go @@ -9,7 +9,6 @@ import ( "net/url" "sort" "strings" - "time" ) var ( @@ -47,17 +46,6 @@ func hasPrefixFold(s, prefix string) bool { strings.EqualFold(s[0:len(prefix)], prefix) } -// isSameDay returns true if a and b are the same date (dd-mm-yyyy). -func isSameDay(a, b time.Time) bool { - xYear, xMonth, xDay := a.Date() - yYear, yMonth, yDay := b.Date() - - if xYear != yYear || xMonth != yMonth { - return false - } - return xDay == yDay -} - // hostOrURLHost returns r.Host, or if empty, r.URL.Host. func hostOrURLHost(r *http.Request) string { if r.Host != "" { @@ -271,32 +259,3 @@ func writeCanonicalString(w *bufio.Writer, s string) { w.WriteByte(s[i]) } } - -type debugHasher struct { - buf []byte -} - -func (dh *debugHasher) Write(b []byte) (int, error) { - dh.buf = append(dh.buf, b...) - return len(b), nil -} - -func (dh *debugHasher) Sum(b []byte) []byte { - return nil -} - -func (dh *debugHasher) Reset() { - // do nothing -} - -func (dh *debugHasher) Size() int { - return 0 -} - -func (dh *debugHasher) BlockSize() int { - return sha256.BlockSize -} - -func (dh *debugHasher) Println() { - fmt.Printf("---%s---\n", dh.buf) -} diff --git a/providers/claude/base.go b/providers/claude/base.go index ce3a7d88..b1f2aead 100644 --- a/providers/claude/base.go +++ b/providers/claude/base.go @@ -73,7 +73,7 @@ func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) { return headers } -func (p *ClaudeProvider) GetFullRequestURL(requestURL string, modelName string) string { +func (p *ClaudeProvider) GetFullRequestURL(requestURL string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { requestURL = strings.TrimPrefix(requestURL, "/v1") diff --git a/providers/claude/chat.go b/providers/claude/chat.go index 7024ab20..3ebe803f 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -71,7 +71,7 @@ func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (* } // 获取请求地址 - fullRequestURL := p.GetFullRequestURL(url, request.Model) + fullRequestURL := p.GetFullRequestURL(url) if fullRequestURL == "" { return nil, common.ErrorWrapper(nil, "invalid_claude_config", http.StatusInternalServerError) } diff --git a/providers/hunyuan/sign.go b/providers/hunyuan/sign.go index ee1bb60c..2a4d8e58 100644 --- a/providers/hunyuan/sign.go +++ b/providers/hunyuan/sign.go @@ -28,7 +28,7 @@ func hmacsha256(s, key string) string { func (p *HunyuanProvider) sign(body any, action, method string) (*http.Request, *types.OpenAIErrorWithStatusCode) { service := "hunyuan" version := "2023-09-01" - region := "" + // region := "" host := strings.Replace(p.GetBaseURL(), "https://", "", 1) algorithm := "TC3-HMAC-SHA256" var timestamp = time.Now().Unix() @@ -88,9 +88,9 @@ func (p *HunyuanProvider) sign(body any, action, method string) (*http.Request, "Content-Type": contentType, "Authorization": authorization, } - if region != "" { - headers["X-TC-Region"] = region - } + // if region != "" { + // headers["X-TC-Region"] = region + // } req, err := p.Requester.NewRequest(method, p.GetBaseURL(), p.Requester.WithBody(body), p.Requester.WithHeader(headers)) if err != nil { diff --git a/providers/midjourney/constant.go b/providers/midjourney/constant.go index 770ee685..09015b6b 100644 --- a/providers/midjourney/constant.go +++ b/providers/midjourney/constant.go @@ -49,21 +49,3 @@ const ( MjActionPan = "PAN" MjActionSwapFace = "SWAP_FACE" ) - -var MidjourneyModel2Action = map[string]string{ - "mj_imagine": MjActionImagine, - "mj_describe": MjActionDescribe, - "mj_blend": MjActionBlend, - "mj_upscale": MjActionUpscale, - "mj_variation": MjActionVariation, - "mj_reroll": MjActionReRoll, - "mj_modal": MjActionModal, - "mj_inpaint": MjActionInPaint, - "mj_zoom": MjActionZoom, - "mj_custom_zoom": MjActionCustomZoom, - "mj_shorten": MjActionShorten, - "mj_high_variation": MjActionHighVariation, - "mj_low_variation": MjActionLowVariation, - "mj_pan": MjActionPan, - "swap_face": MjActionSwapFace, -} diff --git a/providers/minimax/base.go b/providers/minimax/base.go index 9057daa3..4f6ac6bb 100644 --- a/providers/minimax/base.go +++ b/providers/minimax/base.go @@ -59,7 +59,7 @@ func errorHandle(minimaxError *BaseResp) *types.OpenAIError { } } -func (p *MiniMaxProvider) GetFullRequestURL(requestURL string, modelName string) string { +func (p *MiniMaxProvider) GetFullRequestURL(requestURL string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") keys := strings.Split(p.Channel.Key, "|") if len(keys) != 2 { diff --git a/providers/minimax/chat.go b/providers/minimax/chat.go index f2a612c0..fcc83422 100644 --- a/providers/minimax/chat.go +++ b/providers/minimax/chat.go @@ -62,7 +62,7 @@ func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) ( } // 获取请求地址 - fullRequestURL := p.GetFullRequestURL(url, request.Model) + fullRequestURL := p.GetFullRequestURL(url) if fullRequestURL == "" { return nil, common.ErrorWrapper(errors.New("API KEY is filled in incorrectly"), "invalid_minimax_config", http.StatusInternalServerError) } diff --git a/providers/minimax/embeddings.go b/providers/minimax/embeddings.go index f1a2f9e7..d2706951 100644 --- a/providers/minimax/embeddings.go +++ b/providers/minimax/embeddings.go @@ -13,7 +13,7 @@ func (p *MiniMaxProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*ty return nil, errWithCode } // 获取请求地址 - fullRequestURL := p.GetFullRequestURL(url, request.Model) + fullRequestURL := p.GetFullRequestURL(url) if fullRequestURL == "" { return nil, common.ErrorWrapper(nil, "invalid_minimax_config", http.StatusInternalServerError) } diff --git a/providers/openai/transcriptions.go b/providers/openai/transcriptions.go index b24dd9fa..b2fba15a 100644 --- a/providers/openai/transcriptions.go +++ b/providers/openai/transcriptions.go @@ -165,7 +165,7 @@ func getTextContent(text, format string) string { func extractTextFromVTT(vttContent string) string { scanner := bufio.NewScanner(strings.NewReader(vttContent)) re := regexp.MustCompile(`\d{2}:\d{2}:\d{2}\.\d{3} --> \d{2}:\d{2}:\d{2}\.\d{3}`) - text := []string{} + var text []string isStart := true for scanner.Scan() { @@ -185,7 +185,7 @@ func extractTextFromVTT(vttContent string) string { func extractTextFromSRT(srtContent string) string { scanner := bufio.NewScanner(strings.NewReader(srtContent)) re := regexp.MustCompile(`\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3}`) - text := []string{} + var text []string isContent := false for scanner.Scan() { diff --git a/providers/palm/base.go b/providers/palm/base.go index 46cc851e..26f54581 100644 --- a/providers/palm/base.go +++ b/providers/palm/base.go @@ -69,7 +69,7 @@ func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) { } // 获取完整请求 URL -func (p *PalmProvider) GetFullRequestURL(requestURL string, modelName string) string { +func (p *PalmProvider) GetFullRequestURL(requestURL string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") return fmt.Sprintf("%s%s", baseURL, requestURL) diff --git a/providers/palm/chat.go b/providers/palm/chat.go index 754e7d91..36672f42 100644 --- a/providers/palm/chat.go +++ b/providers/palm/chat.go @@ -61,7 +61,7 @@ func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*ht return nil, errWithCode } // 获取请求地址 - fullRequestURL := p.GetFullRequestURL(url, request.Model) + fullRequestURL := p.GetFullRequestURL(url) if fullRequestURL == "" { return nil, common.ErrorWrapper(nil, "invalid_palm_config", http.StatusInternalServerError) } diff --git a/providers/xunfei/base.go b/providers/xunfei/base.go index 1145cd46..7f2eabaf 100644 --- a/providers/xunfei/base.go +++ b/providers/xunfei/base.go @@ -63,7 +63,7 @@ func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) { } // 获取完整请求 URL -func (p *XunfeiProvider) GetFullRequestURL(requestURL string, modelName string) string { +func (p *XunfeiProvider) GetFullRequestURL(modelName string) string { splits := strings.Split(p.Channel.Key, "|") if len(splits) != 3 { return "" diff --git a/providers/xunfei/chat.go b/providers/xunfei/chat.go index fc740076..a72d9d83 100644 --- a/providers/xunfei/chat.go +++ b/providers/xunfei/chat.go @@ -60,12 +60,12 @@ func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletio } func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) + _, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } - authUrl := p.GetFullRequestURL(url, request.Model) + authUrl := p.GetFullRequestURL(request.Model) wsConn, err := p.wsRequester.NewRequest(authUrl, nil) if err != nil { diff --git a/providers/zhipu/base.go b/providers/zhipu/base.go index f5e30262..e593a3a6 100644 --- a/providers/zhipu/base.go +++ b/providers/zhipu/base.go @@ -77,7 +77,7 @@ func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) { } // 获取完整请求 URL -func (p *ZhipuProvider) GetFullRequestURL(requestURL string, modelName string) string { +func (p *ZhipuProvider) GetFullRequestURL(requestURL string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") return fmt.Sprintf("%s%s", baseURL, requestURL) diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index 536f68d9..fabe557a 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -61,7 +61,7 @@ func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*h } // 获取请求地址 - fullRequestURL := p.GetFullRequestURL(url, request.Model) + fullRequestURL := p.GetFullRequestURL(url) if fullRequestURL == "" { return nil, common.ErrorWrapper(nil, "invalid_zhipu_config", http.StatusInternalServerError) } @@ -100,7 +100,7 @@ func (p *ZhipuProvider) convertToChatOpenai(response *ZhipuResponse, request *ty } if len(openaiResponse.Choices) > 0 && openaiResponse.Choices[0].Message.ToolCalls != nil && request.Functions != nil { - for i, _ := range openaiResponse.Choices { + for i := range openaiResponse.Choices { openaiResponse.Choices[i].CheckChoice(request) } } @@ -112,7 +112,7 @@ func (p *ZhipuProvider) convertToChatOpenai(response *ZhipuResponse, request *ty func (p *ZhipuProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest { request.ClearEmptyMessages() - for i, _ := range request.Messages { + for i := range request.Messages { request.Messages[i].Role = convertRole(request.Messages[i].Role) if request.Messages[i].FunctionCall != nil { request.Messages[i].FuncToToolCalls() diff --git a/providers/zhipu/embeddings.go b/providers/zhipu/embeddings.go index 681f8fea..78947ec1 100644 --- a/providers/zhipu/embeddings.go +++ b/providers/zhipu/embeddings.go @@ -13,7 +13,7 @@ func (p *ZhipuProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*type return nil, errWithCode } // 获取请求地址 - fullRequestURL := p.GetFullRequestURL(url, request.Model) + fullRequestURL := p.GetFullRequestURL(url) if fullRequestURL == "" { return nil, common.ErrorWrapper(nil, "invalid_zhipu_config", http.StatusInternalServerError) } diff --git a/providers/zhipu/image_generations.go b/providers/zhipu/image_generations.go index 52622cf6..cfbf55bc 100644 --- a/providers/zhipu/image_generations.go +++ b/providers/zhipu/image_generations.go @@ -14,7 +14,7 @@ func (p *ZhipuProvider) CreateImageGenerations(request *types.ImageRequest) (*ty return nil, errWithCode } // 获取请求地址 - fullRequestURL := p.GetFullRequestURL(url, request.Model) + fullRequestURL := p.GetFullRequestURL(url) if fullRequestURL == "" { return nil, common.ErrorWrapper(nil, "invalid_zhipu_config", http.StatusInternalServerError) } diff --git a/relay/common.go b/relay/common.go index 97a68b4c..48b67f58 100644 --- a/relay/common.go +++ b/relay/common.go @@ -90,7 +90,7 @@ func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fai } func fetchChannelById(channelId int) (*model.Channel, error) { - channel, err := model.GetChannelById(channelId, true) + channel, err := model.GetChannelById(channelId) if err != nil { return nil, errors.New("无效的渠道 Id") } diff --git a/relay/relay_util/cache.go b/relay/relay_util/cache.go index 5130c753..9b97c43e 100644 --- a/relay/relay_util/cache.go +++ b/relay/relay_util/cache.go @@ -4,12 +4,10 @@ import ( "crypto/md5" "encoding/hex" "fmt" + "github.com/gin-gonic/gin" "one-api/common" "one-api/common/config" "one-api/common/utils" - "one-api/model" - - "github.com/gin-gonic/gin" ) type ChatCacheProps struct { @@ -31,24 +29,6 @@ type CacheDriver interface { Set(hash string, props *ChatCacheProps, expire int64) error } -func GetDebugList(userId int) ([]*ChatCacheProps, error) { - caches, err := model.GetChatCacheListByUserId(userId) - if err != nil { - return nil, err - } - - var props []*ChatCacheProps - for _, cache := range caches { - prop, err := utils.UnmarshalString[ChatCacheProps](cache.Data) - if err != nil { - continue - } - props = append(props, &prop) - } - - return props, nil -} - func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps { props := &ChatCacheProps{ Cache: false, diff --git a/relay/relay_util/quota.go b/relay/relay_util/quota.go index e0e0add2..e1a92cbb 100644 --- a/relay/relay_util/quota.go +++ b/relay/relay_util/quota.go @@ -99,7 +99,7 @@ func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string, quota = int(1000 * q.inputRatio) } else { completionRatio := q.price.GetOutput() * q.groupRatio - quota = int(math.Ceil(((float64(promptTokens) * q.inputRatio) + (float64(completionTokens) * completionRatio)))) + quota = int(math.Ceil((float64(promptTokens) * q.inputRatio) + (float64(completionTokens) * completionRatio))) } if q.inputRatio != 0 && quota <= 0 {