From 053bb85a1c217241ec45ca452a8be80ba2c0ee43 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 28 Apr 2023 16:58:55 +0800 Subject: [PATCH] feat: now use token as the unit of quota (close #33) --- controller/relay.go | 158 +++++++++++++++++++++++++++++++++++++---- middleware/auth.go | 12 +++- model/redemption.go | 2 +- model/token.go | 8 +-- router/relay-router.go | 27 ++++++- 5 files changed, 185 insertions(+), 22 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index f51709c2..2d81bc7e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,6 +2,8 @@ package controller import ( "bufio" + "bytes" + "encoding/json" "fmt" "github.com/gin-gonic/gin" "io" @@ -11,14 +13,78 @@ import ( "strings" ) +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type TextRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Prompt string `json:"prompt"` + //Stream bool `json:"stream"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type TextResponse struct { + Usage `json:"usage"` +} + +type StreamResponse struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + func Relay(c *gin.Context) { channelType := c.GetInt("channel") tokenId := c.GetInt("token_id") - isUnlimitedQuota := c.GetBool("unlimited_quota") + consumeQuota := c.GetBool("consume_quota") baseURL := common.ChannelBaseURLs[channelType] if channelType == common.ChannelTypeCustom { baseURL = c.GetString("base_url") } + requestBody, err := io.ReadAll(c.Request.Body) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + return + } + err = c.Request.Body.Close() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + return + } + var textRequest TextRequest + err = json.Unmarshal(requestBody, &textRequest) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + return + } + // Reset request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) requestURL := c.Request.URL.String() req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body) if err != nil { @@ -30,16 +96,11 @@ func Relay(c *gin.Context) { }) return } - //req.Header = c.Request.Header.Clone() - // Fix HTTP Decompression failed - // https://github.com/stoplightio/prism/issues/1064#issuecomment-824682360 - //req.Header.Del("Accept-Encoding") req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Connection", c.Request.Header.Get("Connection")) client := &http.Client{} - resp, err := client.Do(req) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -50,20 +111,36 @@ func Relay(c *gin.Context) { }) return } + err = req.Body.Close() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + return + } + + var textResponse TextResponse + isStream := resp.Header.Get("Content-Type") == "text/event-stream" + var streamResponseText string defer func() { - err := req.Body.Close() - if err != nil { - common.SysError("Error closing request body: " + err.Error()) - } - if !isUnlimitedQuota && requestURL == "/v1/chat/completions" { - err := model.DecreaseTokenRemainQuotaById(tokenId) + if consumeQuota { + quota := 0 + if isStream { + quota = int(float64(len(streamResponseText)) * 0.8) + } else { + quota = textResponse.Usage.TotalTokens + } + err := model.ConsumeTokenQuota(tokenId, quota) if err != nil { - common.SysError("Error decreasing token remain times: " + err.Error()) + common.SysError("Error consuming token remain quota: " + err.Error()) } } }() - isStream := resp.Header.Get("Content-Type") == "text/event-stream" + if isStream { scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -87,6 +164,18 @@ func Relay(c *gin.Context) { for scanner.Scan() { data := scanner.Text() dataChan <- data + data = data[6:] + if data != "[DONE]" { + var streamResponse StreamResponse + err = json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("Error unmarshalling stream response: " + err.Error()) + return + } + for _, choice := range streamResponse.Choices { + streamResponseText += choice.Delta.Content + } + } } stopChan <- true }() @@ -108,6 +197,38 @@ func Relay(c *gin.Context) { for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + return + } + err = resp.Body.Close() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + return + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + return + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) _, err = io.Copy(c.Writer, resp.Body) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -120,3 +241,12 @@ func Relay(c *gin.Context) { } } } + +func RelayNotImplemented(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": "Not Implemented", + "type": "one_api_error", + }, + }) +} diff --git a/middleware/auth.go b/middleware/auth.go index c2cd5891..d5891b4f 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -110,7 +110,17 @@ func TokenAuth() func(c *gin.Context) { } c.Set("id", token.UserId) c.Set("token_id", token.Id) - c.Set("unlimited_quota", token.UnlimitedQuota) + requestURL := c.Request.URL.String() + consumeQuota := false + switch requestURL { + case "/v1/chat/completions": + consumeQuota = !token.UnlimitedQuota + case "/v1/completions": + consumeQuota = !token.UnlimitedQuota + case "/v1/edits": + consumeQuota = !token.UnlimitedQuota + } + c.Set("consume_quota", consumeQuota) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) diff --git a/model/redemption.go b/model/redemption.go index 37036e68..838d7ecd 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -55,7 +55,7 @@ func Redeem(key string, tokenId int) (quota int, err error) { if redemption.Status != common.RedemptionCodeStatusEnabled { return 0, errors.New("该兑换码已被使用") } - err = TopUpToken(tokenId, redemption.Quota) + err = TopUpTokenQuota(tokenId, redemption.Quota) if err != nil { return 0, err } diff --git a/model/token.go b/model/token.go index f4896bbc..ff1806af 100644 --- a/model/token.go +++ b/model/token.go @@ -119,12 +119,12 @@ func DeleteTokenById(id int, userId int) (err error) { return token.Delete() } -func DecreaseTokenRemainQuotaById(id int) (err error) { - err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", 1)).Error +func ConsumeTokenQuota(id int, quota int) (err error) { + err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error return err } -func TopUpToken(id int, times int) (err error) { - err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", times)).Error +func TopUpTokenQuota(id int, quota int) (err error) { + err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error return err } diff --git a/router/relay-router.go b/router/relay-router.go index 585af78a..2fd0021e 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -7,12 +7,35 @@ import ( ) func SetRelayRouter(router *gin.Engine) { + // https://platform.openai.com/docs/api-reference/introduction relayV1Router := router.Group("/v1") relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { - relayV1Router.Any("/*path", controller.Relay) + relayV1Router.GET("/models", controller.Relay) + relayV1Router.GET("/models/:model", controller.Relay) + relayV1Router.POST("/completions", controller.RelayNotImplemented) + relayV1Router.POST("/chat/completions", controller.Relay) + relayV1Router.POST("/edits", controller.RelayNotImplemented) + relayV1Router.POST("/images/generations", controller.RelayNotImplemented) + relayV1Router.POST("/images/edits", controller.RelayNotImplemented) + relayV1Router.POST("/images/variations", controller.RelayNotImplemented) + relayV1Router.POST("/embeddings", controller.RelayNotImplemented) + relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented) + relayV1Router.POST("/audio/translations", controller.RelayNotImplemented) + relayV1Router.GET("/files", controller.RelayNotImplemented) + relayV1Router.POST("/files", controller.RelayNotImplemented) + relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) + relayV1Router.GET("/files/:id", controller.RelayNotImplemented) + relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) + relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented) + relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented) + relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented) + relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) + relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) + relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) + relayV1Router.POST("/moderations", controller.RelayNotImplemented) } - relayDashboardRouter := router.Group("/dashboard") + relayDashboardRouter := router.Group("/dashboard") // TODO: return system's own token info relayDashboardRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { relayDashboardRouter.Any("/*path", controller.Relay)