From 209d248535d771a53145744c86ca2fa59a71ca69 Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:58:24 +0800 Subject: [PATCH] Refactored code to handle both string and structured message content --- controller/relay-aiproxy.go | 5 +++-- controller/relay-ali.go | 11 ++++++----- controller/relay-baidu.go | 7 ++++--- controller/relay-openai.go | 7 +++++-- controller/relay-palm.go | 5 +++-- controller/relay-tencent.go | 7 ++++--- controller/relay-utils.go | 17 ++++++++++++++--- controller/relay-xunfei.go | 9 +++++---- controller/relay-zhipu.go | 9 +++++---- controller/relay.go | 20 +++++++++++++++++--- 10 files changed, 66 insertions(+), 31 deletions(-) diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go index d0159ce8..cb2fd80e 100644 --- a/controller/relay-aiproxy.go +++ b/controller/relay-aiproxy.go @@ -4,12 +4,13 @@ import ( "bufio" "encoding/json" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "strconv" "strings" + + "github.com/gin-gonic/gin" ) // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 @@ -48,7 +49,7 @@ type AIProxyLibraryStreamResponse struct { func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { query := "" if len(request.Messages) != 0 { - query = request.Messages[len(request.Messages)-1].Content + query = request.Messages[len(request.Messages)-1].Content.(string) } return &AIProxyLibraryRequest{ Model: request.Model, diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 50dc743c..9b7cd209 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -3,11 +3,12 @@ package controller import ( "bufio" "encoding/json" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "strings" + + "github.com/gin-gonic/gin" ) // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r @@ -88,18 +89,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { message := request.Messages[i] if message.Role == "system" { messages = append(messages, AliMessage{ - User: message.Content, + User: message.Content.(string), Bot: "Okay", }) continue } else { if i == len(request.Messages)-1 { - prompt = message.Content + prompt = message.Content.(string) break } messages = append(messages, AliMessage{ - User: message.Content, - Bot: request.Messages[i+1].Content, + User: message.Content.(string), + Bot: request.Messages[i+1].Content.(string), }) i++ } diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index ed08ac04..93e2f1c4 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -5,13 +5,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "strings" "sync" "time" + + "github.com/gin-gonic/gin" ) // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 @@ -89,7 +90,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { if message.Role == "system" { messages = append(messages, BaiduMessage{ Role: "user", - Content: message.Content, + Content: message.Content.(string), }) messages = append(messages, BaiduMessage{ Role: "assistant", @@ -98,7 +99,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { } else { messages = append(messages, BaiduMessage{ Role: message.Role, - Content: message.Content, + Content: message.Content.(string), }) } } diff --git a/controller/relay-openai.go b/controller/relay-openai.go index 6bdfbc08..513c88b7 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -4,11 +4,12 @@ import ( "bufio" "bytes" "encoding/json" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "strings" + + "github.com/gin-gonic/gin" ) func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { @@ -132,7 +133,9 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp if textResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += countTokenText(choice.Message.Content, model) + if content, ok := choice.Message.Content.(string); ok { + completionTokens += countTokenText(content, model) + } } textResponse.Usage = Usage{ PromptTokens: promptTokens, diff --git a/controller/relay-palm.go b/controller/relay-palm.go index a705b318..6970035a 100644 --- a/controller/relay-palm.go +++ b/controller/relay-palm.go @@ -3,10 +3,11 @@ package controller import ( "encoding/json" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" + + "github.com/gin-gonic/gin" ) // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body @@ -59,7 +60,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { } for _, message := range textRequest.Messages { palmMessage := PaLMChatMessage{ - Content: message.Content, + Content: message.Content.(string), } if message.Role == "user" { palmMessage.Author = "0" diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go index 024468bc..606ea1a4 100644 --- a/controller/relay-tencent.go +++ b/controller/relay-tencent.go @@ -8,13 +8,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "sort" "strconv" "strings" + + "github.com/gin-gonic/gin" ) // https://cloud.tencent.com/document/product/1729/97732 @@ -84,7 +85,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { if message.Role == "system" { messages = append(messages, TencentMessage{ Role: "user", - Content: message.Content, + Content: message.Content.(string), }) messages = append(messages, TencentMessage{ Role: "assistant", @@ -93,7 +94,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { continue } messages = append(messages, TencentMessage{ - Content: message.Content, + Content: message.Content.(string), Role: message.Role, }) } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index cf5d9b69..ce9f0e49 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -3,13 +3,14 @@ package controller import ( "encoding/json" "fmt" - "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" "io" "net/http" "one-api/common" "strconv" "strings" + + "github.com/gin-gonic/gin" + "github.com/pkoukk/tiktoken-go" ) var stopFinishReason = "stop" @@ -84,7 +85,17 @@ func countTokenMessages(messages []Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.Content) + + if content, ok := message.Content.(string); ok { + tokenNum += getTokenNum(tokenEncoder, content) + } else if content, ok := message.Content.([]MessageContent); ok { + for _, item := range content { + if item.Type == "text" { + tokenNum += getTokenNum(tokenEncoder, item.Text) + } + } + } + tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 91fb6042..b6b9fe27 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -6,14 +6,15 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" "net/http" "net/url" "one-api/common" "strings" "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) // https://console.xfyun.cn/services/cbm @@ -81,7 +82,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma if message.Role == "system" { messages = append(messages, XunfeiMessage{ Role: "user", - Content: message.Content, + Content: message.Content.(string), }) messages = append(messages, XunfeiMessage{ Role: "assistant", @@ -90,7 +91,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma } else { messages = append(messages, XunfeiMessage{ Role: message.Role, - Content: message.Content, + Content: message.Content.(string), }) } } diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go index 7a4a582d..77c84f80 100644 --- a/controller/relay-zhipu.go +++ b/controller/relay-zhipu.go @@ -3,14 +3,15 @@ package controller import ( "bufio" "encoding/json" - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt" "io" "net/http" "one-api/common" "strings" "sync" "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt" ) // https://open.bigmodel.cn/doc/api#chatglm_std @@ -114,7 +115,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { if message.Role == "system" { messages = append(messages, ZhipuMessage{ Role: "system", - Content: message.Content, + Content: message.Content.(string), }) messages = append(messages, ZhipuMessage{ Role: "user", @@ -123,7 +124,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { } else { messages = append(messages, ZhipuMessage{ Role: message.Role, - Content: message.Content, + Content: message.Content.(string), }) } } diff --git a/controller/relay.go b/controller/relay.go index 1926110e..80805422 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -10,10 +10,24 @@ import ( "github.com/gin-gonic/gin" ) +type MessageImage struct { + URL string `json:"url"` + Detail string `json:"detail"` +} + +type MessageContent struct { + Type string `json:"type"` + Text string `json:"text"` + ImageURL MessageImage `json:"image_url"` +} + +type ContentInterface interface{} + type Message struct { - Role string `json:"role"` - Content string `json:"content"` - Name *string `json:"name,omitempty"` + Role string `json:"role"` + // Content string or MessageContent + Content ContentInterface `json:"content"` + Name *string `json:"name,omitempty"` } const (