From bcca0cc0bc5e116daecbdf19fe49df724b67d027 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 21 May 2023 14:26:59 +0800 Subject: [PATCH] feat: PaLM support is WIP (#105) --- common/constants.go | 2 ++ controller/relay-palm.go | 59 ++++++++++++++++++++++++++++++++++++++++ controller/relay.go | 20 ++++++++++++-- 3 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 controller/relay-palm.go diff --git a/common/constants.go b/common/constants.go index 78474bd0..7c1ff298 100644 --- a/common/constants.go +++ b/common/constants.go @@ -129,6 +129,7 @@ const ( ChannelTypeCustom = 8 ChannelTypeAILS = 9 ChannelTypeAIProxy = 10 + ChannelTypePaLM = 11 ) var ChannelBaseURLs = []string{ @@ -143,4 +144,5 @@ var ChannelBaseURLs = []string{ "", // 8 "https://api.caipacity.com", // 9 "https://api.aiproxy.io", // 10 + "", // 11 } diff --git a/controller/relay-palm.go b/controller/relay-palm.go new file mode 100644 index 00000000..ae739ca0 --- /dev/null +++ b/controller/relay-palm.go @@ -0,0 +1,59 @@ +package controller + +import ( + "fmt" + "github.com/gin-gonic/gin" +) + +type PaLMChatMessage struct { + Author string `json:"author"` + Content string `json:"content"` +} + +type PaLMFilter struct { + Reason string `json:"reason"` + Message string `json:"message"` +} + +// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body +type PaLMChatRequest struct { + Prompt []Message `json:"prompt"` + Temperature float64 `json:"temperature"` + CandidateCount int `json:"candidateCount"` + TopP float64 `json:"topP"` + TopK int `json:"topK"` +} + +// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body +type PaLMChatResponse struct { + Candidates []Message `json:"candidates"` + Messages []Message `json:"messages"` + Filters []PaLMFilter `json:"filters"` +} + +func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode { + // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage + messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages)) + for _, message := range openAIRequest.Messages { + var author string + if message.Role == "user" { + author = "0" + } else { + author = "1" + } + messages = append(messages, PaLMChatMessage{ + Author: author, + Content: message.Content, + }) + } + request := PaLMChatRequest{ + Prompt: nil, + Temperature: openAIRequest.Temperature, + CandidateCount: openAIRequest.N, + TopP: openAIRequest.TopP, + TopK: openAIRequest.MaxTokens, + } + // TODO: forward request to PaLM & convert response + fmt.Print(request) + return nil +} diff --git a/controller/relay.go b/controller/relay.go index 4e093db5..83578141 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -19,6 +19,19 @@ type Message struct { Name *string `json:"name,omitempty"` } +// https://platform.openai.com/docs/api-reference/chat + +type GeneralOpenAIRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Prompt string `json:"prompt"` + Stream bool `json:"stream"` + MaxTokens int `json:"max_tokens"` + Temperature float64 `json:"temperature"` + TopP float64 `json:"top_p"` + N int `json:"n"` +} + type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` @@ -101,8 +114,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { channelType := c.GetInt("channel") tokenId := c.GetInt("token_id") consumeQuota := c.GetBool("consume_quota") - var textRequest TextRequest - if consumeQuota || channelType == common.ChannelTypeAzure { + var textRequest GeneralOpenAIRequest + if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { requestBody, err := io.ReadAll(c.Request.Body) if err != nil { return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest) @@ -141,6 +154,9 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { model_ = strings.TrimSuffix(model_, "-0301") model_ = strings.TrimSuffix(model_, "-0314") fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) + } else if channelType == common.ChannelTypePaLM { + err := relayPaLM(textRequest, c) + return err } promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)