diff --git a/common/model-ratio.go b/common/model-ratio.go index cee4559c..2b975176 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -10,7 +10,7 @@ var ModelRatio = map[string]float64{ "gpt-4-0314": 15, "gpt-4-32k": 30, "gpt-4-32k-0314": 30, - "gpt-3.5-turbo": 1, + "gpt-3.5-turbo": 1, // $0.002 / 1K tokens "gpt-3.5-turbo-0301": 1, "text-ada-001": 0.2, "text-babbage-001": 0.25, diff --git a/controller/model.go b/controller/model.go index 829c795d..9685eb82 100644 --- a/controller/model.go +++ b/controller/model.go @@ -116,6 +116,51 @@ func init() { Root: "text-embedding-ada-002", Parent: nil, }, + { + Id: "text-davinci-003", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-davinci-003", + Parent: nil, + }, + { + Id: "text-davinci-002", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-davinci-002", + Parent: nil, + }, + { + Id: "text-curie-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-curie-001", + Parent: nil, + }, + { + Id: "text-babbage-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-babbage-001", + Parent: nil, + }, + { + Id: "text-ada-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-ada-001", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay.go b/controller/relay.go index fb3b8bc4..f2fa2dd4 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -19,6 +19,13 @@ type Message struct { Name *string `json:"name,omitempty"` } +const ( + RelayModeUnknown = iota + RelayModeChatCompletions + RelayModeCompletions + RelayModeEmbeddings +) + // https://platform.openai.com/docs/api-reference/chat type GeneralOpenAIRequest struct { @@ -69,7 +76,7 @@ type TextResponse struct { Error OpenAIError `json:"error"` } -type StreamResponse struct { +type ChatCompletionsStreamResponse struct { Choices []struct { Delta struct { Content string `json:"content"` @@ -78,8 +85,23 @@ type StreamResponse struct { } `json:"choices"` } +type CompletionsStreamResponse struct { + Choices []struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + func Relay(c *gin.Context) { - err := relayHelper(c) + relayMode := RelayModeUnknown + if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { + relayMode = RelayModeChatCompletions + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { + relayMode = RelayModeCompletions + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { + relayMode = RelayModeEmbeddings + } + err := relayHelper(c, relayMode) if err != nil { if err.StatusCode == http.StatusTooManyRequests { err.OpenAIError.Message = "负载已满,请稍后再试,或升级账户以提升服务质量。" @@ -110,7 +132,7 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus } } -func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { +func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { channelType := c.GetInt("channel") tokenId := c.GetInt("token_id") consumeQuota := c.GetBool("consume_quota") @@ -148,8 +170,13 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { err := relayPaLM(textRequest, c) return err } - - promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model) + var promptTokens int + switch relayMode { + case RelayModeChatCompletions: + promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) + case RelayModeCompletions: + promptTokens = countTokenText(textRequest.Prompt, textRequest.Model) + } preConsumedTokens := common.PreConsumedQuota if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + textRequest.MaxTokens @@ -245,14 +272,27 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { dataChan <- data data = data[6:] if !strings.HasPrefix(data, "[DONE]") { - var streamResponse StreamResponse - err = json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("Error unmarshalling stream response: " + err.Error()) - return - } - for _, choice := range streamResponse.Choices { - streamResponseText += choice.Delta.Content + switch relayMode { + case RelayModeChatCompletions: + var streamResponse ChatCompletionsStreamResponse + err = json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("Error unmarshalling stream response: " + err.Error()) + return + } + for _, choice := range streamResponse.Choices { + streamResponseText += choice.Delta.Content + } + case RelayModeCompletions: + var streamResponse CompletionsStreamResponse + err = json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("Error unmarshalling stream response: " + err.Error()) + return + } + for _, choice := range streamResponse.Choices { + streamResponseText += choice.Text + } } } } diff --git a/router/relay-router.go b/router/relay-router.go index 6d5b74a9..759e5f60 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router := router.Group("/v1") relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { - relayV1Router.POST("/completions", controller.RelayNotImplemented) + relayV1Router.POST("/completions", controller.Relay) relayV1Router.POST("/chat/completions", controller.Relay) relayV1Router.POST("/edits", controller.RelayNotImplemented) relayV1Router.POST("/images/generations", controller.RelayNotImplemented)