feat: now use token as the unit of quota (close #33)

This commit is contained in:
JustSong 2023-04-28 16:58:55 +08:00
parent 601fa5cea8
commit 053bb85a1c
5 changed files with 185 additions and 22 deletions

View File

@ -2,6 +2,8 @@ package controller
import ( import (
"bufio" "bufio"
"bytes"
"encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
@ -11,14 +13,78 @@ import (
"strings" "strings"
) )
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
//Stream bool `json:"stream"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TextResponse struct {
Usage `json:"usage"`
}
type StreamResponse struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
isUnlimitedQuota := c.GetBool("unlimited_quota") consumeQuota := c.GetBool("consume_quota")
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
if channelType == common.ChannelTypeCustom { if channelType == common.ChannelTypeCustom {
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
} }
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
err = c.Request.Body.Close()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
var textRequest TextRequest
err = json.Unmarshal(requestBody, &textRequest)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body) req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body)
if err != nil { if err != nil {
@ -30,16 +96,11 @@ func Relay(c *gin.Context) {
}) })
return return
} }
//req.Header = c.Request.Header.Clone()
// Fix HTTP Decompression failed
// https://github.com/stoplightio/prism/issues/1064#issuecomment-824682360
//req.Header.Del("Accept-Encoding")
req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("Connection", c.Request.Header.Get("Connection")) req.Header.Set("Connection", c.Request.Header.Get("Connection"))
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -50,20 +111,36 @@ func Relay(c *gin.Context) {
}) })
return return
} }
err = req.Body.Close()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
var textResponse TextResponse
isStream := resp.Header.Get("Content-Type") == "text/event-stream"
var streamResponseText string
defer func() { defer func() {
err := req.Body.Close() if consumeQuota {
if err != nil { quota := 0
common.SysError("Error closing request body: " + err.Error()) if isStream {
} quota = int(float64(len(streamResponseText)) * 0.8)
if !isUnlimitedQuota && requestURL == "/v1/chat/completions" { } else {
err := model.DecreaseTokenRemainQuotaById(tokenId) quota = textResponse.Usage.TotalTokens
}
err := model.ConsumeTokenQuota(tokenId, quota)
if err != nil { if err != nil {
common.SysError("Error decreasing token remain times: " + err.Error()) common.SysError("Error consuming token remain quota: " + err.Error())
} }
} }
}() }()
isStream := resp.Header.Get("Content-Type") == "text/event-stream"
if isStream { if isStream {
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@ -87,6 +164,18 @@ func Relay(c *gin.Context) {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
dataChan <- data dataChan <- data
data = data[6:]
if data != "[DONE]" {
var streamResponse StreamResponse
err = json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("Error unmarshalling stream response: " + err.Error())
return
}
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Delta.Content
}
}
} }
stopChan <- true stopChan <- true
}() }()
@ -108,6 +197,38 @@ func Relay(c *gin.Context) {
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])
} }
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
err = resp.Body.Close()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
_, err = io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -120,3 +241,12 @@ func Relay(c *gin.Context) {
} }
} }
} }
func RelayNotImplemented(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": "Not Implemented",
"type": "one_api_error",
},
})
}

View File

@ -110,7 +110,17 @@ func TokenAuth() func(c *gin.Context) {
} }
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("unlimited_quota", token.UnlimitedQuota) requestURL := c.Request.URL.String()
consumeQuota := false
switch requestURL {
case "/v1/chat/completions":
consumeQuota = !token.UnlimitedQuota
case "/v1/completions":
consumeQuota = !token.UnlimitedQuota
case "/v1/edits":
consumeQuota = !token.UnlimitedQuota
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("channelId", parts[1])

View File

@ -55,7 +55,7 @@ func Redeem(key string, tokenId int) (quota int, err error) {
if redemption.Status != common.RedemptionCodeStatusEnabled { if redemption.Status != common.RedemptionCodeStatusEnabled {
return 0, errors.New("该兑换码已被使用") return 0, errors.New("该兑换码已被使用")
} }
err = TopUpToken(tokenId, redemption.Quota) err = TopUpTokenQuota(tokenId, redemption.Quota)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -119,12 +119,12 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete() return token.Delete()
} }
func DecreaseTokenRemainQuotaById(id int) (err error) { func ConsumeTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", 1)).Error err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
return err return err
} }
func TopUpToken(id int, times int) (err error) { func TopUpTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", times)).Error err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error
return err return err
} }

View File

@ -7,12 +7,35 @@ import (
) )
func SetRelayRouter(router *gin.Engine) { func SetRelayRouter(router *gin.Engine) {
// https://platform.openai.com/docs/api-reference/introduction
relayV1Router := router.Group("/v1") relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{ {
relayV1Router.Any("/*path", controller.Relay) relayV1Router.GET("/models", controller.Relay)
relayV1Router.GET("/models/:model", controller.Relay)
relayV1Router.POST("/completions", controller.RelayNotImplemented)
relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/generations", controller.RelayNotImplemented)
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.POST("/embeddings", controller.RelayNotImplemented)
relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented)
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.RelayNotImplemented)
} }
relayDashboardRouter := router.Group("/dashboard") relayDashboardRouter := router.Group("/dashboard") // TODO: return system's own token info
relayDashboardRouter.Use(middleware.TokenAuth(), middleware.Distribute()) relayDashboardRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{ {
relayDashboardRouter.Any("/*path", controller.Relay) relayDashboardRouter.Any("/*path", controller.Relay)