Merge branch 'songquanpeng' into sync_upstream

This commit is contained in:
Martial BE 2023-12-25 11:23:28 +08:00
commit 47b72b850f
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
24 changed files with 251 additions and 49 deletions

View File

@ -15,7 +15,22 @@ import (
_ "golang.org/x/image/webp" _ "golang.org/x/image/webp"
) )
func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
func GetImageSizeFromUrl(url string) (width int, height int, err error) { func GetImageSizeFromUrl(url string) (width int, height int, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url) resp, err := http.Get(url)
if err != nil { if err != nil {
return return
@ -28,6 +43,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
return img.Width, img.Height, nil return img.Width, img.Height, nil
} }
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}
var ( var (
reg = regexp.MustCompile(`data:image/([^;]+);base64,`) reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
) )

View File

@ -152,3 +152,20 @@ func TestGetImageSize(t *testing.T) {
}) })
} }
} }
func TestGetImageSizeFromBase64(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
width, height, err := img.GetImageSizeFromBase64(encoded)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}

View File

@ -84,6 +84,7 @@ var ModelRatio = map[string]float64{
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1, "PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
@ -115,6 +116,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
} }
func GetModelRatio(name string) float64 { func GetModelRatio(name string) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
ratio, ok := ModelRatio[name] ratio, ok := ModelRatio[name]
if !ok { if !ok {
SysError("model ratio not found: " + name) SysError("model ratio not found: " + name)

View File

@ -433,6 +433,15 @@ func init() {
Root: "gemini-pro", Root: "gemini-pro",
Parent: nil, Parent: nil,
}, },
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{ {
Id: "chatglm_turbo", Id: "chatglm_turbo",
Object: "model", Object: "model",

4
go.mod
View File

@ -16,7 +16,7 @@ require (
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5 github.com/pkoukk/tiktoken-go v0.1.5
github.com/stretchr/testify v1.8.3 github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.14.0 golang.org/x/crypto v0.17.0
golang.org/x/image v0.14.0 golang.org/x/image v0.14.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2 gorm.io/driver/postgres v1.5.2
@ -59,7 +59,7 @@ require (
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.17.0 // indirect golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect

8
go.sum
View File

@ -152,8 +152,8 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@ -166,8 +166,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=

View File

@ -5,6 +5,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"runtime/debug"
) )
func RelayPanicRecover() gin.HandlerFunc { func RelayPanicRecover() gin.HandlerFunc {
@ -12,6 +13,7 @@ func RelayPanicRecover() gin.HandlerFunc {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err)) common.SysError(fmt.Sprintf("panic detected: %v", err))
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{ "error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),

View File

@ -39,6 +39,7 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI
ID: aliResponse.RequestId, ID: aliResponse.RequestId,
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: aliResponse.Model,
Choices: []types.ChatCompletionChoice{choice}, Choices: []types.ChatCompletionChoice{choice},
Usage: &types.Usage{ Usage: &types.Usage{
PromptTokens: aliResponse.Usage.InputTokens, PromptTokens: aliResponse.Usage.InputTokens,
@ -50,6 +51,8 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI
return return
} }
const AliEnableSearchModelSuffix = "-internet"
// 获取聊天请求体 // 获取聊天请求体
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest { func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages)) messages := make([]AliMessage, 0, len(request.Messages))
@ -60,11 +63,23 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *
Role: strings.ToLower(message.Role), Role: strings.ToLower(message.Role),
}) })
} }
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
}
return &AliChatRequest{ return &AliChatRequest{
Model: request.Model, Model: aliModel,
Input: AliInput{ Input: AliInput{
Messages: messages, Messages: messages,
}, },
Parameters: AliParameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
},
} }
} }
@ -86,7 +101,7 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
} }
if request.Stream { if request.Stream {
usage, errWithCode = p.sendStreamRequest(req) usage, errWithCode = p.sendStreamRequest(req, request.Model)
if errWithCode != nil { if errWithCode != nil {
return return
} }
@ -100,7 +115,9 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
} }
} else { } else {
aliResponse := &AliChatResponse{} aliResponse := &AliChatResponse{
Model: request.Model,
}
errWithCode = p.SendRequest(req, aliResponse, false) errWithCode = p.SendRequest(req, aliResponse, false)
if errWithCode != nil { if errWithCode != nil {
return return
@ -128,14 +145,14 @@ func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ty
ID: aliResponse.RequestId, ID: aliResponse.RequestId,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: "ernie-bot", Model: aliResponse.Model,
Choices: []types.ChatCompletionStreamChoice{choice}, Choices: []types.ChatCompletionStreamChoice{choice},
} }
return &response return &response
} }
// 发送流请求 // 发送流请求
func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close() defer req.Body.Close()
usage = &types.Usage{} usage = &types.Usage{}
@ -181,7 +198,7 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage,
stopChan <- true stopChan <- true
}() }()
common.SetEventStreamHeaders(p.Context) common.SetEventStreamHeaders(p.Context)
lastResponseText := "" // lastResponseText := ""
p.Context.Stream(func(w io.Writer) bool { p.Context.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
@ -196,9 +213,10 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage,
usage.CompletionTokens = aliResponse.Usage.OutputTokens usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
} }
aliResponse.Model = model
response := p.streamResponseAli2OpenAI(&aliResponse) response := p.streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) // response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text // lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysError("error marshalling stream response: " + err.Error())

View File

@ -23,10 +23,11 @@ type AliInput struct {
} }
type AliParameters struct { type AliParameters struct {
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"` Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"` EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
} }
type AliChatRequest struct { type AliChatRequest struct {
@ -43,6 +44,7 @@ type AliOutput struct {
type AliChatResponse struct { type AliChatResponse struct {
Output AliOutput `json:"output"` Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"` Usage AliUsage `json:"usage"`
Model string `json:"model,omitempty"`
AliError AliError
} }

View File

@ -88,13 +88,15 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
} }
if request.Stream { if request.Stream {
usage, errWithCode = p.sendStreamRequest(req) usage, errWithCode = p.sendStreamRequest(req, request.Model)
if errWithCode != nil { if errWithCode != nil {
return return
} }
} else { } else {
baiduChatRequest := &BaiduChatResponse{} baiduChatRequest := &BaiduChatResponse{
Model: request.Model,
}
errWithCode = p.SendRequest(req, baiduChatRequest, false) errWithCode = p.SendRequest(req, baiduChatRequest, false)
if errWithCode != nil { if errWithCode != nil {
return return
@ -117,13 +119,13 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
ID: baiduResponse.Id, ID: baiduResponse.Id,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: baiduResponse.Created, Created: baiduResponse.Created,
Model: "ernie-bot", Model: baiduResponse.Model,
Choices: []types.ChatCompletionStreamChoice{choice}, Choices: []types.ChatCompletionStreamChoice{choice},
} }
return &response return &response
} }
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close() defer req.Body.Close()
usage = &types.Usage{} usage = &types.Usage{}
@ -180,6 +182,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage
usage.PromptTokens = baiduResponse.Usage.PromptTokens usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
} }
baiduResponse.Model = model
response := p.streamResponseBaidu2OpenAI(&baiduResponse) response := p.streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {

View File

@ -32,6 +32,7 @@ type BaiduChatResponse struct {
IsTruncated bool `json:"is_truncated"` IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"` NeedClearHistory bool `json:"need_clear_history"`
Usage *types.Usage `json:"usage"` Usage *types.Usage `json:"usage"`
Model string `json:"model,omitempty"`
BaiduError BaiduError
} }

View File

@ -38,6 +38,7 @@ func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (Open
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice}, Choices: []types.ChatCompletionChoice{choice},
Model: claudeResponse.Model,
} }
completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model) completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model)

View File

@ -32,7 +32,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string)
version = p.Context.GetString("api_version") version = p.Context.GetString("api_version")
} }
return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", baseURL, version, modelName, requestURL, p.Context.GetString("api_key")) return fmt.Sprintf("%s/%s/models/%s:%s", baseURL, version, modelName, requestURL)
} }
@ -40,6 +40,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string)
func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) { func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string) headers = make(map[string]string)
p.CommonRequestHeaders(headers) p.CommonRequestHeaders(headers)
headers["x-goog-api-key"] = p.Context.GetString("api_key")
return headers return headers
} }

View File

@ -7,11 +7,16 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/image"
"one-api/providers/base" "one-api/providers/base"
"one-api/types" "one-api/types"
"strings" "strings"
) )
const (
GeminiVisionMaxImageNum = 16
)
func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if len(response.Candidates) == 0 { if len(response.Candidates) == 0 {
return nil, &types.OpenAIErrorWithStatusCode{ return nil, &types.OpenAIErrorWithStatusCode{
@ -29,6 +34,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: response.Model,
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)), Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
} }
for i, candidate := range response.Candidates { for i, candidate := range response.Candidates {
@ -46,7 +52,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
fullTextResponse.Choices = append(fullTextResponse.Choices, choice) fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
} }
completionTokens := common.CountTokenText(response.GetResponseText(), "gemini-pro") completionTokens := common.CountTokenText(response.GetResponseText(), response.Model)
response.Usage.CompletionTokens = completionTokens response.Usage.CompletionTokens = completionTokens
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
@ -98,6 +104,31 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
}, },
}, },
} }
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == types.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == types.ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.URL)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model // there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" { if content.Role == "assistant" {
content.Role = "model" content.Role = "model"
@ -142,7 +173,7 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
if request.Stream { if request.Stream {
var responseText string var responseText string
errWithCode, responseText = p.sendStreamRequest(req) errWithCode, responseText = p.sendStreamRequest(req, request.Model)
if errWithCode != nil { if errWithCode != nil {
return return
} }
@ -155,6 +186,7 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
} else { } else {
var geminiResponse = &GeminiChatResponse{ var geminiResponse = &GeminiChatResponse{
Model: request.Model,
Usage: &types.Usage{ Usage: &types.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
}, },
@ -170,18 +202,18 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
} }
func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse { // func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice // var choice types.ChatCompletionStreamChoice
choice.Delta.Content = geminiResponse.GetResponseText() // choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &base.StopFinishReason // choice.FinishReason = &base.StopFinishReason
var response types.ChatCompletionStreamResponse // var response types.ChatCompletionStreamResponse
response.Object = "chat.completion.chunk" // response.Object = "chat.completion.chunk"
response.Model = "gemini" // response.Model = "gemini"
response.Choices = []types.ChatCompletionStreamChoice{choice} // response.Choices = []types.ChatCompletionStreamChoice{choice}
return &response // return &response
} // }
func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { func (p *GeminiProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
defer req.Body.Close() defer req.Body.Close()
// 发送请求 // 发送请求
@ -235,7 +267,7 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
Content string `json:"content"` Content string `json:"content"`
} }
var dummy dummyStruct var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy) json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content responseText += dummy.Content
var choice types.ChatCompletionStreamChoice var choice types.ChatCompletionStreamChoice
choice.Delta.Content = dummy.Content choice.Delta.Content = dummy.Content
@ -243,7 +275,7 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: "gemini-pro", Model: model,
Choices: []types.ChatCompletionStreamChoice{choice}, Choices: []types.ChatCompletionStreamChoice{choice},
} }
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)

View File

@ -46,6 +46,7 @@ type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"` Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
Usage *types.Usage `json:"usage,omitempty"` Usage *types.Usage `json:"usage,omitempty"`
Model string `json:"model,omitempty"`
} }
type GeminiChatCandidate struct { type GeminiChatCandidate struct {

View File

@ -29,6 +29,7 @@ type PalmProvider struct {
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) { func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string) headers = make(map[string]string)
p.CommonRequestHeaders(headers) p.CommonRequestHeaders(headers)
headers["x-goog-api-key"] = p.Context.GetString("api_key")
return headers return headers
} }
@ -37,5 +38,5 @@ func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
func (p *PalmProvider) GetFullRequestURL(requestURL string, modelName string) string { func (p *PalmProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
return fmt.Sprintf("%s%s?key=%s", baseURL, requestURL, p.Context.GetString("api_key")) return fmt.Sprintf("%s%s", baseURL, requestURL)
} }

View File

@ -43,6 +43,7 @@ func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (Open
palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens
fullTextResponse.Usage = palmResponse.Usage fullTextResponse.Usage = palmResponse.Usage
fullTextResponse.Model = palmResponse.Model
return fullTextResponse, nil return fullTextResponse, nil
} }

View File

@ -27,6 +27,7 @@ func (TencentResponse *TencentChatResponse) ResponseHandler(resp *http.Response)
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Usage: TencentResponse.Usage, Usage: TencentResponse.Usage,
Model: TencentResponse.Model,
} }
if len(TencentResponse.Choices) > 0 { if len(TencentResponse.Choices) > 0 {
choice := types.ChatCompletionChoice{ choice := types.ChatCompletionChoice{
@ -100,7 +101,7 @@ func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isMod
if request.Stream { if request.Stream {
var responseText string var responseText string
errWithCode, responseText = p.sendStreamRequest(req) errWithCode, responseText = p.sendStreamRequest(req, request.Model)
if errWithCode != nil { if errWithCode != nil {
return return
} }
@ -112,7 +113,9 @@ func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isMod
usage.TotalTokens = promptTokens + usage.CompletionTokens usage.TotalTokens = promptTokens + usage.CompletionTokens
} else { } else {
tencentResponse := &TencentChatResponse{} tencentResponse := &TencentChatResponse{
Model: request.Model,
}
errWithCode = p.SendRequest(req, tencentResponse, false) errWithCode = p.SendRequest(req, tencentResponse, false)
if errWithCode != nil { if errWithCode != nil {
return return
@ -128,7 +131,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC
response := types.ChatCompletionStreamResponse{ response := types.ChatCompletionStreamResponse{
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: "tencent-hunyuan", Model: TencentResponse.Model,
} }
if len(TencentResponse.Choices) > 0 { if len(TencentResponse.Choices) > 0 {
var choice types.ChatCompletionStreamChoice var choice types.ChatCompletionStreamChoice
@ -141,7 +144,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC
return &response return &response
} }
func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { func (p *TencentProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
defer req.Body.Close() defer req.Body.Close()
// 发送请求 // 发送请求
resp, err := common.HttpClient.Do(req) resp, err := common.HttpClient.Do(req)
@ -195,6 +198,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return true return true
} }
TencentResponse.Model = model
response := p.streamResponseTencent2OpenAI(&TencentResponse) response := p.streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 { if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content responseText += response.Choices[0].Delta.Content

View File

@ -58,4 +58,5 @@ type TencentChatResponse struct {
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值 Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释 Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参 ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
Model string `json:"model,omitempty"` // 模型名称
} }

View File

@ -28,6 +28,7 @@ func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAI
ID: zhipuResponse.Data.TaskId, ID: zhipuResponse.Data.TaskId,
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: zhipuResponse.Model,
Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)), Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)),
Usage: &zhipuResponse.Data.Usage, Usage: &zhipuResponse.Data.Usage,
} }
@ -94,13 +95,15 @@ func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModel
} }
if request.Stream { if request.Stream {
errWithCode, usage = p.sendStreamRequest(req) errWithCode, usage = p.sendStreamRequest(req, request.Model)
if errWithCode != nil { if errWithCode != nil {
return return
} }
} else { } else {
zhipuResponse := &ZhipuResponse{} zhipuResponse := &ZhipuResponse{
Model: request.Model,
}
errWithCode = p.SendRequest(req, zhipuResponse, false) errWithCode = p.SendRequest(req, zhipuResponse, false)
if errWithCode != nil { if errWithCode != nil {
return return
@ -132,13 +135,13 @@ func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStrea
ID: zhipuResponse.RequestId, ID: zhipuResponse.RequestId,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: "chatglm", Model: zhipuResponse.Model,
Choices: []types.ChatCompletionStreamChoice{choice}, Choices: []types.ChatCompletionStreamChoice{choice},
} }
return &response, &zhipuResponse.Usage return &response, &zhipuResponse.Usage
} }
func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, *types.Usage) { func (p *ZhipuProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, *types.Usage) {
defer req.Body.Close() defer req.Body.Close()
// 发送请求 // 发送请求
@ -159,7 +162,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
} }
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Contains(string(data), ":") {
return i + 2, data[0:i], nil return i + 2, data[0:i], nil
} }
if atEOF { if atEOF {
@ -195,6 +198,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
select { select {
case data := <-dataChan: case data := <-dataChan:
response := p.streamResponseZhipu2OpenAI(data) response := p.streamResponseZhipu2OpenAI(data)
response.Model = model
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysError("error marshalling stream response: " + err.Error())
@ -209,6 +213,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return true return true
} }
zhipuResponse.Model = model
response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse) response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {

View File

@ -31,6 +31,7 @@ type ZhipuResponse struct {
Msg string `json:"msg"` Msg string `json:"msg"`
Success bool `json:"success"` Success bool `json:"success"`
Data ZhipuResponseData `json:"data"` Data ZhipuResponseData `json:"data"`
Model string `json:"model,omitempty"`
} }
type ZhipuStreamMetaResponse struct { type ZhipuStreamMetaResponse struct {
@ -38,6 +39,7 @@ type ZhipuStreamMetaResponse struct {
TaskId string `json:"task_id"` TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"` TaskStatus string `json:"task_status"`
types.Usage `json:"usage"` types.Usage `json:"usage"`
Model string `json:"model,omitempty"`
} }
type zhipuTokenData struct { type zhipuTokenData struct {

View File

@ -1,3 +1,9 @@
[//]: # (请按照以下格式关联 issue)
[//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR谢谢)
[//]: # (项目维护者一般仅在周末处理 PR因此如若未能及时回复希望能理解)
[//]: # (开发者交流群910657413)
[//]: # (请在提交 PR 之前删除上面的注释)
close #issue_number close #issue_number
我已确认该 PR 已自测通过,相关截图如下: 我已确认该 PR 已自测通过,相关截图如下:

View File

@ -1,5 +1,10 @@
package types package types
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
type ChatCompletionMessage struct { type ChatCompletionMessage struct {
Role string `json:"role"` Role string `json:"role"`
Content any `json:"content"` Content any `json:"content"`
@ -33,6 +38,47 @@ func (m ChatCompletionMessage) StringContent() string {
return "" return ""
} }
func (m ChatCompletionMessage) ParseContent() []ChatMessagePart {
var contentList []ChatMessagePart
content, ok := m.Content.(string)
if ok {
contentList = append(contentList, ChatMessagePart{
Type: ContentTypeText,
Text: content,
})
return contentList
}
anyList, ok := m.Content.([]any)
if ok {
for _, contentItem := range anyList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, ChatMessagePart{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
contentList = append(contentList, ChatMessagePart{
Type: ContentTypeImageURL,
ImageURL: &ChatMessageImageURL{
URL: subObj["url"].(string),
},
})
}
}
}
return contentList
}
return nil
}
type ChatMessageImageURL struct { type ChatMessageImageURL struct {
URL string `json:"url,omitempty"` URL string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"` Detail string `json:"detail,omitempty"`

View File

@ -71,7 +71,17 @@ const typeConfig = {
other: '插件参数' other: '插件参数'
}, },
input: { input: {
models: ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1'] models: [
'qwen-turbo',
'qwen-plus',
'qwen-max',
'qwen-max-longcontext',
'text-embedding-v1',
'qwen-turbo-internet',
'qwen-plus-internet',
'qwen-max-internet',
'qwen-max-longcontext-internet'
]
}, },
prompt: { prompt: {
other: '请输入插件参数,即 X-DashScope-Plugin 请求头的取值' other: '请输入插件参数,即 X-DashScope-Plugin 请求头的取值'