diff --git a/common/image/image.go b/common/image/image.go index a602936a..f4998bc9 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -3,6 +3,7 @@ package image import ( "bytes" "encoding/base64" + "errors" "image" _ "image/gif" _ "image/jpeg" @@ -44,8 +45,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) { } func GetImageFromUrl(url string) (mimeType string, data string, err error) { + + if strings.HasPrefix(url, "data:image/") { + dataURLPattern := regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) + + matches := dataURLPattern.FindStringSubmatch(url) + if len(matches) == 3 && matches[2] != "" { + mimeType = "image/" + matches[1] + data = matches[2] + return + } + + err = errors.New("image base64 decode failed") + return + } + isImage, err := IsImageUrl(url) if !isImage { + if err == nil { + err = errors.New("Invalid image link") + } return } resp, err := http.Get(url) diff --git a/common/image/image_test.go b/common/image/image_test.go index 8e47b109..2445b559 100644 --- a/common/image/image_test.go +++ b/common/image/image_test.go @@ -169,3 +169,34 @@ func TestGetImageSizeFromBase64(t *testing.T) { }) } } + +func TestGetImageFromUrl(t *testing.T) { + for i, c := range cases { + t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + + mimeType, base64Data, err := img.GetImageFromUrl(c.url) + assert.NoError(t, err) + assert.Equal(t, encoded, base64Data) + assert.Equal(t, "image/"+c.format, mimeType) + + encodedBase64 := "data:image/" + c.format + ";base64," + encoded + mimeType, base64Data, err = img.GetImageFromUrl(encodedBase64) + assert.NoError(t, err) + assert.Equal(t, encoded, base64Data) + assert.Equal(t, "image/"+c.format, mimeType) + }) + } + + url := "https://raw.githubusercontent.com/songquanpeng/one-api/main/README.md" + _, _, err := img.GetImageFromUrl(url) + assert.Error(t, err) + encodedBase64 := "data:image/text;base64," + _, _, err = img.GetImageFromUrl(encodedBase64) + assert.Error(t, err) +} diff --git a/providers/gemini/chat.go b/providers/gemini/chat.go index c9c30244..72d218b2 100644 --- a/providers/gemini/chat.go +++ b/providers/gemini/chat.go @@ -60,7 +60,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI } // Setting safety to the lowest possible values since Gemini is already powerless enough -func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest) { +func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest, errWithCode *types.OpenAIErrorWithStatusCode) { geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(request.Messages)), //SafetySettings: []GeminiChatSafetySettings{ @@ -118,7 +118,10 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest if imageNum > GeminiVisionMaxImageNum { continue } - mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.URL) + mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL) + if err != nil { + return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest) + } parts = append(parts, GeminiPart{ InlineData: &GeminiInlineData{ MimeType: mimeType, @@ -154,11 +157,14 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest } } - return &geminiRequest + return &geminiRequest, nil } func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody := p.getChatRequestBody(request) + requestBody, errWithCode := p.getChatRequestBody(request) + if errWithCode != nil { + return + } fullRequestURL := p.GetFullRequestURL("generateContent", request.Model) headers := p.GetRequestHeaders() if request.Stream {