Refactored code to handle both string and
structured message content
This commit is contained in:
parent
58bb3ab6f6
commit
209d248535
@ -4,12 +4,13 @@ import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
||||
@ -48,7 +49,7 @@ type AIProxyLibraryStreamResponse struct {
|
||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
||||
query := ""
|
||||
if len(request.Messages) != 0 {
|
||||
query = request.Messages[len(request.Messages)-1].Content
|
||||
query = request.Messages[len(request.Messages)-1].Content.(string)
|
||||
}
|
||||
return &AIProxyLibraryRequest{
|
||||
Model: request.Model,
|
||||
|
@ -3,11 +3,12 @@ package controller
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
@ -88,18 +89,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.Content,
|
||||
User: message.Content.(string),
|
||||
Bot: "Okay",
|
||||
})
|
||||
continue
|
||||
} else {
|
||||
if i == len(request.Messages)-1 {
|
||||
prompt = message.Content
|
||||
prompt = message.Content.(string)
|
||||
break
|
||||
}
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.Content,
|
||||
Bot: request.Messages[i+1].Content,
|
||||
User: message.Content.(string),
|
||||
Bot: request.Messages[i+1].Content.(string),
|
||||
})
|
||||
i++
|
||||
}
|
||||
|
@ -5,13 +5,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||
@ -89,7 +90,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
Content: message.Content.(string),
|
||||
})
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "assistant",
|
||||
@ -98,7 +99,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
} else {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
Content: message.Content.(string),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -4,11 +4,12 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
|
||||
@ -132,7 +133,9 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
|
||||
if textResponse.Usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range textResponse.Choices {
|
||||
completionTokens += countTokenText(choice.Message.Content, model)
|
||||
if content, ok := choice.Message.Content.(string); ok {
|
||||
completionTokens += countTokenText(content, model)
|
||||
}
|
||||
}
|
||||
textResponse.Usage = Usage{
|
||||
PromptTokens: promptTokens,
|
||||
|
@ -3,10 +3,11 @@ package controller
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||
@ -59,7 +60,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
palmMessage := PaLMChatMessage{
|
||||
Content: message.Content,
|
||||
Content: message.Content.(string),
|
||||
}
|
||||
if message.Role == "user" {
|
||||
palmMessage.Author = "0"
|
||||
|
@ -8,13 +8,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://cloud.tencent.com/document/product/1729/97732
|
||||
@ -84,7 +85,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
Content: message.Content.(string),
|
||||
})
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "assistant",
|
||||
@ -93,7 +94,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
||||
continue
|
||||
}
|
||||
messages = append(messages, TencentMessage{
|
||||
Content: message.Content,
|
||||
Content: message.Content.(string),
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
|
@ -3,13 +3,14 @@ package controller
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
var stopFinishReason = "stop"
|
||||
@ -84,7 +85,17 @@ func countTokenMessages(messages []Message, model string) int {
|
||||
tokenNum := 0
|
||||
for _, message := range messages {
|
||||
tokenNum += tokensPerMessage
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Content)
|
||||
|
||||
if content, ok := message.Content.(string); ok {
|
||||
tokenNum += getTokenNum(tokenEncoder, content)
|
||||
} else if content, ok := message.Content.([]MessageContent); ok {
|
||||
for _, item := range content {
|
||||
if item.Type == "text" {
|
||||
tokenNum += getTokenNum(tokenEncoder, item.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
|
@ -6,14 +6,15 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// https://console.xfyun.cn/services/cbm
|
||||
@ -81,7 +82,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
Content: message.Content.(string),
|
||||
})
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "assistant",
|
||||
@ -90,7 +91,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
|
||||
} else {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
Content: message.Content.(string),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -3,14 +3,15 @@ package controller
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// https://open.bigmodel.cn/doc/api#chatglm_std
|
||||
@ -114,7 +115,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: "system",
|
||||
Content: message.Content,
|
||||
Content: message.Content.(string),
|
||||
})
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: "user",
|
||||
@ -123,7 +124,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
||||
} else {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
Content: message.Content.(string),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -10,10 +10,24 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type MessageImage struct {
|
||||
URL string `json:"url"`
|
||||
Detail string `json:"detail"`
|
||||
}
|
||||
|
||||
type MessageContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageURL MessageImage `json:"image_url"`
|
||||
}
|
||||
|
||||
type ContentInterface interface{}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Role string `json:"role"`
|
||||
// Content string or MessageContent
|
||||
Content ContentInterface `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
|
Loading…
Reference in New Issue
Block a user