🐛 fix base 64 encoded format support of gemini-pro-vision

This commit is contained in:
Martial BE 2023-12-28 12:23:39 +08:00
parent 0fa94d3c94
commit 4d43dce64b
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
3 changed files with 60 additions and 4 deletions

View File

@ -3,6 +3,7 @@ package image
import (
"bytes"
"encoding/base64"
"errors"
"image"
_ "image/gif"
_ "image/jpeg"
@ -44,8 +45,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
}
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
if strings.HasPrefix(url, "data:image/") {
dataURLPattern := regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
matches := dataURLPattern.FindStringSubmatch(url)
if len(matches) == 3 && matches[2] != "" {
mimeType = "image/" + matches[1]
data = matches[2]
return
}
err = errors.New("image base64 decode failed")
return
}
isImage, err := IsImageUrl(url)
if !isImage {
if err == nil {
err = errors.New("Invalid image link")
}
return
}
resp, err := http.Get(url)

View File

@ -169,3 +169,34 @@ func TestGetImageSizeFromBase64(t *testing.T) {
})
}
}
func TestGetImageFromUrl(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
mimeType, base64Data, err := img.GetImageFromUrl(c.url)
assert.NoError(t, err)
assert.Equal(t, encoded, base64Data)
assert.Equal(t, "image/"+c.format, mimeType)
encodedBase64 := "data:image/" + c.format + ";base64," + encoded
mimeType, base64Data, err = img.GetImageFromUrl(encodedBase64)
assert.NoError(t, err)
assert.Equal(t, encoded, base64Data)
assert.Equal(t, "image/"+c.format, mimeType)
})
}
url := "https://raw.githubusercontent.com/songquanpeng/one-api/main/README.md"
_, _, err := img.GetImageFromUrl(url)
assert.Error(t, err)
encodedBase64 := "data:image/text;base64,"
_, _, err = img.GetImageFromUrl(encodedBase64)
assert.Error(t, err)
}

View File

@ -60,7 +60,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest) {
func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest, errWithCode *types.OpenAIErrorWithStatusCode) {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(request.Messages)),
//SafetySettings: []GeminiChatSafetySettings{
@ -118,7 +118,10 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.URL)
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL)
if err != nil {
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
@ -154,11 +157,14 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
}
}
return &geminiRequest
return &geminiRequest, nil
}
func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
requestBody, errWithCode := p.getChatRequestBody(request)
if errWithCode != nil {
return
}
fullRequestURL := p.GetFullRequestURL("generateContent", request.Model)
headers := p.GetRequestHeaders()
if request.Stream {