diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index 455e30d8..4c2daba9 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -1,11 +1,13 @@ package controller import ( + "bufio" "encoding/json" "fmt" "io" "net/http" "one-api/common" + "strings" "github.com/gin-gonic/gin" ) @@ -180,50 +182,61 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCo func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() dataChan := make(chan string) stopChan := make(chan bool) + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) go func() { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - common.SysError("error reading stream response: " + err.Error()) - stopChan <- true - return + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "\"text\": \"") { + continue + } + data = strings.TrimPrefix(data, "\"text\": \"") + data = strings.TrimSuffix(data, "\"") + dataChan <- data } - err = resp.Body.Close() - if err != nil { - common.SysError("error closing stream response: " + err.Error()) - stopChan <- true - return - } - var geminiResponse GeminiChatResponse - err = json.Unmarshal(responseBody, &geminiResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - stopChan <- true - return - } - fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) - fullTextResponse.Id = responseId - fullTextResponse.Created = createdTime - if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { - responseText += geminiResponse.Candidates[0].Content.Parts[0].Text - } - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - stopChan <- true - return - } - dataChan <- string(jsonResponse) stopChan <- true }() setEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - c.Render(-1, common.CustomEvent{Data: "data: " + data}) + // this is used to prevent annoying \ related format bug + data = fmt.Sprintf("{\"content\": \"%s\"}", data) + type dummyStruct struct { + Content string `json:"content"` + } + var dummy dummyStruct + err := json.Unmarshal([]byte(data), &dummy) + responseText += dummy.Content + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = dummy.Content + response := ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "gemini-pro", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case <-stopChan: c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) diff --git a/controller/relay-text.go b/controller/relay-text.go index 211a34b3..b53b0caa 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -190,10 +190,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { version = c.GetString("api_version") } action := "generateContent" - // actually gemini does not support stream, it's fake - //if textRequest.Stream { - // action = "streamGenerateContent" - //} + if textRequest.Stream { + action = "streamGenerateContent" + } fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") diff --git a/controller/relay.go b/controller/relay.go index 0e660a68..15021997 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -236,7 +236,7 @@ type ChatCompletionsStreamResponseChoice struct { Delta struct { Content string `json:"content"` } `json:"delta"` - FinishReason *string `json:"finish_reason"` + FinishReason *string `json:"finish_reason,omitempty"` } type ChatCompletionsStreamResponse struct {