diff --git a/common/model-ratio.go b/common/model-ratio.go index 2b975176..bc7e7be3 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -26,8 +26,8 @@ var ModelRatio = map[string]float64{ "ada": 10, "text-embedding-ada-002": 0.2, "text-search-ada-doc-001": 10, - "text-moderation-stable": 10, - "text-moderation-latest": 10, + "text-moderation-stable": 0.1, + "text-moderation-latest": 0.1, } func ModelRatio2JSONString() string { diff --git a/controller/model.go b/controller/model.go index 9685eb82..dd3777f7 100644 --- a/controller/model.go +++ b/controller/model.go @@ -161,6 +161,24 @@ func init() { Root: "text-ada-001", Parent: nil, }, + { + Id: "text-moderation-latest", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-moderation-latest", + Parent: nil, + }, + { + Id: "text-moderation-stable", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-moderation-stable", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay.go b/controller/relay.go index ac68f73d..a581d3cc 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -24,6 +24,7 @@ const ( RelayModeChatCompletions RelayModeCompletions RelayModeEmbeddings + RelayModeModeration ) // https://platform.openai.com/docs/api-reference/chat @@ -37,6 +38,7 @@ type GeneralOpenAIRequest struct { Temperature float64 `json:"temperature"` TopP float64 `json:"top_p"` N int `json:"n"` + Input string `json:"input"` } type ChatRequest struct { @@ -100,6 +102,8 @@ func Relay(c *gin.Context) { relayMode = RelayModeCompletions } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { relayMode = RelayModeEmbeddings + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + relayMode = RelayModeModeration } err := relayHelper(c, relayMode) if err != nil { @@ -143,6 +147,9 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } } + if relayMode == RelayModeModeration && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if channelType == common.ChannelTypeCustom { @@ -180,6 +187,8 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) case RelayModeCompletions: promptTokens = countTokenText(textRequest.Prompt, textRequest.Model) + case RelayModeModeration: + promptTokens = countTokenText(textRequest.Input, textRequest.Model) } preConsumedTokens := common.PreConsumedQuota if textRequest.MaxTokens != 0 { @@ -239,6 +248,9 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio } quota = int(float64(quota) * ratio) + if ratio != 0 && quota <= 0 { + quota = 1 + } quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { diff --git a/middleware/distributor.go b/middleware/distributor.go index 6fa2bc28..0f4221bf 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/model" "strconv" + "strings" ) type ModelRequest struct { @@ -64,6 +65,11 @@ func Distribute() func(c *gin.Context) { c.Abort() return } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } userId := c.GetInt("id") userGroup, _ := model.GetUserGroup(userId) channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model) diff --git a/router/relay-router.go b/router/relay-router.go index 759e5f60..0b697af8 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -37,6 +37,6 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) - relayV1Router.POST("/moderations", controller.RelayNotImplemented) + relayV1Router.POST("/moderations", controller.Relay) } }