feat: support gemini-vision-pro

This commit is contained in:
JustSong 2023-12-24 18:54:32 +08:00
parent f3c07e1451
commit 1c8922153d
8 changed files with 139 additions and 10 deletions

View File

@ -15,7 +15,22 @@ import (
_ "golang.org/x/image/webp" _ "golang.org/x/image/webp"
) )
func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
func GetImageSizeFromUrl(url string) (width int, height int, err error) { func GetImageSizeFromUrl(url string) (width int, height int, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url) resp, err := http.Get(url)
if err != nil { if err != nil {
return return
@ -28,6 +43,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
return img.Width, img.Height, nil return img.Width, img.Height, nil
} }
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}
var ( var (
reg = regexp.MustCompile(`data:image/([^;]+);base64,`) reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
) )

View File

@ -84,6 +84,7 @@ var ModelRatio = map[string]float64{
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1, "PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens

View File

@ -432,6 +432,15 @@ func init() {
Root: "gemini-pro", Root: "gemini-pro",
Parent: nil, Parent: nil,
}, },
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{ {
Id: "chatglm_turbo", Id: "chatglm_turbo",
Object: "model", Object: "model",

View File

@ -7,11 +7,18 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/image"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
GeminiVisionMaxImageNum = 16
)
type GeminiChatRequest struct { type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"` Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
@ -97,6 +104,30 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
}, },
}, },
} }
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model // there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" { if content.Role == "assistant" {
content.Role = "model" content.Role = "model"

View File

@ -180,9 +180,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if baseURL != "" { if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
} }
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGemini: case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com" requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" { if baseURL != "" {
@ -197,9 +194,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
action = "streamGenerateContent" action = "streamGenerateContent"
} }
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeZhipu: case APITypeZhipu:
method := "invoke" method := "invoke"
if textRequest.Stream { if textRequest.Stream {
@ -396,9 +390,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case APITypeTencent: case APITypeTencent:
req.Header.Set("Authorization", apiKey) req.Header.Set("Authorization", apiKey)
case APITypePaLM: case APITypePaLM:
// do not set Authorization header req.Header.Set("x-goog-api-key", apiKey)
case APITypeGemini: case APITypeGemini:
// do not set Authorization header req.Header.Set("x-goog-api-key", apiKey)
default: default:
req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Authorization", "Bearer "+apiKey)
} }

View File

@ -31,6 +31,22 @@ type ImageContent struct {
ImageURL *ImageURL `json:"image_url,omitempty"` ImageURL *ImageURL `json:"image_url,omitempty"`
} }
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
type OpenAIMessageContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) IsStringContent() bool {
_, ok := m.Content.(string)
return ok
}
func (m Message) StringContent() string { func (m Message) StringContent() string {
content, ok := m.Content.(string) content, ok := m.Content.(string)
if ok { if ok {
@ -44,7 +60,7 @@ func (m Message) StringContent() string {
if !ok { if !ok {
continue continue
} }
if contentMap["type"] == "text" { if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok { if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr contentStr += subStr
} }
@ -55,6 +71,47 @@ func (m Message) StringContent() string {
return "" return ""
} }
func (m Message) ParseContent() []OpenAIMessageContent {
var contentList []OpenAIMessageContent
content, ok := m.Content.(string)
if ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: content,
})
return contentList
}
anyList, ok := m.Content.([]any)
if ok {
for _, contentItem := range anyList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeImageURL,
ImageURL: &ImageURL{
Url: subObj["url"].(string),
},
})
}
}
}
return contentList
}
return nil
}
const ( const (
RelayModeUnknown = iota RelayModeUnknown = iota
RelayModeChatCompletions RelayModeChatCompletions

View File

@ -5,6 +5,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"runtime/debug"
) )
func RelayPanicRecover() gin.HandlerFunc { func RelayPanicRecover() gin.HandlerFunc {
@ -12,6 +13,7 @@ func RelayPanicRecover() gin.HandlerFunc {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err)) common.SysError(fmt.Sprintf("panic detected: %v", err))
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{ "error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),

View File

@ -91,7 +91,7 @@ const EditChannel = () => {
localModels = ['hunyuan']; localModels = ['hunyuan'];
break; break;
case 24: case 24:
localModels = ['gemini-pro']; localModels = ['gemini-pro', 'gemini-pro-vision'];
break; break;
} }
setInputs((inputs) => ({ ...inputs, models: localModels })); setInputs((inputs) => ({ ...inputs, models: localModels }));