diff --git a/common/image/image.go b/common/image/image.go index 93da6a06..a602936a 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -15,7 +15,22 @@ import ( _ "golang.org/x/image/webp" ) +func IsImageUrl(url string) (bool, error) { + resp, err := http.Head(url) + if err != nil { + return false, err + } + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { + return false, nil + } + return true, nil +} + func GetImageSizeFromUrl(url string) (width int, height int, err error) { + isImage, err := IsImageUrl(url) + if !isImage { + return + } resp, err := http.Get(url) if err != nil { return @@ -28,6 +43,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) { return img.Width, img.Height, nil } +func GetImageFromUrl(url string) (mimeType string, data string, err error) { + isImage, err := IsImageUrl(url) + if !isImage { + return + } + resp, err := http.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + buffer := bytes.NewBuffer(nil) + _, err = buffer.ReadFrom(resp.Body) + if err != nil { + return + } + mimeType = resp.Header.Get("Content-Type") + data = base64.StdEncoding.EncodeToString(buffer.Bytes()) + return +} + var ( reg = regexp.MustCompile(`data:image/([^;]+);base64,`) ) diff --git a/common/image/image_test.go b/common/image/image_test.go index 366eda6e..8e47b109 100644 --- a/common/image/image_test.go +++ b/common/image/image_test.go @@ -152,3 +152,20 @@ func TestGetImageSize(t *testing.T) { }) } } + +func TestGetImageSizeFromBase64(t *testing.T) { + for i, c := range cases { + t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + width, height, err := img.GetImageSizeFromBase64(encoded) + assert.NoError(t, err) + assert.Equal(t, c.width, width) + assert.Equal(t, c.height, height) + }) + } +} diff --git a/common/model-ratio.go b/common/model-ratio.go index d1c96d96..fa2adaa1 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -84,6 +84,7 @@ var ModelRatio = map[string]float64{ "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens @@ -115,6 +116,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error { } func GetModelRatio(name string) float64 { + if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } ratio, ok := ModelRatio[name] if !ok { SysError("model ratio not found: " + name) diff --git a/controller/model.go b/controller/model.go index 0a78fd17..31dc06d1 100644 --- a/controller/model.go +++ b/controller/model.go @@ -433,6 +433,15 @@ func init() { Root: "gemini-pro", Parent: nil, }, + { + Id: "gemini-pro-vision", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro-vision", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", diff --git a/go.mod b/go.mod index 19b5b72d..81a59a52 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.5 github.com/stretchr/testify v1.8.3 - golang.org/x/crypto v0.14.0 + golang.org/x/crypto v0.17.0 golang.org/x/image v0.14.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 @@ -59,7 +59,7 @@ require ( github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/net v0.17.0 // indirect - golang.org/x/sys v0.13.0 // indirect + golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index a37fd320..b0c52c8c 100644 --- a/go.sum +++ b/go.sum @@ -152,8 +152,8 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -166,8 +166,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/middleware/recover.go b/middleware/recover.go index c3a3d748..8338a514 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "runtime/debug" ) func RelayPanicRecover() gin.HandlerFunc { @@ -12,6 +13,7 @@ func RelayPanicRecover() gin.HandlerFunc { defer func() { if err := recover(); err != nil { common.SysError(fmt.Sprintf("panic detected: %v", err)) + common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), diff --git a/providers/ali/chat.go b/providers/ali/chat.go index 68d8376f..723d8c33 100644 --- a/providers/ali/chat.go +++ b/providers/ali/chat.go @@ -39,6 +39,7 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI ID: aliResponse.RequestId, Object: "chat.completion", Created: common.GetTimestamp(), + Model: aliResponse.Model, Choices: []types.ChatCompletionChoice{choice}, Usage: &types.Usage{ PromptTokens: aliResponse.Usage.InputTokens, @@ -50,6 +51,8 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI return } +const AliEnableSearchModelSuffix = "-internet" + // 获取聊天请求体 func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest { messages := make([]AliMessage, 0, len(request.Messages)) @@ -60,11 +63,23 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) * Role: strings.ToLower(message.Role), }) } + + enableSearch := false + aliModel := request.Model + if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) { + enableSearch = true + aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix) + } + return &AliChatRequest{ - Model: request.Model, + Model: aliModel, Input: AliInput{ Messages: messages, }, + Parameters: AliParameters{ + EnableSearch: enableSearch, + IncrementalOutput: request.Stream, + }, } } @@ -86,7 +101,7 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa } if request.Stream { - usage, errWithCode = p.sendStreamRequest(req) + usage, errWithCode = p.sendStreamRequest(req, request.Model) if errWithCode != nil { return } @@ -100,7 +115,9 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa } } else { - aliResponse := &AliChatResponse{} + aliResponse := &AliChatResponse{ + Model: request.Model, + } errWithCode = p.SendRequest(req, aliResponse, false) if errWithCode != nil { return @@ -128,14 +145,14 @@ func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ty ID: aliResponse.RequestId, Object: "chat.completion.chunk", Created: common.GetTimestamp(), - Model: "ernie-bot", + Model: aliResponse.Model, Choices: []types.ChatCompletionStreamChoice{choice}, } return &response } // 发送流请求 -func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { +func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { defer req.Body.Close() usage = &types.Usage{} @@ -181,7 +198,7 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, stopChan <- true }() common.SetEventStreamHeaders(p.Context) - lastResponseText := "" + // lastResponseText := "" p.Context.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -196,9 +213,10 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, usage.CompletionTokens = aliResponse.Usage.OutputTokens usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens } + aliResponse.Model = model response := p.streamResponseAli2OpenAI(&aliResponse) - response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) - lastResponseText = aliResponse.Output.Text + // response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) + // lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) diff --git a/providers/ali/type.go b/providers/ali/type.go index da24dcb3..6e85cc2c 100644 --- a/providers/ali/type.go +++ b/providers/ali/type.go @@ -23,10 +23,11 @@ type AliInput struct { } type AliParameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` } type AliChatRequest struct { @@ -43,6 +44,7 @@ type AliOutput struct { type AliChatResponse struct { Output AliOutput `json:"output"` Usage AliUsage `json:"usage"` + Model string `json:"model,omitempty"` AliError } diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go index 521b6d4d..0c424b66 100644 --- a/providers/baidu/chat.go +++ b/providers/baidu/chat.go @@ -88,13 +88,15 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel } if request.Stream { - usage, errWithCode = p.sendStreamRequest(req) + usage, errWithCode = p.sendStreamRequest(req, request.Model) if errWithCode != nil { return } } else { - baiduChatRequest := &BaiduChatResponse{} + baiduChatRequest := &BaiduChatResponse{ + Model: request.Model, + } errWithCode = p.SendRequest(req, baiduChatRequest, false) if errWithCode != nil { return @@ -117,13 +119,13 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea ID: baiduResponse.Id, Object: "chat.completion.chunk", Created: baiduResponse.Created, - Model: "ernie-bot", + Model: baiduResponse.Model, Choices: []types.ChatCompletionStreamChoice{choice}, } return &response } -func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { +func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { defer req.Body.Close() usage = &types.Usage{} @@ -180,6 +182,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage usage.PromptTokens = baiduResponse.Usage.PromptTokens usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens } + baiduResponse.Model = model response := p.streamResponseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(response) if err != nil { diff --git a/providers/baidu/type.go b/providers/baidu/type.go index 30e6f4c4..b2e0f1e8 100644 --- a/providers/baidu/type.go +++ b/providers/baidu/type.go @@ -32,6 +32,7 @@ type BaiduChatResponse struct { IsTruncated bool `json:"is_truncated"` NeedClearHistory bool `json:"need_clear_history"` Usage *types.Usage `json:"usage"` + Model string `json:"model,omitempty"` BaiduError } diff --git a/providers/claude/chat.go b/providers/claude/chat.go index aaf8b3f7..02f309b0 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -38,6 +38,7 @@ func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (Open Object: "chat.completion", Created: common.GetTimestamp(), Choices: []types.ChatCompletionChoice{choice}, + Model: claudeResponse.Model, } completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model) diff --git a/providers/gemini/base.go b/providers/gemini/base.go index 0e448cd6..43fb3d7a 100644 --- a/providers/gemini/base.go +++ b/providers/gemini/base.go @@ -32,7 +32,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) version = p.Context.GetString("api_version") } - return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", baseURL, version, modelName, requestURL, p.Context.GetString("api_key")) + return fmt.Sprintf("%s/%s/models/%s:%s", baseURL, version, modelName, requestURL) } @@ -40,6 +40,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) p.CommonRequestHeaders(headers) + headers["x-goog-api-key"] = p.Context.GetString("api_key") return headers } diff --git a/providers/gemini/chat.go b/providers/gemini/chat.go index 2f893e60..721d4475 100644 --- a/providers/gemini/chat.go +++ b/providers/gemini/chat.go @@ -7,11 +7,16 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/image" "one-api/providers/base" "one-api/types" "strings" ) +const ( + GeminiVisionMaxImageNum = 16 +) + func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { if len(response.Candidates) == 0 { return nil, &types.OpenAIErrorWithStatusCode{ @@ -29,6 +34,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion", Created: common.GetTimestamp(), + Model: response.Model, Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { @@ -46,7 +52,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } - completionTokens := common.CountTokenText(response.GetResponseText(), "gemini-pro") + completionTokens := common.CountTokenText(response.GetResponseText(), response.Model) response.Usage.CompletionTokens = completionTokens response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens @@ -98,6 +104,31 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest }, }, } + + openaiContent := message.ParseContent() + var parts []GeminiPart + imageNum := 0 + for _, part := range openaiContent { + if part.Type == types.ContentTypeText { + parts = append(parts, GeminiPart{ + Text: part.Text, + }) + } else if part.Type == types.ContentTypeImageURL { + imageNum += 1 + if imageNum > GeminiVisionMaxImageNum { + continue + } + mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.URL) + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } + } + content.Parts = parts + // there's no assistant role in gemini and API shall vomit if Role is not user or model if content.Role == "assistant" { content.Role = "model" @@ -142,7 +173,7 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode if request.Stream { var responseText string - errWithCode, responseText = p.sendStreamRequest(req) + errWithCode, responseText = p.sendStreamRequest(req, request.Model) if errWithCode != nil { return } @@ -155,6 +186,7 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode } else { var geminiResponse = &GeminiChatResponse{ + Model: request.Model, Usage: &types.Usage{ PromptTokens: promptTokens, }, @@ -170,18 +202,18 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode } -func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse { - var choice types.ChatCompletionStreamChoice - choice.Delta.Content = geminiResponse.GetResponseText() - choice.FinishReason = &base.StopFinishReason - var response types.ChatCompletionStreamResponse - response.Object = "chat.completion.chunk" - response.Model = "gemini" - response.Choices = []types.ChatCompletionStreamChoice{choice} - return &response -} +// func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse { +// var choice types.ChatCompletionStreamChoice +// choice.Delta.Content = geminiResponse.GetResponseText() +// choice.FinishReason = &base.StopFinishReason +// var response types.ChatCompletionStreamResponse +// response.Object = "chat.completion.chunk" +// response.Model = "gemini" +// response.Choices = []types.ChatCompletionStreamChoice{choice} +// return &response +// } -func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { +func (p *GeminiProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) { defer req.Body.Close() // 发送请求 @@ -235,7 +267,7 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro Content string `json:"content"` } var dummy dummyStruct - err := json.Unmarshal([]byte(data), &dummy) + json.Unmarshal([]byte(data), &dummy) responseText += dummy.Content var choice types.ChatCompletionStreamChoice choice.Delta.Content = dummy.Content @@ -243,7 +275,7 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion.chunk", Created: common.GetTimestamp(), - Model: "gemini-pro", + Model: model, Choices: []types.ChatCompletionStreamChoice{choice}, } jsonResponse, err := json.Marshal(response) diff --git a/providers/gemini/type.go b/providers/gemini/type.go index 333dfcc7..f6476e11 100644 --- a/providers/gemini/type.go +++ b/providers/gemini/type.go @@ -46,6 +46,7 @@ type GeminiChatResponse struct { Candidates []GeminiChatCandidate `json:"candidates"` PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` Usage *types.Usage `json:"usage,omitempty"` + Model string `json:"model,omitempty"` } type GeminiChatCandidate struct { diff --git a/providers/palm/base.go b/providers/palm/base.go index aba4fb72..8f89bb72 100644 --- a/providers/palm/base.go +++ b/providers/palm/base.go @@ -29,6 +29,7 @@ type PalmProvider struct { func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) p.CommonRequestHeaders(headers) + headers["x-goog-api-key"] = p.Context.GetString("api_key") return headers } @@ -37,5 +38,5 @@ func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) { func (p *PalmProvider) GetFullRequestURL(requestURL string, modelName string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") - return fmt.Sprintf("%s%s?key=%s", baseURL, requestURL, p.Context.GetString("api_key")) + return fmt.Sprintf("%s%s", baseURL, requestURL) } diff --git a/providers/palm/chat.go b/providers/palm/chat.go index 7ec6dd06..67aa52e0 100644 --- a/providers/palm/chat.go +++ b/providers/palm/chat.go @@ -43,6 +43,7 @@ func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (Open palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens fullTextResponse.Usage = palmResponse.Usage + fullTextResponse.Model = palmResponse.Model return fullTextResponse, nil } diff --git a/providers/tencent/chat.go b/providers/tencent/chat.go index 10e77e6f..339a5a39 100644 --- a/providers/tencent/chat.go +++ b/providers/tencent/chat.go @@ -27,6 +27,7 @@ func (TencentResponse *TencentChatResponse) ResponseHandler(resp *http.Response) Object: "chat.completion", Created: common.GetTimestamp(), Usage: TencentResponse.Usage, + Model: TencentResponse.Model, } if len(TencentResponse.Choices) > 0 { choice := types.ChatCompletionChoice{ @@ -100,7 +101,7 @@ func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isMod if request.Stream { var responseText string - errWithCode, responseText = p.sendStreamRequest(req) + errWithCode, responseText = p.sendStreamRequest(req, request.Model) if errWithCode != nil { return } @@ -112,7 +113,9 @@ func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isMod usage.TotalTokens = promptTokens + usage.CompletionTokens } else { - tencentResponse := &TencentChatResponse{} + tencentResponse := &TencentChatResponse{ + Model: request.Model, + } errWithCode = p.SendRequest(req, tencentResponse, false) if errWithCode != nil { return @@ -128,7 +131,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC response := types.ChatCompletionStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), - Model: "tencent-hunyuan", + Model: TencentResponse.Model, } if len(TencentResponse.Choices) > 0 { var choice types.ChatCompletionStreamChoice @@ -141,7 +144,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC return &response } -func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { +func (p *TencentProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) { defer req.Body.Close() // 发送请求 resp, err := common.HttpClient.Do(req) @@ -195,6 +198,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr common.SysError("error unmarshalling stream response: " + err.Error()) return true } + TencentResponse.Model = model response := p.streamResponseTencent2OpenAI(&TencentResponse) if len(response.Choices) != 0 { responseText += response.Choices[0].Delta.Content diff --git a/providers/tencent/type.go b/providers/tencent/type.go index 300ba3af..9783b920 100644 --- a/providers/tencent/type.go +++ b/providers/tencent/type.go @@ -58,4 +58,5 @@ type TencentChatResponse struct { Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 Note string `json:"note,omitempty"` // 注释 ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 + Model string `json:"model,omitempty"` // 模型名称 } diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index 807d2ff7..a22f9815 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -28,6 +28,7 @@ func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAI ID: zhipuResponse.Data.TaskId, Object: "chat.completion", Created: common.GetTimestamp(), + Model: zhipuResponse.Model, Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)), Usage: &zhipuResponse.Data.Usage, } @@ -94,13 +95,15 @@ func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModel } if request.Stream { - errWithCode, usage = p.sendStreamRequest(req) + errWithCode, usage = p.sendStreamRequest(req, request.Model) if errWithCode != nil { return } } else { - zhipuResponse := &ZhipuResponse{} + zhipuResponse := &ZhipuResponse{ + Model: request.Model, + } errWithCode = p.SendRequest(req, zhipuResponse, false) if errWithCode != nil { return @@ -132,13 +135,13 @@ func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStrea ID: zhipuResponse.RequestId, Object: "chat.completion.chunk", Created: common.GetTimestamp(), - Model: "chatglm", + Model: zhipuResponse.Model, Choices: []types.ChatCompletionStreamChoice{choice}, } return &response, &zhipuResponse.Usage } -func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, *types.Usage) { +func (p *ZhipuProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, *types.Usage) { defer req.Body.Close() // 发送请求 @@ -159,7 +162,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError if atEOF && len(data) == 0 { return 0, nil, nil } - if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { + if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Contains(string(data), ":") { return i + 2, data[0:i], nil } if atEOF { @@ -195,6 +198,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError select { case data := <-dataChan: response := p.streamResponseZhipu2OpenAI(data) + response.Model = model jsonResponse, err := json.Marshal(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) @@ -209,6 +213,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError common.SysError("error unmarshalling stream response: " + err.Error()) return true } + zhipuResponse.Model = model response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse) jsonResponse, err := json.Marshal(response) if err != nil { diff --git a/providers/zhipu/type.go b/providers/zhipu/type.go index 5a5942e7..d811f0c1 100644 --- a/providers/zhipu/type.go +++ b/providers/zhipu/type.go @@ -31,6 +31,7 @@ type ZhipuResponse struct { Msg string `json:"msg"` Success bool `json:"success"` Data ZhipuResponseData `json:"data"` + Model string `json:"model,omitempty"` } type ZhipuStreamMetaResponse struct { @@ -38,6 +39,7 @@ type ZhipuStreamMetaResponse struct { TaskId string `json:"task_id"` TaskStatus string `json:"task_status"` types.Usage `json:"usage"` + Model string `json:"model,omitempty"` } type zhipuTokenData struct { diff --git a/pull_request_template.md b/pull_request_template.md index bbcd969c..a313004f 100644 --- a/pull_request_template.md +++ b/pull_request_template.md @@ -1,3 +1,9 @@ +[//]: # (请按照以下格式关联 issue) +[//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢) +[//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) +[//]: # (开发者交流群:910657413) +[//]: # (请在提交 PR 之前删除上面的注释) + close #issue_number 我已确认该 PR 已自测通过,相关截图如下: \ No newline at end of file diff --git a/types/chat.go b/types/chat.go index c736573c..e130260b 100644 --- a/types/chat.go +++ b/types/chat.go @@ -1,5 +1,10 @@ package types +const ( + ContentTypeText = "text" + ContentTypeImageURL = "image_url" +) + type ChatCompletionMessage struct { Role string `json:"role"` Content any `json:"content"` @@ -33,6 +38,47 @@ func (m ChatCompletionMessage) StringContent() string { return "" } +func (m ChatCompletionMessage) ParseContent() []ChatMessagePart { + var contentList []ChatMessagePart + content, ok := m.Content.(string) + if ok { + contentList = append(contentList, ChatMessagePart{ + Type: ContentTypeText, + Text: content, + }) + return contentList + } + anyList, ok := m.Content.([]any) + if ok { + for _, contentItem := range anyList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + switch contentMap["type"] { + case ContentTypeText: + if subStr, ok := contentMap["text"].(string); ok { + contentList = append(contentList, ChatMessagePart{ + Type: ContentTypeText, + Text: subStr, + }) + } + case ContentTypeImageURL: + if subObj, ok := contentMap["image_url"].(map[string]any); ok { + contentList = append(contentList, ChatMessagePart{ + Type: ContentTypeImageURL, + ImageURL: &ChatMessageImageURL{ + URL: subObj["url"].(string), + }, + }) + } + } + } + return contentList + } + return nil +} + type ChatMessageImageURL struct { URL string `json:"url,omitempty"` Detail string `json:"detail,omitempty"` diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index a813a96d..555afd31 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -71,7 +71,17 @@ const typeConfig = { other: '插件参数' }, input: { - models: ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1'] + models: [ + 'qwen-turbo', + 'qwen-plus', + 'qwen-max', + 'qwen-max-longcontext', + 'text-embedding-v1', + 'qwen-turbo-internet', + 'qwen-plus-internet', + 'qwen-max-internet', + 'qwen-max-longcontext-internet' + ] }, prompt: { other: '请输入插件参数,即 X-DashScope-Plugin 请求头的取值'