🐛 fix base 64 encoded format support of gemini-pro-vision
This commit is contained in:
parent
0fa94d3c94
commit
4d43dce64b
@ -3,6 +3,7 @@ package image
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
@ -44,8 +45,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
|
||||
}
|
||||
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
|
||||
if strings.HasPrefix(url, "data:image/") {
|
||||
dataURLPattern := regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
|
||||
|
||||
matches := dataURLPattern.FindStringSubmatch(url)
|
||||
if len(matches) == 3 && matches[2] != "" {
|
||||
mimeType = "image/" + matches[1]
|
||||
data = matches[2]
|
||||
return
|
||||
}
|
||||
|
||||
err = errors.New("image base64 decode failed")
|
||||
return
|
||||
}
|
||||
|
||||
isImage, err := IsImageUrl(url)
|
||||
if !isImage {
|
||||
if err == nil {
|
||||
err = errors.New("Invalid image link")
|
||||
}
|
||||
return
|
||||
}
|
||||
resp, err := http.Get(url)
|
||||
|
@ -169,3 +169,34 @@ func TestGetImageSizeFromBase64(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetImageFromUrl(t *testing.T) {
|
||||
for i, c := range cases {
|
||||
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
|
||||
resp, err := http.Get(c.url)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
mimeType, base64Data, err := img.GetImageFromUrl(c.url)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, encoded, base64Data)
|
||||
assert.Equal(t, "image/"+c.format, mimeType)
|
||||
|
||||
encodedBase64 := "data:image/" + c.format + ";base64," + encoded
|
||||
mimeType, base64Data, err = img.GetImageFromUrl(encodedBase64)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, encoded, base64Data)
|
||||
assert.Equal(t, "image/"+c.format, mimeType)
|
||||
})
|
||||
}
|
||||
|
||||
url := "https://raw.githubusercontent.com/songquanpeng/one-api/main/README.md"
|
||||
_, _, err := img.GetImageFromUrl(url)
|
||||
assert.Error(t, err)
|
||||
encodedBase64 := "data:image/text;base64,"
|
||||
_, _, err = img.GetImageFromUrl(encodedBase64)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest) {
|
||||
func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(request.Messages)),
|
||||
//SafetySettings: []GeminiChatSafetySettings{
|
||||
@ -118,7 +118,10 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
|
||||
if imageNum > GeminiVisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.URL)
|
||||
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
@ -154,11 +157,14 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
|
||||
}
|
||||
}
|
||||
|
||||
return &geminiRequest
|
||||
return &geminiRequest, nil
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
requestBody, errWithCode := p.getChatRequestBody(request)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
fullRequestURL := p.GetFullRequestURL("generateContent", request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
|
Loading…
Reference in New Issue
Block a user