Refactored code to handle both string and

structured message content
This commit is contained in:
ckt1031 2023-11-17 16:58:24 +08:00
parent 58bb3ab6f6
commit 209d248535
10 changed files with 66 additions and 31 deletions

View File

@ -4,12 +4,13 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
@ -48,7 +49,7 @@ type AIProxyLibraryStreamResponse struct {
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].Content
query = request.Messages[len(request.Messages)-1].Content.(string)
}
return &AIProxyLibraryRequest{
Model: request.Model,

View File

@ -3,11 +3,12 @@ package controller
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
"github.com/gin-gonic/gin"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
@ -88,18 +89,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: message.Content,
User: message.Content.(string),
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.Content
prompt = message.Content.(string)
break
}
messages = append(messages, AliMessage{
User: message.Content,
Bot: request.Messages[i+1].Content,
User: message.Content.(string),
Bot: request.Messages[i+1].Content.(string),
})
i++
}

View File

@ -5,13 +5,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
@ -89,7 +90,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
if message.Role == "system" {
messages = append(messages, BaiduMessage{
Role: "user",
Content: message.Content,
Content: message.Content.(string),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
@ -98,7 +99,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
} else {
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.Content,
Content: message.Content.(string),
})
}
}

View File

@ -4,11 +4,12 @@ import (
"bufio"
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
"github.com/gin-gonic/gin"
)
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
@ -132,7 +133,9 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += countTokenText(choice.Message.Content, model)
if content, ok := choice.Message.Content.(string); ok {
completionTokens += countTokenText(content, model)
}
}
textResponse.Usage = Usage{
PromptTokens: promptTokens,

View File

@ -3,10 +3,11 @@ package controller
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"github.com/gin-gonic/gin"
)
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
@ -59,7 +60,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
Content: message.Content,
Content: message.Content.(string),
}
if message.Role == "user" {
palmMessage.Author = "0"

View File

@ -8,13 +8,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"sort"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
// https://cloud.tencent.com/document/product/1729/97732
@ -84,7 +85,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: message.Content,
Content: message.Content.(string),
})
messages = append(messages, TencentMessage{
Role: "assistant",
@ -93,7 +94,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
continue
}
messages = append(messages, TencentMessage{
Content: message.Content,
Content: message.Content.(string),
Role: message.Role,
})
}

View File

@ -3,13 +3,14 @@ package controller
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"io"
"net/http"
"one-api/common"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
)
var stopFinishReason = "stop"
@ -84,7 +85,17 @@ func countTokenMessages(messages []Message, model string) int {
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Content)
if content, ok := message.Content.(string); ok {
tokenNum += getTokenNum(tokenEncoder, content)
} else if content, ok := message.Content.([]MessageContent); ok {
for _, item := range content {
if item.Type == "text" {
tokenNum += getTokenNum(tokenEncoder, item.Text)
}
}
}
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName

View File

@ -6,14 +6,15 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"net/http"
"net/url"
"one-api/common"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
// https://console.xfyun.cn/services/cbm
@ -81,7 +82,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
Content: message.Content,
Content: message.Content.(string),
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
@ -90,7 +91,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
Content: message.Content,
Content: message.Content.(string),
})
}
}

View File

@ -3,14 +3,15 @@ package controller
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"io"
"net/http"
"one-api/common"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
)
// https://open.bigmodel.cn/doc/api#chatglm_std
@ -114,7 +115,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
if message.Role == "system" {
messages = append(messages, ZhipuMessage{
Role: "system",
Content: message.Content,
Content: message.Content.(string),
})
messages = append(messages, ZhipuMessage{
Role: "user",
@ -123,7 +124,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
} else {
messages = append(messages, ZhipuMessage{
Role: message.Role,
Content: message.Content,
Content: message.Content.(string),
})
}
}

View File

@ -10,10 +10,24 @@ import (
"github.com/gin-gonic/gin"
)
type MessageImage struct {
URL string `json:"url"`
Detail string `json:"detail"`
}
type MessageContent struct {
Type string `json:"type"`
Text string `json:"text"`
ImageURL MessageImage `json:"image_url"`
}
type ContentInterface interface{}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Name *string `json:"name,omitempty"`
Role string `json:"role"`
// Content string or MessageContent
Content ContentInterface `json:"content"`
Name *string `json:"name,omitempty"`
}
const (