diff --git a/controller/relay.go b/controller/relay.go index 4dc4a9b8..5c8a10e0 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,11 +1,14 @@ package controller import ( + "bufio" + "bytes" "fmt" "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" + "strings" ) func Relay(c *gin.Context) { @@ -30,18 +33,8 @@ func Relay(c *gin.Context) { //req.Header.Del("Accept-Encoding") req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - acceptHeader := c.Request.Header.Get("Accept") - if acceptHeader != "" { - req.Header.Set("Accept", acceptHeader) - } - connectionHeader := c.Request.Header.Get("Connection") - if connectionHeader != "" { - req.Header.Set("Connection", connectionHeader) - } - lastEventIDHeader := c.Request.Header.Get("Last-Event-ID") - if lastEventIDHeader != "" { - req.Header.Set("Last-Event-ID", lastEventIDHeader) - } + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + req.Header.Set("Connection", c.Request.Header.Get("Connection")) client := &http.Client{} resp, err := client.Do(req) @@ -54,17 +47,57 @@ func Relay(c *gin.Context) { }) return } + defer resp.Body.Close() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + + if i := strings.Index(string(data), "\n\n"); i >= 0 { + return i + 2, data[0:i], nil + } + + if atEOF { + return len(data), data, nil + } + + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + dataChan <- data + } + stopChan <- true + }() for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - return - } + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + //fmt.Println(data) + //c.Data(http.StatusOK, "text/event-stream", []byte(data)) + //c.Render(-1, common.Event{Data: data}) + //c.SSEvent("", data) + //w.Write([]byte(data)) + //w.(http.Flusher).Flush() + //c.Writer.Write(append([]byte(data), []byte("\n\n")...)) + outputBytes := bytes.NewBufferString(data) + w.Write(outputBytes.Bytes()) + if strings.HasPrefix(data, "data: ") { + w.Write([]byte("\n\n")) + } + //w.Write(append(outputBytes.Bytes(), []byte("\n\n")...)) + w.(http.Flusher).Flush() + //fmt.Println(data) + return true + case <-stopChan: + return false + } + }) + return }