From 4339f45f747bfb19cb2d36e0f603ab64afeb383a Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 11 Jun 2023 09:37:36 +0800 Subject: [PATCH 1/2] feat: support /v1/moderations now (close #117) --- common/model-ratio.go | 4 ++-- controller/model.go | 18 ++++++++++++++++++ controller/relay.go | 12 ++++++++++++ middleware/distributor.go | 6 ++++++ router/relay-router.go | 2 +- 5 files changed, 39 insertions(+), 3 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 2b975176..bc7e7be3 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -26,8 +26,8 @@ var ModelRatio = map[string]float64{ "ada": 10, "text-embedding-ada-002": 0.2, "text-search-ada-doc-001": 10, - "text-moderation-stable": 10, - "text-moderation-latest": 10, + "text-moderation-stable": 0.1, + "text-moderation-latest": 0.1, } func ModelRatio2JSONString() string { diff --git a/controller/model.go b/controller/model.go index 9685eb82..dd3777f7 100644 --- a/controller/model.go +++ b/controller/model.go @@ -161,6 +161,24 @@ func init() { Root: "text-ada-001", Parent: nil, }, + { + Id: "text-moderation-latest", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-moderation-latest", + Parent: nil, + }, + { + Id: "text-moderation-stable", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-moderation-stable", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay.go b/controller/relay.go index ac68f73d..a581d3cc 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -24,6 +24,7 @@ const ( RelayModeChatCompletions RelayModeCompletions RelayModeEmbeddings + RelayModeModeration ) // https://platform.openai.com/docs/api-reference/chat @@ -37,6 +38,7 @@ type GeneralOpenAIRequest struct { Temperature float64 `json:"temperature"` TopP float64 `json:"top_p"` N int `json:"n"` + Input string `json:"input"` } type ChatRequest struct { @@ -100,6 +102,8 @@ func Relay(c *gin.Context) { relayMode = RelayModeCompletions } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { relayMode = RelayModeEmbeddings + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + relayMode = RelayModeModeration } err := relayHelper(c, relayMode) if err != nil { @@ -143,6 +147,9 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } } + if relayMode == RelayModeModeration && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if channelType == common.ChannelTypeCustom { @@ -180,6 +187,8 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) case RelayModeCompletions: promptTokens = countTokenText(textRequest.Prompt, textRequest.Model) + case RelayModeModeration: + promptTokens = countTokenText(textRequest.Input, textRequest.Model) } preConsumedTokens := common.PreConsumedQuota if textRequest.MaxTokens != 0 { @@ -239,6 +248,9 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio } quota = int(float64(quota) * ratio) + if ratio != 0 && quota <= 0 { + quota = 1 + } quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { diff --git a/middleware/distributor.go b/middleware/distributor.go index 6fa2bc28..0f4221bf 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/model" "strconv" + "strings" ) type ModelRequest struct { @@ -64,6 +65,11 @@ func Distribute() func(c *gin.Context) { c.Abort() return } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } userId := c.GetInt("id") userGroup, _ := model.GetUserGroup(userId) channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model) diff --git a/router/relay-router.go b/router/relay-router.go index 759e5f60..0b697af8 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -37,6 +37,6 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) - relayV1Router.POST("/moderations", controller.RelayNotImplemented) + relayV1Router.POST("/moderations", controller.Relay) } } From f97a9ce597650d01c95353eabfde9f61c9e3d77a Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 11 Jun 2023 09:49:57 +0800 Subject: [PATCH 2/2] fix: correct OpenAI error code's type --- controller/channel-test.go | 2 +- controller/relay.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 838a2738..ec865b23 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -59,7 +59,7 @@ func testChannel(channel *model.Channel, request *ChatRequest) error { return err } if response.Usage.CompletionTokens == 0 { - return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) + return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) } return nil } diff --git a/controller/relay.go b/controller/relay.go index a581d3cc..63b77e4a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -65,7 +65,7 @@ type OpenAIError struct { Message string `json:"message"` Type string `json:"type"` Param string `json:"param"` - Code string `json:"code"` + Code any `json:"code"` } type OpenAIErrorWithStatusCode struct {