feat: PaLM support is WIP (#105)
This commit is contained in:
parent
b92ec5e54c
commit
bcca0cc0bc
@ -129,6 +129,7 @@ const (
|
|||||||
ChannelTypeCustom = 8
|
ChannelTypeCustom = 8
|
||||||
ChannelTypeAILS = 9
|
ChannelTypeAILS = 9
|
||||||
ChannelTypeAIProxy = 10
|
ChannelTypeAIProxy = 10
|
||||||
|
ChannelTypePaLM = 11
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@ -143,4 +144,5 @@ var ChannelBaseURLs = []string{
|
|||||||
"", // 8
|
"", // 8
|
||||||
"https://api.caipacity.com", // 9
|
"https://api.caipacity.com", // 9
|
||||||
"https://api.aiproxy.io", // 10
|
"https://api.aiproxy.io", // 10
|
||||||
|
"", // 11
|
||||||
}
|
}
|
||||||
|
59
controller/relay-palm.go
Normal file
59
controller/relay-palm.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PaLMChatMessage struct {
|
||||||
|
Author string `json:"author"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMFilter struct {
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||||
|
type PaLMChatRequest struct {
|
||||||
|
Prompt []Message `json:"prompt"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
CandidateCount int `json:"candidateCount"`
|
||||||
|
TopP float64 `json:"topP"`
|
||||||
|
TopK int `json:"topK"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||||
|
type PaLMChatResponse struct {
|
||||||
|
Candidates []Message `json:"candidates"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
Filters []PaLMFilter `json:"filters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode {
|
||||||
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
|
||||||
|
messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages))
|
||||||
|
for _, message := range openAIRequest.Messages {
|
||||||
|
var author string
|
||||||
|
if message.Role == "user" {
|
||||||
|
author = "0"
|
||||||
|
} else {
|
||||||
|
author = "1"
|
||||||
|
}
|
||||||
|
messages = append(messages, PaLMChatMessage{
|
||||||
|
Author: author,
|
||||||
|
Content: message.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
request := PaLMChatRequest{
|
||||||
|
Prompt: nil,
|
||||||
|
Temperature: openAIRequest.Temperature,
|
||||||
|
CandidateCount: openAIRequest.N,
|
||||||
|
TopP: openAIRequest.TopP,
|
||||||
|
TopK: openAIRequest.MaxTokens,
|
||||||
|
}
|
||||||
|
// TODO: forward request to PaLM & convert response
|
||||||
|
fmt.Print(request)
|
||||||
|
return nil
|
||||||
|
}
|
@ -19,6 +19,19 @@ type Message struct {
|
|||||||
Name *string `json:"name,omitempty"`
|
Name *string `json:"name,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/chat
|
||||||
|
|
||||||
|
type GeneralOpenAIRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
N int `json:"n"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
@ -101,8 +114,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
consumeQuota := c.GetBool("consume_quota")
|
||||||
var textRequest TextRequest
|
var textRequest GeneralOpenAIRequest
|
||||||
if consumeQuota || channelType == common.ChannelTypeAzure {
|
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
|
||||||
requestBody, err := io.ReadAll(c.Request.Body)
|
requestBody, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
|
return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
|
||||||
@ -141,6 +154,9 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|||||||
model_ = strings.TrimSuffix(model_, "-0301")
|
model_ = strings.TrimSuffix(model_, "-0301")
|
||||||
model_ = strings.TrimSuffix(model_, "-0314")
|
model_ = strings.TrimSuffix(model_, "-0314")
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
||||||
|
} else if channelType == common.ChannelTypePaLM {
|
||||||
|
err := relayPaLM(textRequest, c)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
|
promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
|
||||||
|
Loading…
Reference in New Issue
Block a user