🐛 fix base 64 encoded format support of gemini-pro-vision
This commit is contained in:
parent
0fa94d3c94
commit
4d43dce64b
@ -3,6 +3,7 @@ package image
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"image"
|
"image"
|
||||||
_ "image/gif"
|
_ "image/gif"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
@ -44,8 +45,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||||
|
|
||||||
|
if strings.HasPrefix(url, "data:image/") {
|
||||||
|
dataURLPattern := regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
|
||||||
|
|
||||||
|
matches := dataURLPattern.FindStringSubmatch(url)
|
||||||
|
if len(matches) == 3 && matches[2] != "" {
|
||||||
|
mimeType = "image/" + matches[1]
|
||||||
|
data = matches[2]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = errors.New("image base64 decode failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
isImage, err := IsImageUrl(url)
|
isImage, err := IsImageUrl(url)
|
||||||
if !isImage {
|
if !isImage {
|
||||||
|
if err == nil {
|
||||||
|
err = errors.New("Invalid image link")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := http.Get(url)
|
resp, err := http.Get(url)
|
||||||
|
@ -169,3 +169,34 @@ func TestGetImageSizeFromBase64(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetImageFromUrl(t *testing.T) {
|
||||||
|
for i, c := range cases {
|
||||||
|
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
|
||||||
|
resp, err := http.Get(c.url)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(data)
|
||||||
|
|
||||||
|
mimeType, base64Data, err := img.GetImageFromUrl(c.url)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, encoded, base64Data)
|
||||||
|
assert.Equal(t, "image/"+c.format, mimeType)
|
||||||
|
|
||||||
|
encodedBase64 := "data:image/" + c.format + ";base64," + encoded
|
||||||
|
mimeType, base64Data, err = img.GetImageFromUrl(encodedBase64)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, encoded, base64Data)
|
||||||
|
assert.Equal(t, "image/"+c.format, mimeType)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
url := "https://raw.githubusercontent.com/songquanpeng/one-api/main/README.md"
|
||||||
|
_, _, err := img.GetImageFromUrl(url)
|
||||||
|
assert.Error(t, err)
|
||||||
|
encodedBase64 := "data:image/text;base64,"
|
||||||
|
_, _, err = img.GetImageFromUrl(encodedBase64)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
@ -60,7 +60,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||||
func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest) {
|
func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
geminiRequest := GeminiChatRequest{
|
geminiRequest := GeminiChatRequest{
|
||||||
Contents: make([]GeminiChatContent, 0, len(request.Messages)),
|
Contents: make([]GeminiChatContent, 0, len(request.Messages)),
|
||||||
//SafetySettings: []GeminiChatSafetySettings{
|
//SafetySettings: []GeminiChatSafetySettings{
|
||||||
@ -118,7 +118,10 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
|
|||||||
if imageNum > GeminiVisionMaxImageNum {
|
if imageNum > GeminiVisionMaxImageNum {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.URL)
|
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
|
||||||
|
}
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, GeminiPart{
|
||||||
InlineData: &GeminiInlineData{
|
InlineData: &GeminiInlineData{
|
||||||
MimeType: mimeType,
|
MimeType: mimeType,
|
||||||
@ -154,11 +157,14 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &geminiRequest
|
return &geminiRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
requestBody := p.getChatRequestBody(request)
|
requestBody, errWithCode := p.getChatRequestBody(request)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
fullRequestURL := p.GetFullRequestURL("generateContent", request.Model)
|
fullRequestURL := p.GetFullRequestURL("generateContent", request.Model)
|
||||||
headers := p.GetRequestHeaders()
|
headers := p.GetRequestHeaders()
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
|
Loading…
Reference in New Issue
Block a user