Refactored code to handle both string and

structured message content
This commit is contained in:
ckt1031 2023-11-17 16:58:24 +08:00
parent 58bb3ab6f6
commit 209d248535
10 changed files with 66 additions and 31 deletions

View File

@ -4,12 +4,13 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
@ -48,7 +49,7 @@ type AIProxyLibraryStreamResponse struct {
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := "" query := ""
if len(request.Messages) != 0 { if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].Content query = request.Messages[len(request.Messages)-1].Content.(string)
} }
return &AIProxyLibraryRequest{ return &AIProxyLibraryRequest{
Model: request.Model, Model: request.Model,

View File

@ -3,11 +3,12 @@ package controller
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
@ -88,18 +89,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
message := request.Messages[i] message := request.Messages[i]
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, AliMessage{ messages = append(messages, AliMessage{
User: message.Content, User: message.Content.(string),
Bot: "Okay", Bot: "Okay",
}) })
continue continue
} else { } else {
if i == len(request.Messages)-1 { if i == len(request.Messages)-1 {
prompt = message.Content prompt = message.Content.(string)
break break
} }
messages = append(messages, AliMessage{ messages = append(messages, AliMessage{
User: message.Content, User: message.Content.(string),
Bot: request.Messages[i+1].Content, Bot: request.Messages[i+1].Content.(string),
}) })
i++ i++
} }

View File

@ -5,13 +5,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
) )
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
@ -89,7 +90,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, BaiduMessage{ messages = append(messages, BaiduMessage{
Role: "user", Role: "user",
Content: message.Content, Content: message.Content.(string),
}) })
messages = append(messages, BaiduMessage{ messages = append(messages, BaiduMessage{
Role: "assistant", Role: "assistant",
@ -98,7 +99,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
} else { } else {
messages = append(messages, BaiduMessage{ messages = append(messages, BaiduMessage{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.Content.(string),
}) })
} }
} }

View File

@ -4,11 +4,12 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
@ -132,7 +133,9 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
if textResponse.Usage.TotalTokens == 0 { if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0 completionTokens := 0
for _, choice := range textResponse.Choices { for _, choice := range textResponse.Choices {
completionTokens += countTokenText(choice.Message.Content, model) if content, ok := choice.Message.Content.(string); ok {
completionTokens += countTokenText(content, model)
}
} }
textResponse.Usage = Usage{ textResponse.Usage = Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,

View File

@ -3,10 +3,11 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"github.com/gin-gonic/gin"
) )
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
@ -59,7 +60,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
} }
for _, message := range textRequest.Messages { for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{ palmMessage := PaLMChatMessage{
Content: message.Content, Content: message.Content.(string),
} }
if message.Role == "user" { if message.Role == "user" {
palmMessage.Author = "0" palmMessage.Author = "0"

View File

@ -8,13 +8,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
// https://cloud.tencent.com/document/product/1729/97732 // https://cloud.tencent.com/document/product/1729/97732
@ -84,7 +85,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, TencentMessage{ messages = append(messages, TencentMessage{
Role: "user", Role: "user",
Content: message.Content, Content: message.Content.(string),
}) })
messages = append(messages, TencentMessage{ messages = append(messages, TencentMessage{
Role: "assistant", Role: "assistant",
@ -93,7 +94,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
continue continue
} }
messages = append(messages, TencentMessage{ messages = append(messages, TencentMessage{
Content: message.Content, Content: message.Content.(string),
Role: message.Role, Role: message.Role,
}) })
} }

View File

@ -3,13 +3,14 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
) )
var stopFinishReason = "stop" var stopFinishReason = "stop"
@ -84,7 +85,17 @@ func countTokenMessages(messages []Message, model string) int {
tokenNum := 0 tokenNum := 0
for _, message := range messages { for _, message := range messages {
tokenNum += tokensPerMessage tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Content)
if content, ok := message.Content.(string); ok {
tokenNum += getTokenNum(tokenEncoder, content)
} else if content, ok := message.Content.([]MessageContent); ok {
for _, item := range content {
if item.Type == "text" {
tokenNum += getTokenNum(tokenEncoder, item.Text)
}
}
}
tokenNum += getTokenNum(tokenEncoder, message.Role) tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil { if message.Name != nil {
tokenNum += tokensPerName tokenNum += tokensPerName

View File

@ -6,14 +6,15 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/common"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
) )
// https://console.xfyun.cn/services/cbm // https://console.xfyun.cn/services/cbm
@ -81,7 +82,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, XunfeiMessage{ messages = append(messages, XunfeiMessage{
Role: "user", Role: "user",
Content: message.Content, Content: message.Content.(string),
}) })
messages = append(messages, XunfeiMessage{ messages = append(messages, XunfeiMessage{
Role: "assistant", Role: "assistant",
@ -90,7 +91,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
} else { } else {
messages = append(messages, XunfeiMessage{ messages = append(messages, XunfeiMessage{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.Content.(string),
}) })
} }
} }

View File

@ -3,14 +3,15 @@ package controller
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
) )
// https://open.bigmodel.cn/doc/api#chatglm_std // https://open.bigmodel.cn/doc/api#chatglm_std
@ -114,7 +115,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, ZhipuMessage{ messages = append(messages, ZhipuMessage{
Role: "system", Role: "system",
Content: message.Content, Content: message.Content.(string),
}) })
messages = append(messages, ZhipuMessage{ messages = append(messages, ZhipuMessage{
Role: "user", Role: "user",
@ -123,7 +124,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
} else { } else {
messages = append(messages, ZhipuMessage{ messages = append(messages, ZhipuMessage{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.Content.(string),
}) })
} }
} }

View File

@ -10,9 +10,23 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type MessageImage struct {
URL string `json:"url"`
Detail string `json:"detail"`
}
type MessageContent struct {
Type string `json:"type"`
Text string `json:"text"`
ImageURL MessageImage `json:"image_url"`
}
type ContentInterface interface{}
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` // Content string or MessageContent
Content ContentInterface `json:"content"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
} }