From 9b178a28a3952b6aaa0671605ccbbf2de355865a Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 25 Jun 2023 11:46:23 +0800 Subject: [PATCH] feat: support /v1/edits now (close #196) --- controller/model.go | 18 ++++++++++++++++++ controller/relay-text.go | 25 ++++++++++++++++--------- controller/relay.go | 8 ++++++-- router/relay-router.go | 2 +- 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/controller/model.go b/controller/model.go index 08819c72..83d0a774 100644 --- a/controller/model.go +++ b/controller/model.go @@ -224,6 +224,24 @@ func init() { Root: "text-moderation-stable", Parent: nil, }, + { + Id: "text-davinci-edit-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-davinci-edit-001", + Parent: nil, + }, + { + Id: "code-davinci-edit-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "code-davinci-edit-001", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay-text.go b/controller/relay-text.go index e14e0632..f2b91c02 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -27,7 +27,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } } - if relayMode == RelayModeModeration && textRequest.Model == "" { + if relayMode == RelayModeModerations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } // request validation @@ -37,16 +37,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { switch relayMode { case RelayModeCompletions: if textRequest.Prompt == "" { - return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) } case RelayModeChatCompletions: - if len(textRequest.Messages) == 0 { - return errorWrapper(errors.New("messages is required"), "required_field_missing", http.StatusBadRequest) + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) } case RelayModeEmbeddings: - case RelayModeModeration: + case RelayModeModerations: if textRequest.Input == "" { - return errorWrapper(errors.New("input is required"), "required_field_missing", http.StatusBadRequest) + return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) + } + case RelayModeEdits: + if textRequest.Instruction == "" { + return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) } } baseURL := common.ChannelBaseURLs[channelType] @@ -84,7 +88,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) case RelayModeCompletions: promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) - case RelayModeModeration: + case RelayModeModerations: promptTokens = countTokenInput(textRequest.Input, textRequest.Model) } preConsumedTokens := common.PreConsumedQuota @@ -144,7 +148,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { defer func() { if consumeQuota { quota := 0 - completionRatio := 1.333333 // default for gpt-3 + completionRatio := 1.0 + if strings.HasPrefix(textRequest.Model, "gpt-3.5") { + completionRatio = 1.333333 + } if strings.HasPrefix(textRequest.Model, "gpt-4") { completionRatio = 2 } @@ -172,7 +179,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } if quota != 0 { tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") diff --git a/controller/relay.go b/controller/relay.go index b6f04c09..2910cc97 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -19,8 +19,9 @@ const ( RelayModeChatCompletions RelayModeCompletions RelayModeEmbeddings - RelayModeModeration + RelayModeModerations RelayModeImagesGenerations + RelayModeEdits ) // https://platform.openai.com/docs/api-reference/chat @@ -35,6 +36,7 @@ type GeneralOpenAIRequest struct { TopP float64 `json:"top_p"` N int `json:"n"` Input any `json:"input"` + Instruction string `json:"instruction"` } type ChatRequest struct { @@ -99,9 +101,11 @@ func Relay(c *gin.Context) { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { relayMode = RelayModeEmbeddings } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - relayMode = RelayModeModeration + relayMode = RelayModeModerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { relayMode = RelayModeImagesGenerations + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { + relayMode = RelayModeEdits } var err *OpenAIErrorWithStatusCode switch relayMode { diff --git a/router/relay-router.go b/router/relay-router.go index 0b697af8..cbdfef11 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -19,7 +19,7 @@ func SetRelayRouter(router *gin.Engine) { { relayV1Router.POST("/completions", controller.Relay) relayV1Router.POST("/chat/completions", controller.Relay) - relayV1Router.POST("/edits", controller.RelayNotImplemented) + relayV1Router.POST("/edits", controller.Relay) relayV1Router.POST("/images/generations", controller.RelayNotImplemented) relayV1Router.POST("/images/edits", controller.RelayNotImplemented) relayV1Router.POST("/images/variations", controller.RelayNotImplemented)