From 379074f7d011df8fcbfe3cd6bdfbca8967d3ab91 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 10 Dec 2023 17:22:52 +0800 Subject: [PATCH 01/16] feat: support plugin for ali channel (close #797) --- controller/relay-text.go | 3 +++ middleware/distributor.go | 2 ++ web/src/pages/Channel/EditChannel.js | 14 ++++++++++++++ 3 files changed, 19 insertions(+) diff --git a/controller/relay-text.go b/controller/relay-text.go index a3e233d3..a69c7f8b 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -360,6 +360,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { req.Header.Set("X-DashScope-SSE", "enable") } + if c.GetString("plugin") != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) + } case APITypeTencent: req.Header.Set("Authorization", apiKey) case APITypePaLM: diff --git a/middleware/distributor.go b/middleware/distributor.go index c4ddc3a0..8be986c9 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -89,6 +89,8 @@ func Distribute() func(c *gin.Context) { c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: c.Set("library_id", channel.Other) + case common.ChannelTypeAli: + c.Set("plugin", channel.Other) } c.Next() } diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index bc3886a0..62e8a155 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -343,6 +343,20 @@ const EditChannel = () => { ) } + { + inputs.type === 17 && ( + + + + ) + } Date: Sun, 10 Dec 2023 18:39:14 +0800 Subject: [PATCH 02/16] feat: refactor response parsing logic to support multiple formats (#782) * feat: Refactor response parsing logic to support multiple formats The parsing logic for responses in relay.go and relay-audio.go was refactored to support multiple response formats - 'json', 'text', 'srt', 'verbose_json', and 'vtt'. The existing `WhisperResponse` struct was renamed to `WhisperJsonResponse` and a new struct `WhisperVerboseJsonResponse` was added to support the 'verbose_json' format. Additional parsing functions were added to extract text from these new response types. This change was necessary to make the parsing logic more flexible and extendable for different types of responses. * chore: update name --------- Co-authored-by: JustSong --- controller/relay-audio.go | 85 ++++++++++++++++++++++++++++++++++++--- controller/relay.go | 23 ++++++++++- 2 files changed, 101 insertions(+), 7 deletions(-) diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 9e78dadc..2247f4c7 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -1,6 +1,7 @@ package controller import ( + "bufio" "bytes" "context" "encoding/json" @@ -102,7 +103,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) } - requestBody := c.Request.Body + requestBody := &bytes.Buffer{} + _, err = io.Copy(requestBody, c.Request.Body) + if err != nil { + return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) + responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -144,12 +151,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - var whisperResponse WhisperResponse - err = json.Unmarshal(responseBody, &whisperResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + + var openAIErr TextResponse + if err = json.Unmarshal(responseBody, &openAIErr); err == nil { + if openAIErr.Error.Message != "" { + return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) + } } - quota = countTokenText(whisperResponse.Text, audioModel) + + var text string + switch responseFormat { + case "json": + text, err = getTextFromJSON(responseBody) + case "text": + text, err = getTextFromText(responseBody) + case "srt": + text, err = getTextFromSRT(responseBody) + case "verbose_json": + text, err = getTextFromVerboseJSON(responseBody) + case "vtt": + text, err = getTextFromVTT(responseBody) + default: + return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) + } + if err != nil { + return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + } + quota = countTokenText(text, audioModel) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { @@ -187,3 +215,48 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } return nil } + +func getTextFromVTT(body []byte) (string, error) { + return getTextFromSRT(body) +} + +func getTextFromVerboseJSON(body []byte) (string, error) { + var whisperResponse WhisperVerboseJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} + +func getTextFromSRT(body []byte) (string, error) { + scanner := bufio.NewScanner(strings.NewReader(string(body))) + var builder strings.Builder + var textLine bool + for scanner.Scan() { + line := scanner.Text() + if textLine { + builder.WriteString(line) + textLine = false + continue + } else if strings.Contains(line, "-->") { + textLine = true + continue + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return builder.String(), nil +} + +func getTextFromText(body []byte) (string, error) { + return strings.TrimSuffix(string(body), "\n"), nil +} + +func getTextFromJSON(body []byte) (string, error) { + var whisperResponse WhisperJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} diff --git a/controller/relay.go b/controller/relay.go index 58ee8381..0e660a68 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -141,10 +141,31 @@ type ImageRequest struct { User string `json:"user,omitempty"` } -type WhisperResponse struct { +type WhisperJSONResponse struct { Text string `json:"text,omitempty"` } +type WhisperVerboseJSONResponse struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Text string `json:"text,omitempty"` + Segments []Segment `json:"segments,omitempty"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} + type TextToSpeechRequest struct { Model string `json:"model" binding:"required"` Input string `json:"input" binding:"required"` From 4c5feee0b66e0d42091ef9fa4bd1f5f3f78ec38f Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Sun, 10 Dec 2023 19:39:46 +0800 Subject: [PATCH 03/16] feat: add image counter for gpt-4 vision (#795) --- common/image/image.go | 47 +++++++++++ common/image/image_test.go | 154 +++++++++++++++++++++++++++++++++++++ controller/relay-text.go | 4 + controller/relay-utils.go | 105 ++++++++++++++++++++++++- go.mod | 6 +- go.sum | 6 +- 6 files changed, 315 insertions(+), 7 deletions(-) create mode 100644 common/image/image.go create mode 100644 common/image/image_test.go diff --git a/common/image/image.go b/common/image/image.go new file mode 100644 index 00000000..cbb656ad --- /dev/null +++ b/common/image/image.go @@ -0,0 +1,47 @@ +package image + +import ( + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "net/http" + "regexp" + "strings" + + _ "golang.org/x/image/webp" +) + +func GetImageSizeFromUrl(url string) (width int, height int, err error) { + resp, err := http.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + img, _, err := image.DecodeConfig(resp.Body) + if err != nil { + return + } + return img.Width, img.Height, nil +} + +var ( + reg = regexp.MustCompile(`data:image/([^;]+);base64,`) +) + +func GetImageSizeFromBase64(encoded string) (width int, height int, err error) { + encoded = strings.TrimPrefix(encoded, "data:image/png;base64,") + base64 := strings.NewReader(reg.ReplaceAllString(encoded, "")) + img, _, err := image.DecodeConfig(base64) + if err != nil { + return + } + return img.Width, img.Height, nil +} + +func GetImageSize(image string) (width int, height int, err error) { + if strings.HasPrefix(image, "data:image/") { + return GetImageSizeFromBase64(image) + } + return GetImageSizeFromUrl(image) +} diff --git a/common/image/image_test.go b/common/image/image_test.go new file mode 100644 index 00000000..366eda6e --- /dev/null +++ b/common/image/image_test.go @@ -0,0 +1,154 @@ +package image_test + +import ( + "encoding/base64" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "io" + "net/http" + "strconv" + "strings" + "testing" + + img "one-api/common/image" + + "github.com/stretchr/testify/assert" + _ "golang.org/x/image/webp" +) + +type CountingReader struct { + reader io.Reader + BytesRead int +} + +func (r *CountingReader) Read(p []byte) (n int, err error) { + n, err = r.reader.Read(p) + r.BytesRead += n + return n, err +} + +var ( + cases = []struct { + url string + format string + width int + height int + }{ + {"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669}, + {"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592}, + {"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985}, + {"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533}, + {"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230}, + } +) + +func TestDecode(t *testing.T) { + // Bytes read: varies sometimes + // jpeg: 1063892 + // png: 294462 + // webp: 99529 + // gif: 956153 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 4096 + // png: 4096 + // webp: 4096 + // gif: 4096 + // jpeg#01: 4096 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestBase64(t *testing.T) { + // Bytes read: + // jpeg: 1063892 + // png: 294462 + // webp: 99072 + // gif: 953856 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 1536 + // png: 768 + // webp: 768 + // gif: 1536 + // jpeg#01: 3840 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestGetImageSize(t *testing.T) { + for i, c := range cases { + t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { + width, height, err := img.GetImageSize(c.url) + assert.NoError(t, err) + assert.Equal(t, c.width, width) + assert.Equal(t, c.height, height) + }) + } +} diff --git a/controller/relay-text.go b/controller/relay-text.go index a69c7f8b..c3d54059 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -410,6 +410,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { defer func(ctx context.Context) { // c.Writer.Flush() go func() { + if promptTokens != textResponse.PromptTokens { + common.SysError(fmt.Sprintf("prompt tokens not match, expected %d, actual %d", promptTokens, textResponse.PromptTokens)) + } + quota := 0 completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens = textResponse.Usage.PromptTokens diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 38408c7f..9deca75a 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -3,10 +3,13 @@ package controller import ( "context" "encoding/json" + "errors" "fmt" "io" + "math" "net/http" "one-api/common" + "one-api/common/image" "one-api/model" "strconv" "strings" @@ -87,7 +90,33 @@ func countTokenMessages(messages []Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.StringContent()) + switch v := message.Content.(type) { + case string: + tokenNum += getTokenNum(tokenEncoder, v) + case []any: + for _, it := range v { + m := it.(map[string]any) + switch m["type"] { + case "text": + tokenNum += getTokenNum(tokenEncoder, m["text"].(string)) + case "image_url": + imageUrl, ok := m["image_url"].(map[string]any) + if ok { + url := imageUrl["url"].(string) + detail := "" + if imageUrl["detail"] != nil { + detail = imageUrl["detail"].(string) + } + imageTokens, err := countImageTokens(url, detail) + if err != nil { + common.SysError("error counting image tokens: " + err.Error()) + } else { + tokenNum += imageTokens + } + } + } + } + } tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName @@ -98,13 +127,81 @@ func countTokenMessages(messages []Message, model string) int { return tokenNum } +const ( + lowDetailCost = 85 + highDetailCostPerTile = 170 + additionalCost = 85 +) + +// https://platform.openai.com/docs/guides/vision/calculating-costs +// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb +func countImageTokens(url string, detail string) (_ int, err error) { + var fetchSize = true + var width, height int + // Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding + // detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting. + // According to the official guide, "low" disable the high-res model, + // and only receive low-res 512px x 512px version of the image, indicating + // that image is treated as low-res when size is smaller than 512px x 512px, + // then we can assume that image size larger than 512px x 512px is treated + // as high-res. Then we have the following logic: + // if detail == "" || detail == "auto" { + // width, height, err = image.GetImageSize(url) + // if err != nil { + // return 0, err + // } + // fetchSize = false + // // not sure if this is correct + // if width > 512 || height > 512 { + // detail = "high" + // } else { + // detail = "low" + // } + // } + + // However, in my test, it seems to be always the same as "high". + // The following image, which is 125x50, is still treated as high-res, taken + // 255 tokens in the response of non-stream chat completion api. + // https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg + if detail == "" || detail == "auto" { + // assume by test, not sure if this is correct + detail = "high" + } + switch detail { + case "low": + return lowDetailCost, nil + case "high": + if fetchSize { + width, height, err = image.GetImageSize(url) + if err != nil { + return 0, err + } + } + if width > 2048 || height > 2048 { // max(width, height) > 2048 + ratio := float64(2048) / math.Max(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + if width > 768 && height > 768 { // min(width, height) > 768 + ratio := float64(768) / math.Min(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512)) + result := numSquares*highDetailCostPerTile + additionalCost + return result, nil + default: + return 0, errors.New("invalid detail option") + } +} + func countTokenInput(input any, model string) int { - switch input.(type) { + switch v := input.(type) { case string: - return countTokenText(input.(string), model) + return countTokenText(v, model) case []string: text := "" - for _, s := range input.([]string) { + for _, s := range v { text += s } return countTokenText(text, model) diff --git a/go.mod b/go.mod index 10b78d68..1fe5eabc 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,9 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.5 + github.com/stretchr/testify v1.8.3 golang.org/x/crypto v0.14.0 + golang.org/x/image v0.14.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 gorm.io/driver/sqlite v1.4.3 @@ -26,6 +28,7 @@ require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.10.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect @@ -50,12 +53,13 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.13.0 // indirect - golang.org/x/text v0.13.0 // indirect + golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4865bcaa..fb252aa7 100644 --- a/go.sum +++ b/go.sum @@ -152,6 +152,8 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= +golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= @@ -168,8 +170,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From 2a70744dbf266f25170ae04ca4ae7fe007ef6018 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 10 Dec 2023 19:53:33 +0800 Subject: [PATCH 04/16] feat: add panic recover middleware --- middleware/recover.go | 26 ++++++++++++++++++++++++++ router/relay-router.go | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 middleware/recover.go diff --git a/middleware/recover.go b/middleware/recover.go new file mode 100644 index 00000000..c3a3d748 --- /dev/null +++ b/middleware/recover.go @@ -0,0 +1,26 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" +) + +func RelayPanicRecover() gin.HandlerFunc { + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + common.SysError(fmt.Sprintf("panic detected: %v", err)) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), + "type": "one_api_panic", + }, + }) + c.Abort() + } + }() + c.Next() + } +} diff --git a/router/relay-router.go b/router/relay-router.go index 24edc9a9..56ab9b28 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter.GET("/:model", controller.RetrieveModel) } relayV1Router := router.Group("/v1") - relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) { relayV1Router.POST("/completions", controller.Relay) relayV1Router.POST("/chat/completions", controller.Relay) From 366b82128f89a328f096da6951cbafebb6b0060f Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 10 Dec 2023 20:44:37 +0800 Subject: [PATCH 05/16] fix: remove incorrect logging --- controller/relay-text.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/controller/relay-text.go b/controller/relay-text.go index c3d54059..a69c7f8b 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -410,10 +410,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { defer func(ctx context.Context) { // c.Writer.Flush() go func() { - if promptTokens != textResponse.PromptTokens { - common.SysError(fmt.Sprintf("prompt tokens not match, expected %d, actual %d", promptTokens, textResponse.PromptTokens)) - } - quota := 0 completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens = textResponse.Usage.PromptTokens From 5cf23d8698cfebd720a6722e4b89a839c8a804c0 Mon Sep 17 00:00:00 2001 From: David Zhuang Date: Sat, 16 Dec 2023 23:48:32 -0500 Subject: [PATCH 06/16] feat: add Google Gemini Pro support (#826) * fest: Add Google Gemini Pro, fix #810 * fest: Add tooling to Gemini; Add OpenAI-like system prompt to Gemini * refactor: removing unused if statement * fest: Add dummy model message for system message in gemini model * chore: update implementation --------- Co-authored-by: JustSong --- README.en.md | 2 +- README.ja.md | 2 +- README.md | 2 +- common/constants.go | 2 + common/model-ratio.go | 1 + controller/channel-test.go | 2 + controller/model.go | 9 + controller/relay-gemini.go | 281 +++++++++++++++++++++++++ controller/relay-text.go | 49 +++++ middleware/distributor.go | 2 + web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 3 + 12 files changed, 353 insertions(+), 3 deletions(-) create mode 100644 controller/relay-gemini.go diff --git a/README.en.md b/README.en.md index 9345a219..82dceb5b 100644 --- a/README.en.md +++ b/README.en.md @@ -60,7 +60,7 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use 1. Support for multiple large models: + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude Series Models](https://anthropic.com) - + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) + + [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) diff --git a/README.ja.md b/README.ja.md index 6faf9bee..089fc2b5 100644 --- a/README.ja.md +++ b/README.ja.md @@ -60,7 +60,7 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に 1. 複数の大型モデルをサポート: + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) + [x] [Anthropic Claude シリーズモデル](https://anthropic.com) - + [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) diff --git a/README.md b/README.md index 7e6a7b38..8a1d6caf 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 1. 支持多种大模型: + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude 系列模型](https://anthropic.com) - + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) diff --git a/common/constants.go b/common/constants.go index f6860f67..60700ec8 100644 --- a/common/constants.go +++ b/common/constants.go @@ -187,6 +187,7 @@ const ( ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 + ChannelTypeGemini = 24 ) var ChannelBaseURLs = []string{ @@ -214,4 +215,5 @@ var ChannelBaseURLs = []string{ "https://api.aiproxy.io", // 21 "https://fastgpt.run/api/openapi", // 22 "https://hunyuan.cloud.tencent.com", //23 + "", //24 } diff --git a/common/model-ratio.go b/common/model-ratio.go index ccbc05dd..c054fa5f 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -83,6 +83,7 @@ var ModelRatio = map[string]float64{ "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/channel-test.go b/controller/channel-test.go index bba9a657..3aaa4897 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,6 +20,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai switch channel.Type { case common.ChannelTypePaLM: fallthrough + case common.ChannelTypeGemini: + fallthrough case common.ChannelTypeAnthropic: fallthrough case common.ChannelTypeBaidu: diff --git a/controller/model.go b/controller/model.go index 8f79524d..5c8aebc0 100644 --- a/controller/model.go +++ b/controller/model.go @@ -423,6 +423,15 @@ func init() { Root: "PaLM-2", Parent: nil, }, + { + Id: "gemini-pro", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go new file mode 100644 index 00000000..455e30d8 --- /dev/null +++ b/controller/relay-gemini.go @@ -0,0 +1,281 @@ +package controller + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + + "github.com/gin-gonic/gin" +) + +type GeminiChatRequest struct { + Contents []GeminiChatContent `json:"contents"` + SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + Tools []GeminiChatTools `json:"tools,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + +type GeminiChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type GeminiChatTools struct { + FunctionDeclarations any `json:"functionDeclarations,omitempty"` +} + +type GeminiChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +// Setting safety to the lowest possible values since Gemini is already powerless enough +func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { + geminiRequest := GeminiChatRequest{ + Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), + //SafetySettings: []GeminiChatSafetySettings{ + // { + // Category: "HARM_CATEGORY_HARASSMENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_HATE_SPEECH", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + //}, + GenerationConfig: GeminiChatGenerationConfig{ + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + MaxOutputTokens: textRequest.MaxTokens, + }, + } + if textRequest.Functions != nil { + geminiRequest.Tools = []GeminiChatTools{ + { + FunctionDeclarations: textRequest.Functions, + }, + } + } + shouldAddDummyModelMessage := false + for _, message := range textRequest.Messages { + content := GeminiChatContent{ + Role: message.Role, + Parts: []GeminiPart{ + { + Text: message.StringContent(), + }, + }, + } + // there's no assistant role in gemini and API shall vomit if Role is not user or model + if content.Role == "assistant" { + content.Role = "model" + } + // Converting system prompt to prompt from user for the same reason + if content.Role == "system" { + content.Role = "user" + shouldAddDummyModelMessage = true + } + geminiRequest.Contents = append(geminiRequest.Contents, content) + + // If a system message is the last message, we need to add a dummy model message to make gemini happy + if shouldAddDummyModelMessage { + geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ + Role: "model", + Parts: []GeminiPart{ + { + Text: "ok", + }, + }, + }) + shouldAddDummyModelMessage = false + } + } + + return &geminiRequest +} + +type GeminiChatResponse struct { + Candidates []GeminiChatCandidate `json:"candidates"` + PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` +} + +type GeminiChatCandidate struct { + Content GeminiChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +type GeminiChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +type GeminiChatPromptFeedback struct { + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + } + for i, candidate := range response.Candidates { + choice := OpenAITextResponseChoice{ + Index: i, + Message: Message{ + Role: "assistant", + Content: candidate.Content.Parts[0].Text, + }, + FinishReason: stopFinishReason, + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text + } + choice.FinishReason = &stopFinishReason + var response ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "gemini" + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysError("error reading stream response: " + err.Error()) + stopChan <- true + return + } + err = resp.Body.Close() + if err != nil { + common.SysError("error closing stream response: " + err.Error()) + stopChan <- true + return + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + stopChan <- true + return + } + fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) + fullTextResponse.Id = responseId + fullTextResponse.Created = createdTime + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + responseText += geminiResponse.Candidates[0].Content.Parts[0].Text + } + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + stopChan <- true + return + } + dataChan <- string(jsonResponse) + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + c.Render(-1, common.CustomEvent{Data: "data: " + data}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if len(geminiResponse.Candidates) == 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) + completionTokens := countTokenText(geminiResponse.Candidates[0].Content.Parts[0].Text, model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index a69c7f8b..211a34b3 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -27,6 +27,7 @@ const ( APITypeXunfei APITypeAIProxyLibrary APITypeTencent + APITypeGemini ) var httpClient *http.Client @@ -118,6 +119,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAIProxyLibrary case common.ChannelTypeTencent: apiType = APITypeTencent + case common.ChannelTypeGemini: + apiType = APITypeGemini } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -177,6 +180,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey + case APITypeGemini: + requestBaseURL := "https://generativelanguage.googleapis.com" + if baseURL != "" { + requestBaseURL = baseURL + } + version := "v1" + if c.GetString("api_version") != "" { + version = c.GetString("api_version") + } + action := "generateContent" + // actually gemini does not support stream, it's fake + //if textRequest.Stream { + // action = "streamGenerateContent" + //} + fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + fullRequestURL += "?key=" + apiKey case APITypeZhipu: method := "invoke" if textRequest.Stream { @@ -274,6 +295,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeGemini: + geminiChatRequest := requestOpenAI2Gemini(textRequest) + jsonStr, err := json.Marshal(geminiChatRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) case APITypeZhipu: zhipuRequest := requestOpenAI2Zhipu(textRequest) jsonStr, err := json.Marshal(zhipuRequest) @@ -367,6 +395,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { req.Header.Set("Authorization", apiKey) case APITypePaLM: // do not set Authorization header + case APITypeGemini: + // do not set Authorization header default: req.Header.Set("Authorization", "Bearer "+apiKey) } @@ -527,6 +557,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeGemini: + if textRequest.Stream { + err, responseText := geminiChatStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } case APITypeZhipu: if isStream { err, usage := zhipuStreamHandler(c, resp) diff --git a/middleware/distributor.go b/middleware/distributor.go index 8be986c9..81338130 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -87,6 +87,8 @@ func Distribute() func(c *gin.Context) { c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: c.Set("library_id", channel.Other) case common.ChannelTypeAli: diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 76407745..264dbefb 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -3,6 +3,7 @@ export const CHANNEL_OPTIONS = [ { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, + { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 62e8a155..114e5933 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -83,6 +83,9 @@ const EditChannel = () => { case 23: localModels = ['hunyuan']; break; + case 24: + localModels = ['gemini-pro']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); } From 58dee76bf7828718021f924dd741c3754515cd7c Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 16:16:18 +0800 Subject: [PATCH 07/16] fix: fix Gemini stream problem --- controller/relay-gemini.go | 81 ++++++++++++++++++++++---------------- controller/relay-text.go | 7 ++-- controller/relay.go | 2 +- 3 files changed, 51 insertions(+), 39 deletions(-) diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index 455e30d8..4c2daba9 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -1,11 +1,13 @@ package controller import ( + "bufio" "encoding/json" "fmt" "io" "net/http" "one-api/common" + "strings" "github.com/gin-gonic/gin" ) @@ -180,50 +182,61 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCo func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() dataChan := make(chan string) stopChan := make(chan bool) + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) go func() { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - common.SysError("error reading stream response: " + err.Error()) - stopChan <- true - return + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "\"text\": \"") { + continue + } + data = strings.TrimPrefix(data, "\"text\": \"") + data = strings.TrimSuffix(data, "\"") + dataChan <- data } - err = resp.Body.Close() - if err != nil { - common.SysError("error closing stream response: " + err.Error()) - stopChan <- true - return - } - var geminiResponse GeminiChatResponse - err = json.Unmarshal(responseBody, &geminiResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - stopChan <- true - return - } - fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) - fullTextResponse.Id = responseId - fullTextResponse.Created = createdTime - if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { - responseText += geminiResponse.Candidates[0].Content.Parts[0].Text - } - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - stopChan <- true - return - } - dataChan <- string(jsonResponse) stopChan <- true }() setEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - c.Render(-1, common.CustomEvent{Data: "data: " + data}) + // this is used to prevent annoying \ related format bug + data = fmt.Sprintf("{\"content\": \"%s\"}", data) + type dummyStruct struct { + Content string `json:"content"` + } + var dummy dummyStruct + err := json.Unmarshal([]byte(data), &dummy) + responseText += dummy.Content + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = dummy.Content + response := ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "gemini-pro", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case <-stopChan: c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) diff --git a/controller/relay-text.go b/controller/relay-text.go index 211a34b3..b53b0caa 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -190,10 +190,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { version = c.GetString("api_version") } action := "generateContent" - // actually gemini does not support stream, it's fake - //if textRequest.Stream { - // action = "streamGenerateContent" - //} + if textRequest.Stream { + action = "streamGenerateContent" + } fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") diff --git a/controller/relay.go b/controller/relay.go index 0e660a68..15021997 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -236,7 +236,7 @@ type ChatCompletionsStreamResponseChoice struct { Delta struct { Content string `json:"content"` } `json:"delta"` - FinishReason *string `json:"finish_reason"` + FinishReason *string `json:"finish_reason,omitempty"` } type ChatCompletionsStreamResponse struct { From 7069c49bdf4e4ae39ab4944faaa5fd66b121d3fb Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 18:06:37 +0800 Subject: [PATCH 08/16] fix: fix xunfei panic error (close #820) --- controller/relay-xunfei.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 00ec8981..904e6d14 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -230,7 +230,13 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin case stop = <-stopChan: } } - + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + { + Content: "", + }, + } + } xunfeiResponse.Payload.Choices.Text[0].Content = content response := responseXunfei2OpenAI(&xunfeiResponse) From 6acb9537a921008813eb691e021aa87f316fb82e Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 18:33:27 +0800 Subject: [PATCH 09/16] fix: try to return a more meaningful error message (close #817) --- controller/relay-utils.go | 57 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 9deca75a..a6a1f0f6 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -263,11 +263,52 @@ func setEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } +type GeneralErrorResponse struct { + Error OpenAIError `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" +} + func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, OpenAIError: OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Message: "", Type: "upstream_error", Code: "bad_response_status_code", Param: strconv.Itoa(resp.StatusCode), @@ -281,12 +322,20 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr if err != nil { return } - var textResponse TextResponse - err = json.Unmarshal(responseBody, &textResponse) + var errResponse GeneralErrorResponse + err = json.Unmarshal(responseBody, &errResponse) if err != nil { return } - openAIErrorWithStatusCode.OpenAIError = textResponse.Error + if errResponse.Error.Message != "" { + // OpenAI format error, so we override the default one + openAIErrorWithStatusCode.OpenAIError = errResponse.Error + } else { + openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage() + } + if openAIErrorWithStatusCode.OpenAIError.Message == "" { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } return } From 66f06e5d6f14828c3522fb59bfca35c612685bef Mon Sep 17 00:00:00 2001 From: Ghostz <137054651+ye4293@users.noreply.github.com> Date: Sun, 17 Dec 2023 18:54:08 +0800 Subject: [PATCH 10/16] feat: reset image num to 1 when not given (#821) * Update relay-image.go * fix: reset image num to 1 when not given --------- Co-authored-by: JustSong --- controller/relay-image.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/controller/relay-image.go b/controller/relay-image.go index b3248fcc..bf916474 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -19,7 +19,6 @@ func isWithinRange(element string, value int) bool { if _, ok := common.DalleGenerationImageAmounts[element]; !ok { return false } - min := common.DalleGenerationImageAmounts[element][0] max := common.DalleGenerationImageAmounts[element][1] @@ -42,6 +41,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + // Size validation if imageRequest.Size != "" { imageSize = imageRequest.Size From 7d6a169669f1eb2666abb0747964ab1cd9f86f12 Mon Sep 17 00:00:00 2001 From: Calcium-Ion <61247483+Calcium-Ion@users.noreply.github.com> Date: Sun, 17 Dec 2023 19:17:00 +0800 Subject: [PATCH 11/16] feat: able to set sqlite busy_timeout (#818) * add sqlite busy_timeout=3000 * chore: update impl --------- Co-authored-by: JustSong --- README.md | 1 + common/database.go | 1 + model/main.go | 4 +++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8a1d6caf..916e4331 100644 --- a/README.md +++ b/README.md @@ -371,6 +371,7 @@ graph LR + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/database.go b/common/database.go index c7e9fd52..76f2cd55 100644 --- a/common/database.go +++ b/common/database.go @@ -4,3 +4,4 @@ var UsingSQLite = false var UsingPostgreSQL = false var SQLitePath = "one-api.db" +var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/model/main.go b/model/main.go index 08182634..bfd6888b 100644 --- a/model/main.go +++ b/model/main.go @@ -1,6 +1,7 @@ package model import ( + "fmt" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -59,7 +60,8 @@ func chooseDB() (*gorm.DB, error) { // Use SQLite common.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true - return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ + config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) + return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } From 0fe26cc4bd9d2e85b47ac6ae23fc974295c72f52 Mon Sep 17 00:00:00 2001 From: Oliver Lee Date: Sun, 17 Dec 2023 19:43:23 +0800 Subject: [PATCH 12/16] feat: update ali relay implementation (#830) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修改通译千问最新接口:1.删除history参数,改用官方推荐的messages参数 2.整理messages参数顺序,补充必要上下文信息 3.用autogen调试测试通过 * chore: update impl --------- Co-authored-by: JustSong --- common/model-ratio.go | 6 +++-- controller/model.go | 18 +++++++++++++++ controller/relay-ali.go | 33 ++++++++-------------------- web/src/pages/Channel/EditChannel.js | 2 +- 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index c054fa5f..d1c96d96 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -88,8 +88,10 @@ var ModelRatio = map[string]float64{ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens - "qwen-plus": 10, // ¥0.14 / 1k tokens + "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "qwen-plus": 1.4286, // ¥0.02 / 1k tokens + "qwen-max": 1.4286, // ¥0.02 / 1k tokens + "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 5c8aebc0..9ae40f5c 100644 --- a/controller/model.go +++ b/controller/model.go @@ -486,6 +486,24 @@ func init() { Root: "qwen-plus", Parent: nil, }, + { + Id: "qwen-max", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-max", + Parent: nil, + }, + { + Id: "qwen-max-longcontext", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-max-longcontext", + Parent: nil, + }, { Id: "text-embedding-v1", Object: "model", diff --git a/controller/relay-ali.go b/controller/relay-ali.go index b41ca327..65626f6a 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -13,13 +13,13 @@ import ( // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r type AliMessage struct { - User string `json:"user"` - Bot string `json:"bot"` + Content string `json:"content"` + Role string `json:"role"` } type AliInput struct { - Prompt string `json:"prompt"` - History []AliMessage `json:"history"` + //Prompt string `json:"prompt"` + Messages []AliMessage `json:"messages"` } type AliParameters struct { @@ -83,32 +83,17 @@ type AliChatResponse struct { func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { messages := make([]AliMessage, 0, len(request.Messages)) - prompt := "" for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] - if message.Role == "system" { - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: "Okay", - }) - continue - } else { - if i == len(request.Messages)-1 { - prompt = message.StringContent() - break - } - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: request.Messages[i+1].StringContent(), - }) - i++ - } + messages = append(messages, AliMessage{ + Content: message.StringContent(), + Role: strings.ToLower(message.Role), + }) } return &AliChatRequest{ Model: request.Model, Input: AliInput{ - Prompt: prompt, - History: messages, + Messages: messages, }, //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's // TopP: request.TopP, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 114e5933..364da69d 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -69,7 +69,7 @@ const EditChannel = () => { localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; break; case 17: - localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; + localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']; break; case 16: localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; From bc6769826bb827ca323beac6118bfd15e829fee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ShinChven=20=E2=9C=A8?= Date: Sun, 17 Dec 2023 19:49:08 +0800 Subject: [PATCH 13/16] feat: add condition to validate n value for non-Azure channels (#775) - Add a condition to validate the n value only for non-Azure channels, ensuring it falls within the acceptable range. - Fix Azure compatibility --- controller/relay-image.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/controller/relay-image.go b/controller/relay-image.go index bf916474..7e1fed39 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -82,7 +82,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode // Number of generated images validation if isWithinRange(imageModel, imageRequest.N) == false { - return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + // channel not azure + if channelType != common.ChannelTypeAzure { + return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } } // map model name @@ -105,7 +108,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { + if channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api apiVersion := GetAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview From af378c59afe3510ebc28b7f081a7a8ccfa8de891 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 22:19:16 +0800 Subject: [PATCH 14/16] docs: update readme --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 916e4331..1a967ace 100644 --- a/README.md +++ b/README.md @@ -77,8 +77,6 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [OpenAI-SB](https://openai-sb.com) + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) + [x] [API2D](https://api2d.com/r/197971) - + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) - + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) + [x] 自定义渠道:例如各种未收录的第三方代理服务 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 From 461f5dab561ff09d1bada624058c943daa7b9f99 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 22:25:03 +0800 Subject: [PATCH 15/16] docs: update readme --- README.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index 1a967ace..ff9e0bc0 100644 --- a/README.md +++ b/README.md @@ -73,11 +73,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) + [x] [360 智脑](https://ai.360.cn) + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) -2. 支持配置镜像以及众多第三方代理服务: - + [x] [OpenAI-SB](https://openai-sb.com) - + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) - + [x] [API2D](https://api2d.com/r/197971) - + [x] 自定义渠道:例如各种未收录的第三方代理服务 +2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 5. 支持**多机部署**,[详见此处](#多机部署)。 From 97030e27f88ac27a9ab4bcc2484f7d8e83d29d04 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 23:30:45 +0800 Subject: [PATCH 16/16] fix: fix gemini panic (close #833) --- controller/relay-gemini.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index 4c2daba9..2458458e 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -114,7 +114,7 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { Role: "model", Parts: []GeminiPart{ { - Text: "ok", + Text: "Okay", }, }, }) @@ -130,6 +130,16 @@ type GeminiChatResponse struct { PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` } +func (g *GeminiChatResponse) GetResponseText() string { + if g == nil { + return "" + } + if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { + return g.Candidates[0].Content.Parts[0].Text + } + return "" +} + type GeminiChatCandidate struct { Content GeminiChatContent `json:"content"` FinishReason string `json:"finishReason"` @@ -158,10 +168,13 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse Index: i, Message: Message{ Role: "assistant", - Content: candidate.Content.Parts[0].Text, + Content: "", }, FinishReason: stopFinishReason, } + if len(candidate.Content.Parts) > 0 { + choice.Message.Content = candidate.Content.Parts[0].Text + } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } return &fullTextResponse @@ -169,9 +182,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { var choice ChatCompletionsStreamResponseChoice - if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { - choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text - } + choice.Delta.Content = geminiResponse.GetResponseText() choice.FinishReason = &stopFinishReason var response ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" @@ -276,7 +287,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - completionTokens := countTokenText(geminiResponse.Candidates[0].Content.Parts[0].Text, model) + completionTokens := countTokenText(geminiResponse.GetResponseText(), model) usage := Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens,