diff --git a/common/constants.go b/common/constants.go index 429803dc..5d405bfd 100644 --- a/common/constants.go +++ b/common/constants.go @@ -173,26 +173,27 @@ const ( ) const ( - ChannelTypeUnknown = 0 - ChannelTypeOpenAI = 1 - ChannelTypeAPI2D = 2 - ChannelTypeAzure = 3 - ChannelTypeCloseAI = 4 - ChannelTypeOpenAISB = 5 - ChannelTypeOpenAIMax = 6 - ChannelTypeOhMyGPT = 7 - ChannelTypeCustom = 8 - ChannelTypeAILS = 9 - ChannelTypeAIProxy = 10 - ChannelTypePaLM = 11 - ChannelTypeAPI2GPT = 12 - ChannelTypeAIGC2D = 13 - ChannelTypeAnthropic = 14 - ChannelTypeBaidu = 15 - ChannelTypeZhipu = 16 - ChannelTypeAli = 17 - ChannelTypeXunfei = 18 - ChannelType360 = 19 + ChannelTypeUnknown = 0 + ChannelTypeOpenAI = 1 + ChannelTypeAPI2D = 2 + ChannelTypeAzure = 3 + ChannelTypeCloseAI = 4 + ChannelTypeOpenAISB = 5 + ChannelTypeOpenAIMax = 6 + ChannelTypeOhMyGPT = 7 + ChannelTypeCustom = 8 + ChannelTypeAILS = 9 + ChannelTypeAIProxy = 10 + ChannelTypePaLM = 11 + ChannelTypeAPI2GPT = 12 + ChannelTypeAIGC2D = 13 + ChannelTypeAnthropic = 14 + ChannelTypeBaidu = 15 + ChannelTypeZhipu = 16 + ChannelTypeAli = 17 + ChannelTypeXunfei = 18 + ChannelType360 = 19 + ChannelTypeOpenRouter = 20 ) var ChannelBaseURLs = []string{ @@ -216,4 +217,5 @@ var ChannelBaseURLs = []string{ "https://dashscope.aliyuncs.com", // 17 "", // 18 "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 } diff --git a/common/model-ratio.go b/common/model-ratio.go index 3f4f64b7..70758805 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -31,7 +31,7 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 10, + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "davinci": 10, "curie": 10, "babbage": 10, diff --git a/controller/model.go b/controller/model.go index a8ac6a65..88f95f7b 100644 --- a/controller/model.go +++ b/controller/model.go @@ -63,6 +63,15 @@ func init() { Root: "dall-e", Parent: nil, }, + { + Id: "whisper-1", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "whisper-1", + Parent: nil, + }, { Id: "gpt-3.5-turbo", Object: "model", diff --git a/controller/relay-audio.go b/controller/relay-audio.go new file mode 100644 index 00000000..277ab404 --- /dev/null +++ b/controller/relay-audio.go @@ -0,0 +1,147 @@ +package controller + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { + audioModel := "whisper-1" + + tokenId := c.GetInt("token_id") + channelType := c.GetInt("channel") + userId := c.GetInt("id") + group := c.GetString("group") + + preConsumedTokens := common.PreConsumedQuota + modelRatio := common.GetModelRatio(audioModel) + groupRatio := common.GetGroupRatio(group) + ratio := modelRatio * groupRatio + preConsumedQuota := int(float64(preConsumedTokens) * ratio) + userQuota, err := model.CacheGetUserQuota(userId) + if err != nil { + return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + + // map model name + modelMapping := c.GetString("model_mapping") + if modelMapping != "" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[audioModel] != "" { + audioModel = modelMap[audioModel] + } + } + + baseURL := common.ChannelBaseURLs[channelType] + requestURL := c.Request.URL.String() + + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + requestBody := c.Request.Body + + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + + resp, err := httpClient.Do(req) + if err != nil { + return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + err = req.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + err = c.Request.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + var audioResponse AudioResponse + + defer func() { + go func() { + quota := countTokenText(audioResponse.Text, audioModel) + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + }() + }() + + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + err = json.Unmarshal(responseBody, &audioResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + return nil +} diff --git a/controller/relay-text.go b/controller/relay-text.go index c245be94..b10de594 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -282,6 +282,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { req.Header.Set("api-key", apiKey) } else { req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + if channelType == common.ChannelTypeOpenRouter { + req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") + req.Header.Set("X-Title", "One API") + } } case APITypeClaude: req.Header.Set("x-api-key", apiKey) @@ -315,6 +319,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + + if resp.StatusCode != http.StatusOK { + return relayErrorHandler(resp) + } } if resp.StatusCode != http.StatusOK { return errorWrapper( diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 668b37bf..3773dbbb 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -1,11 +1,15 @@ package controller import ( + "encoding/json" "fmt" - "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" + "io" "net/http" "one-api/common" + "strconv" + + "github.com/gin-gonic/gin" + "github.com/pkoukk/tiktoken-go" ) var stopFinishReason = "stop" @@ -137,3 +141,30 @@ func setEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("Transfer-Encoding", "chunked") c.Writer.Header().Set("X-Accel-Buffering", "no") } + +func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { + openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ + StatusCode: resp.StatusCode, + OpenAIError: OpenAIError{ + Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Type: "one_api_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + err = resp.Body.Close() + if err != nil { + return + } + var textResponse TextResponse + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return + } + openAIErrorWithStatusCode.OpenAIError = textResponse.Error + return +} diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 87037e34..3b6fe5a0 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -75,7 +75,7 @@ type XunfeiChatResponse struct { } `json:"payload"` } -func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest { +func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { messages := make([]XunfeiMessage, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { @@ -96,7 +96,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *Xun } xunfeiRequest := XunfeiChatRequest{} xunfeiRequest.Header.AppId = xunfeiAppId - xunfeiRequest.Parameter.Chat.Domain = "general" + xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature xunfeiRequest.Parameter.Chat.TopK = request.N xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens @@ -178,15 +178,28 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { var usage Usage + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + if apiVersion == "" { + apiVersion = "v1.1" + common.SysLog("api_version not found, use default: " + apiVersion) + } + domain := "general" + if apiVersion == "v2.1" { + domain = "generalv2" + } + hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion) d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } - hostUrl := "wss://aichat.xf-yun.com/v1/chat" conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) if err != nil || resp.StatusCode != 101 { return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil } - data := requestOpenAI2Xunfei(textRequest, appId) + data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil diff --git a/controller/relay.go b/controller/relay.go index c266621a..4068e520 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -24,6 +24,7 @@ const ( RelayModeModerations RelayModeImagesGenerations RelayModeEdits + RelayModeAudio ) // https://platform.openai.com/docs/api-reference/chat @@ -64,6 +65,10 @@ type ImageRequest struct { Size string `json:"size"` } +type AudioResponse struct { + Text string `json:"text,omitempty"` +} + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` @@ -160,11 +165,15 @@ func Relay(c *gin.Context) { relayMode = RelayModeImagesGenerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { relayMode = RelayModeEdits + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + relayMode = RelayModeAudio } var err *OpenAIErrorWithStatusCode switch relayMode { case RelayModeImagesGenerations: err = relayImageHelper(c, relayMode) + case RelayModeAudio: + err = relayAudioHelper(c, relayMode) default: err = relayTextHelper(c, relayMode) } diff --git a/i18n/en.json b/i18n/en.json index e8916e52..54a5d688 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -613,5 +613,7 @@ "点击查看": "click to view", "使用明细(总消耗额度:" : "Usage Details (Total Consumption Quota: ", ")": ")", - "360 智脑": "360 AI" + "360 智脑": "360 AI", + "模型版本": "Model version", + "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1" } diff --git a/middleware/distributor.go b/middleware/distributor.go index 1940c69c..acdbc60d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -59,7 +59,10 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c, &modelRequest) + var err error + if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + err = common.UnmarshalBodyReusable(c, &modelRequest) + } if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": gin.H{ @@ -85,7 +88,12 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "dall-e" } } - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream) + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if modelRequest.Model == "" { + modelRequest.Model = "whisper-1" + } + } + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) if channel != nil { @@ -108,7 +116,7 @@ func Distribute() func(c *gin.Context) { c.Set("model_mapping", channel.ModelMapping) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.BaseURL) - if channel.Type == common.ChannelTypeAzure { + if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei { c.Set("api_version", channel.Other) } c.Next() diff --git a/router/relay-router.go b/router/relay-router.go index c3c84d8b..a76e42cf 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -26,8 +26,8 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/images/variations", controller.RelayNotImplemented) relayV1Router.POST("/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay) - relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented) - relayV1Router.POST("/audio/translations", controller.RelayNotImplemented) + relayV1Router.POST("/audio/transcriptions", controller.Relay) + relayV1Router.POST("/audio/translations", controller.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index a14c4e0f..b1631479 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -9,6 +9,7 @@ export const CHANNEL_OPTIONS = [ { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 19, text: '360 智脑', value: 19, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 97f10518..6ce136b9 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -1,6 +1,6 @@ import React, { useEffect, useState } from 'react'; import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; -import { useParams, useNavigate } from 'react-router-dom'; +import { useNavigate, useParams } from 'react-router-dom'; import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; import { CHANNEL_OPTIONS } from '../../constants'; @@ -19,7 +19,7 @@ const EditChannel = () => { const handleCancel = () => { navigate('/channel'); }; - + const originInputs = { name: '', type: 1, @@ -64,7 +64,7 @@ const EditChannel = () => { localModels = ['SparkDesk']; break; case 19: - localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'] + localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4']; break; } setInputs((inputs) => ({ ...inputs, models: localModels })); @@ -176,6 +176,9 @@ const EditChannel = () => { if (localInputs.type === 3 && localInputs.other === '') { localInputs.other = '2023-06-01-preview'; } + if (localInputs.type === 18 && localInputs.other === '') { + localInputs.other = 'v2.1'; + } if (localInputs.model_mapping === '') { localInputs.model_mapping = '{}'; } @@ -288,6 +291,20 @@ const EditChannel = () => { options={groupOptions} /> + { + inputs.type === 18 && ( + + + + ) + }