From a763681c2edebca610db7ae91d11d0e9a559a08c Mon Sep 17 00:00:00 2001 From: Buer <42402987+MartialBE@users.noreply.github.com> Date: Sun, 24 Dec 2023 15:35:56 +0800 Subject: [PATCH 1/7] fix: fix base64 image parse error (#858) --- common/image/image.go | 25 +++++++++++++++++++++---- common/image/image_test.go | 17 +++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/common/image/image.go b/common/image/image.go index cbb656ad..93da6a06 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -1,6 +1,8 @@ package image import ( + "bytes" + "encoding/base64" "image" _ "image/gif" _ "image/jpeg" @@ -8,6 +10,7 @@ import ( "net/http" "regexp" "strings" + "sync" _ "golang.org/x/image/webp" ) @@ -29,13 +32,27 @@ var ( reg = regexp.MustCompile(`data:image/([^;]+);base64,`) ) +var readerPool = sync.Pool{ + New: func() interface{} { + return &bytes.Reader{} + }, +} + func GetImageSizeFromBase64(encoded string) (width int, height int, err error) { - encoded = strings.TrimPrefix(encoded, "data:image/png;base64,") - base64 := strings.NewReader(reg.ReplaceAllString(encoded, "")) - img, _, err := image.DecodeConfig(base64) + decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, "")) if err != nil { - return + return 0, 0, err } + + reader := readerPool.Get().(*bytes.Reader) + defer readerPool.Put(reader) + reader.Reset(decoded) + + img, _, err := image.DecodeConfig(reader) + if err != nil { + return 0, 0, err + } + return img.Width, img.Height, nil } diff --git a/common/image/image_test.go b/common/image/image_test.go index 366eda6e..8e47b109 100644 --- a/common/image/image_test.go +++ b/common/image/image_test.go @@ -152,3 +152,20 @@ func TestGetImageSize(t *testing.T) { }) } } + +func TestGetImageSizeFromBase64(t *testing.T) { + for i, c := range cases { + t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + width, height, err := img.GetImageSizeFromBase64(encoded) + assert.NoError(t, err) + assert.Equal(t, c.width, width) + assert.Equal(t, c.height, height) + }) + } +} From ee9e746520e57520fd2b7c7f1bc85d4eaed077fd Mon Sep 17 00:00:00 2001 From: moondie <528893699@qq.com> Date: Sun, 24 Dec 2023 16:17:21 +0800 Subject: [PATCH 2/7] feat: update ali stream implementation & enable internet search (#856) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update relay-ali.go: 改进stream模式,添加联网搜索能力 通义千问支持stream的增量模式,不需要每次去掉上次的前缀;实测qwen-max联网模式效果不错,添加了联网模式。如果别的模型有问题可以改为单独给qwen-max开放 * 删除"stream参数" 刚发现原来阿里api没有这个参数,上次误加了。 * refactor: only enable search when specified * fix: remove custom suffix when get model ratio --------- Co-authored-by: JustSong --- common/model-ratio.go | 3 +++ controller/relay-ali.go | 37 +++++++++++++++++----------- web/src/pages/Channel/EditChannel.js | 7 ++++++ 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index d1c96d96..d6b51f84 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -115,6 +115,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error { } func GetModelRatio(name string) float64 { + if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } ratio, ok := ModelRatio[name] if !ok { SysError("model ratio not found: " + name) diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 65626f6a..7968bfb6 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -23,10 +23,11 @@ type AliInput struct { } type AliParameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` } type AliChatRequest struct { @@ -81,6 +82,8 @@ type AliChatResponse struct { AliError } +const AliEnableSearchModelSuffix = "-internet" + func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { messages := make([]AliMessage, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { @@ -90,17 +93,21 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { Role: strings.ToLower(message.Role), }) } + enableSearch := false + aliModel := request.Model + if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) { + enableSearch = true + aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix) + } return &AliChatRequest{ - Model: request.Model, + Model: aliModel, Input: AliInput{ Messages: messages, }, - //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's - // TopP: request.TopP, - // TopK: 50, - // //Seed: 0, - // //EnableSearch: false, - //}, + Parameters: AliParameters{ + EnableSearch: enableSearch, + IncrementalOutput: request.Stream, + }, } } @@ -202,7 +209,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStre Id: aliResponse.RequestId, Object: "chat.completion.chunk", Created: common.GetTimestamp(), - Model: "ernie-bot", + Model: "qwen", Choices: []ChatCompletionsStreamResponseChoice{choice}, } return &response @@ -240,7 +247,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat stopChan <- true }() setEventStreamHeaders(c) - lastResponseText := "" + //lastResponseText := "" c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -256,8 +263,8 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens } response := streamResponseAli2OpenAI(&aliResponse) - response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) - lastResponseText = aliResponse.Output.Text + //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) + //lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 364da69d..b1c7ae62 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -70,6 +70,13 @@ const EditChannel = () => { break; case 17: localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']; + let withInternetVersion = []; + for (let i = 0; i < localModels.length; i++) { + if (localModels[i].startsWith('qwen-')) { + withInternetVersion.push(localModels[i] + '-internet'); + } + } + localModels = [...localModels, ...withInternetVersion]; break; case 16: localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; From 0699ecd0af9a12e748f1174c3814f0f5c3f49592 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 24 Dec 2023 16:29:48 +0800 Subject: [PATCH 3/7] chore(deps): bump golang.org/x/crypto from 0.14.0 to 0.17.0 (#840) Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.14.0 to 0.17.0. - [Commits](https://github.com/golang/crypto/compare/v0.14.0...v0.17.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 1fe5eabc..68dd5eb6 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.5 github.com/stretchr/testify v1.8.3 - golang.org/x/crypto v0.14.0 + golang.org/x/crypto v0.17.0 golang.org/x/image v0.14.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 @@ -58,7 +58,7 @@ require ( github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/net v0.17.0 // indirect - golang.org/x/sys v0.13.0 // indirect + golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index fb252aa7..21bcddc6 100644 --- a/go.sum +++ b/go.sum @@ -150,8 +150,8 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -164,8 +164,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From 40ceb29e540340572035cdd4348e80433e4d5ab2 Mon Sep 17 00:00:00 2001 From: Bryan Date: Sun, 24 Dec 2023 16:42:00 +0800 Subject: [PATCH 4/7] fix: fix SearchUsers not working if using PostgreSQL (#778) * fix SearchUsers * refactor: using UsingPostgreSQL as condition --------- Co-authored-by: JustSong --- model/user.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/model/user.go b/model/user.go index 7844eb6a..e738b1ba 100644 --- a/model/user.go +++ b/model/user.go @@ -42,7 +42,11 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) { } func SearchUsers(keyword string) (users []*User, err error) { - err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error + if !common.UsingPostgreSQL { + err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error + } else { + err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error + } return users, err } From f3c07e14511c563b7d2cbe66432338e2ed545586 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Sun, 24 Dec 2023 16:58:31 +0800 Subject: [PATCH 5/7] fix: openai response should contains `model` (#841) * fix: openai response should contains `model` - Update model attributes in `claudeHandler` for `relay-claude.go` - Implement model type for fullTextResponse in `relay-gemini.go` - Add new `Model` field to `OpenAITextResponse` struct in `relay.go` * chore: set model name response for models --------- Co-authored-by: JustSong --- controller/relay-ali.go | 1 + controller/relay-baidu.go | 1 + controller/relay-claude.go | 1 + controller/relay-gemini.go | 1 + controller/relay-palm.go | 1 + controller/relay-tencent.go | 1 + controller/relay-zhipu.go | 1 + controller/relay.go | 1 + 8 files changed, 8 insertions(+) diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 7968bfb6..df1cc084 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -310,6 +310,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode }, nil } fullTextResponse := responseAli2OpenAI(&aliResponse) + fullTextResponse.Model = "qwen" jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index c75ec09a..dca30da1 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -255,6 +255,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo }, nil } fullTextResponse := responseBaidu2OpenAI(&baiduResponse) + fullTextResponse.Model = "ernie-bot" jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/controller/relay-claude.go b/controller/relay-claude.go index 1b72b47d..ca7a701a 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -204,6 +204,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model }, nil } fullTextResponse := responseClaude2OpenAI(&claudeResponse) + fullTextResponse.Model = model completionTokens := countTokenText(claudeResponse.Completion, model) usage := Usage{ PromptTokens: promptTokens, diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index 2458458e..523018de 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -287,6 +287,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) + fullTextResponse.Model = model completionTokens := countTokenText(geminiResponse.GetResponseText(), model) usage := Usage{ PromptTokens: promptTokens, diff --git a/controller/relay-palm.go b/controller/relay-palm.go index 2bd0bcd8..0c1c8af6 100644 --- a/controller/relay-palm.go +++ b/controller/relay-palm.go @@ -187,6 +187,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) + fullTextResponse.Model = model completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) usage := Usage{ PromptTokens: promptTokens, diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go index f66bf38f..5930ae89 100644 --- a/controller/relay-tencent.go +++ b/controller/relay-tencent.go @@ -237,6 +237,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus }, nil } fullTextResponse := responseTencent2OpenAI(&TencentResponse) + fullTextResponse.Model = "hunyuan" jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go index 2e345ab5..cb5a78cf 100644 --- a/controller/relay-zhipu.go +++ b/controller/relay-zhipu.go @@ -290,6 +290,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo }, nil } fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) + fullTextResponse.Model = "chatglm" jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/controller/relay.go b/controller/relay.go index 15021997..b7906d08 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -206,6 +206,7 @@ type OpenAITextResponseChoice struct { type OpenAITextResponse struct { Id string `json:"id"` + Model string `json:"model,omitempty"` Object string `json:"object"` Created int64 `json:"created"` Choices []OpenAITextResponseChoice `json:"choices"` From 1c8922153d6adb94184513e7f0263521f0d29157 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 24 Dec 2023 18:54:32 +0800 Subject: [PATCH 6/7] feat: support gemini-vision-pro --- common/image/image.go | 35 +++++++++++++++++ common/model-ratio.go | 1 + controller/model.go | 9 +++++ controller/relay-gemini.go | 31 +++++++++++++++ controller/relay-text.go | 10 +---- controller/relay.go | 59 +++++++++++++++++++++++++++- middleware/recover.go | 2 + web/src/pages/Channel/EditChannel.js | 2 +- 8 files changed, 139 insertions(+), 10 deletions(-) diff --git a/common/image/image.go b/common/image/image.go index 93da6a06..a602936a 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -15,7 +15,22 @@ import ( _ "golang.org/x/image/webp" ) +func IsImageUrl(url string) (bool, error) { + resp, err := http.Head(url) + if err != nil { + return false, err + } + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { + return false, nil + } + return true, nil +} + func GetImageSizeFromUrl(url string) (width int, height int, err error) { + isImage, err := IsImageUrl(url) + if !isImage { + return + } resp, err := http.Get(url) if err != nil { return @@ -28,6 +43,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) { return img.Width, img.Height, nil } +func GetImageFromUrl(url string) (mimeType string, data string, err error) { + isImage, err := IsImageUrl(url) + if !isImage { + return + } + resp, err := http.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + buffer := bytes.NewBuffer(nil) + _, err = buffer.ReadFrom(resp.Body) + if err != nil { + return + } + mimeType = resp.Header.Get("Content-Type") + data = base64.StdEncoding.EncodeToString(buffer.Bytes()) + return +} + var ( reg = regexp.MustCompile(`data:image/([^;]+);base64,`) ) diff --git a/common/model-ratio.go b/common/model-ratio.go index d6b51f84..fa2adaa1 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -84,6 +84,7 @@ var ModelRatio = map[string]float64{ "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 9ae40f5c..6a759b63 100644 --- a/controller/model.go +++ b/controller/model.go @@ -432,6 +432,15 @@ func init() { Root: "gemini-pro", Parent: nil, }, + { + Id: "gemini-pro-vision", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro-vision", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index 523018de..ec55d4b6 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -7,11 +7,18 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/image" "strings" "github.com/gin-gonic/gin" ) +// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn + +const ( + GeminiVisionMaxImageNum = 16 +) + type GeminiChatRequest struct { Contents []GeminiChatContent `json:"contents"` SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` @@ -97,6 +104,30 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { }, }, } + openaiContent := message.ParseContent() + var parts []GeminiPart + imageNum := 0 + for _, part := range openaiContent { + if part.Type == ContentTypeText { + parts = append(parts, GeminiPart{ + Text: part.Text, + }) + } else if part.Type == ContentTypeImageURL { + imageNum += 1 + if imageNum > GeminiVisionMaxImageNum { + continue + } + mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } + } + content.Parts = parts + // there's no assistant role in gemini and API shall vomit if Role is not user or model if content.Role == "assistant" { content.Role = "model" diff --git a/controller/relay-text.go b/controller/relay-text.go index c49a2abe..64338545 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -180,9 +180,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if baseURL != "" { fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - fullRequestURL += "?key=" + apiKey case APITypeGemini: requestBaseURL := "https://generativelanguage.googleapis.com" if baseURL != "" { @@ -197,9 +194,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { action = "streamGenerateContent" } fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - fullRequestURL += "?key=" + apiKey case APITypeZhipu: method := "invoke" if textRequest.Stream { @@ -396,9 +390,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { case APITypeTencent: req.Header.Set("Authorization", apiKey) case APITypePaLM: - // do not set Authorization header + req.Header.Set("x-goog-api-key", apiKey) case APITypeGemini: - // do not set Authorization header + req.Header.Set("x-goog-api-key", apiKey) default: req.Header.Set("Authorization", "Bearer "+apiKey) } diff --git a/controller/relay.go b/controller/relay.go index b7906d08..e45fd3eb 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -31,6 +31,22 @@ type ImageContent struct { ImageURL *ImageURL `json:"image_url,omitempty"` } +const ( + ContentTypeText = "text" + ContentTypeImageURL = "image_url" +) + +type OpenAIMessageContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +func (m Message) IsStringContent() bool { + _, ok := m.Content.(string) + return ok +} + func (m Message) StringContent() string { content, ok := m.Content.(string) if ok { @@ -44,7 +60,7 @@ func (m Message) StringContent() string { if !ok { continue } - if contentMap["type"] == "text" { + if contentMap["type"] == ContentTypeText { if subStr, ok := contentMap["text"].(string); ok { contentStr += subStr } @@ -55,6 +71,47 @@ func (m Message) StringContent() string { return "" } +func (m Message) ParseContent() []OpenAIMessageContent { + var contentList []OpenAIMessageContent + content, ok := m.Content.(string) + if ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeText, + Text: content, + }) + return contentList + } + anyList, ok := m.Content.([]any) + if ok { + for _, contentItem := range anyList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + switch contentMap["type"] { + case ContentTypeText: + if subStr, ok := contentMap["text"].(string); ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeText, + Text: subStr, + }) + } + case ContentTypeImageURL: + if subObj, ok := contentMap["image_url"].(map[string]any); ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeImageURL, + ImageURL: &ImageURL{ + Url: subObj["url"].(string), + }, + }) + } + } + } + return contentList + } + return nil +} + const ( RelayModeUnknown = iota RelayModeChatCompletions diff --git a/middleware/recover.go b/middleware/recover.go index c3a3d748..8338a514 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "runtime/debug" ) func RelayPanicRecover() gin.HandlerFunc { @@ -12,6 +13,7 @@ func RelayPanicRecover() gin.HandlerFunc { defer func() { if err := recover(); err != nil { common.SysError(fmt.Sprintf("panic detected: %v", err)) + common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index b1c7ae62..0d4e114d 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -91,7 +91,7 @@ const EditChannel = () => { localModels = ['hunyuan']; break; case 24: - localModels = ['gemini-pro']; + localModels = ['gemini-pro', 'gemini-pro-vision']; break; } setInputs((inputs) => ({ ...inputs, models: localModels })); From f44fbe3fe7e712847930b19a6527a476b2f4a45f Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 24 Dec 2023 19:24:59 +0800 Subject: [PATCH 7/7] docs: update pr template --- pull_request_template.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pull_request_template.md b/pull_request_template.md index bbcd969c..a313004f 100644 --- a/pull_request_template.md +++ b/pull_request_template.md @@ -1,3 +1,9 @@ +[//]: # (请按照以下格式关联 issue) +[//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢) +[//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) +[//]: # (开发者交流群:910657413) +[//]: # (请在提交 PR 之前删除上面的注释) + close #issue_number 我已确认该 PR 已自测通过,相关截图如下: \ No newline at end of file