feat: support /v1/completions (close #115)

This commit is contained in:
JustSong 2023-06-08 14:54:02 +08:00
parent 9301b3fed3
commit 4b6adaec0b
4 changed files with 100 additions and 15 deletions

View File

@ -10,7 +10,7 @@ var ModelRatio = map[string]float64{
"gpt-4-0314": 15, "gpt-4-0314": 15,
"gpt-4-32k": 30, "gpt-4-32k": 30,
"gpt-4-32k-0314": 30, "gpt-4-32k-0314": 30,
"gpt-3.5-turbo": 1, "gpt-3.5-turbo": 1, // $0.002 / 1K tokens
"gpt-3.5-turbo-0301": 1, "gpt-3.5-turbo-0301": 1,
"text-ada-001": 0.2, "text-ada-001": 0.2,
"text-babbage-001": 0.25, "text-babbage-001": 0.25,

View File

@ -116,6 +116,51 @@ func init() {
Root: "text-embedding-ada-002", Root: "text-embedding-ada-002",
Parent: nil, Parent: nil,
}, },
{
Id: "text-davinci-003",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-003",
Parent: nil,
},
{
Id: "text-davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-002",
Parent: nil,
},
{
Id: "text-curie-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-curie-001",
Parent: nil,
},
{
Id: "text-babbage-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-babbage-001",
Parent: nil,
},
{
Id: "text-ada-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-ada-001",
Parent: nil,
},
} }
openAIModelsMap = make(map[string]OpenAIModels) openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels { for _, model := range openAIModels {

View File

@ -19,6 +19,13 @@ type Message struct {
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
} }
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
)
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
@ -69,7 +76,7 @@ type TextResponse struct {
Error OpenAIError `json:"error"` Error OpenAIError `json:"error"`
} }
type StreamResponse struct { type ChatCompletionsStreamResponse struct {
Choices []struct { Choices []struct {
Delta struct { Delta struct {
Content string `json:"content"` Content string `json:"content"`
@ -78,8 +85,23 @@ type StreamResponse struct {
} `json:"choices"` } `json:"choices"`
} }
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
err := relayHelper(c) relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
}
err := relayHelper(c, relayMode)
if err != nil { if err != nil {
if err.StatusCode == http.StatusTooManyRequests { if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "负载已满,请稍后再试,或升级账户以提升服务质量。" err.OpenAIError.Message = "负载已满,请稍后再试,或升级账户以提升服务质量。"
@ -110,7 +132,7 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus
} }
} }
func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
consumeQuota := c.GetBool("consume_quota") consumeQuota := c.GetBool("consume_quota")
@ -148,8 +170,13 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
err := relayPaLM(textRequest, c) err := relayPaLM(textRequest, c)
return err return err
} }
var promptTokens int
promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model) switch relayMode {
case RelayModeChatCompletions:
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
case RelayModeCompletions:
promptTokens = countTokenText(textRequest.Prompt, textRequest.Model)
}
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 { if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens preConsumedTokens = promptTokens + textRequest.MaxTokens
@ -245,14 +272,27 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
dataChan <- data dataChan <- data
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
var streamResponse StreamResponse switch relayMode {
err = json.Unmarshal([]byte(data), &streamResponse) case RelayModeChatCompletions:
if err != nil { var streamResponse ChatCompletionsStreamResponse
common.SysError("Error unmarshalling stream response: " + err.Error()) err = json.Unmarshal([]byte(data), &streamResponse)
return if err != nil {
} common.SysError("Error unmarshalling stream response: " + err.Error())
for _, choice := range streamResponse.Choices { return
streamResponseText += choice.Delta.Content }
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Delta.Content
}
case RelayModeCompletions:
var streamResponse CompletionsStreamResponse
err = json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("Error unmarshalling stream response: " + err.Error())
return
}
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Text
}
} }
} }
} }

View File

@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router := router.Group("/v1") relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{ {
relayV1Router.POST("/completions", controller.RelayNotImplemented) relayV1Router.POST("/completions", controller.Relay)
relayV1Router.POST("/chat/completions", controller.Relay) relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.RelayNotImplemented) relayV1Router.POST("/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/generations", controller.RelayNotImplemented) relayV1Router.POST("/images/generations", controller.RelayNotImplemented)