diff --git a/common/model-ratio.go b/common/model-ratio.go index 3f4f64b7..70758805 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -31,7 +31,7 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 10, + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "davinci": 10, "curie": 10, "babbage": 10, diff --git a/controller/model.go b/controller/model.go index a8ac6a65..88f95f7b 100644 --- a/controller/model.go +++ b/controller/model.go @@ -63,6 +63,15 @@ func init() { Root: "dall-e", Parent: nil, }, + { + Id: "whisper-1", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "whisper-1", + Parent: nil, + }, { Id: "gpt-3.5-turbo", Object: "model", diff --git a/controller/relay-audio.go b/controller/relay-audio.go new file mode 100644 index 00000000..277ab404 --- /dev/null +++ b/controller/relay-audio.go @@ -0,0 +1,147 @@ +package controller + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { + audioModel := "whisper-1" + + tokenId := c.GetInt("token_id") + channelType := c.GetInt("channel") + userId := c.GetInt("id") + group := c.GetString("group") + + preConsumedTokens := common.PreConsumedQuota + modelRatio := common.GetModelRatio(audioModel) + groupRatio := common.GetGroupRatio(group) + ratio := modelRatio * groupRatio + preConsumedQuota := int(float64(preConsumedTokens) * ratio) + userQuota, err := model.CacheGetUserQuota(userId) + if err != nil { + return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + + // map model name + modelMapping := c.GetString("model_mapping") + if modelMapping != "" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[audioModel] != "" { + audioModel = modelMap[audioModel] + } + } + + baseURL := common.ChannelBaseURLs[channelType] + requestURL := c.Request.URL.String() + + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + requestBody := c.Request.Body + + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + + resp, err := httpClient.Do(req) + if err != nil { + return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + err = req.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + err = c.Request.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + var audioResponse AudioResponse + + defer func() { + go func() { + quota := countTokenText(audioResponse.Text, audioModel) + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + }() + }() + + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + err = json.Unmarshal(responseBody, &audioResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + return nil +} diff --git a/controller/relay.go b/controller/relay.go index 6a2d58eb..056d42d3 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -24,6 +24,7 @@ const ( RelayModeModerations RelayModeImagesGenerations RelayModeEdits + RelayModeAudio ) // https://platform.openai.com/docs/api-reference/chat @@ -63,6 +64,10 @@ type ImageRequest struct { Size string `json:"size"` } +type AudioResponse struct { + Text string `json:"text,omitempty"` +} + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` @@ -159,11 +164,15 @@ func Relay(c *gin.Context) { relayMode = RelayModeImagesGenerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { relayMode = RelayModeEdits + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + relayMode = RelayModeAudio } var err *OpenAIErrorWithStatusCode switch relayMode { case RelayModeImagesGenerations: err = relayImageHelper(c, relayMode) + case RelayModeAudio: + err = relayAudioHelper(c, relayMode) default: err = relayTextHelper(c, relayMode) } diff --git a/middleware/distributor.go b/middleware/distributor.go index ebbde535..93827c95 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -58,7 +58,10 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c, &modelRequest) + var err error + if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + err = common.UnmarshalBodyReusable(c, &modelRequest) + } if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": gin.H{ @@ -84,6 +87,11 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "dall-e" } } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if modelRequest.Model == "" { + modelRequest.Model = "whisper-1" + } + } channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) diff --git a/router/relay-router.go b/router/relay-router.go index c3c84d8b..a76e42cf 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -26,8 +26,8 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/images/variations", controller.RelayNotImplemented) relayV1Router.POST("/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay) - relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented) - relayV1Router.POST("/audio/translations", controller.RelayNotImplemented) + relayV1Router.POST("/audio/transcriptions", controller.Relay) + relayV1Router.POST("/audio/translations", controller.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)