fix: fully support stream mode now (close #3)

This commit is contained in:
JustSong 2023-04-25 21:50:57 +08:00
parent b74a17c963
commit 69ee87c57f
3 changed files with 136 additions and 52 deletions

82
common/custom_event.go Normal file
View File

@ -0,0 +1,82 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package common
import (
"fmt"
"io"
"net/http"
"strings"
)
type stringWriter interface {
io.Writer
writeString(string) (int, error)
}
type stringWrapper struct {
io.Writer
}
func (w stringWrapper) writeString(str string) (int, error) {
return w.Writer.Write([]byte(str))
}
func checkWriter(writer io.Writer) stringWriter {
if w, ok := writer.(stringWriter); ok {
return w
} else {
return stringWrapper{writer}
}
}
// Server-Sent Events
// W3C Working Draft 29 October 2009
// http://www.w3.org/TR/2009/WD-eventsource-20091029/
var contentType = []string{"text/event-stream"}
var noCache = []string{"no-cache"}
var fieldReplacer = strings.NewReplacer(
"\n", "\\n",
"\r", "\\r")
var dataReplacer = strings.NewReplacer(
"\n", "\ndata:",
"\r", "\\r")
type CustomEvent struct {
Event string
Id string
Retry uint
Data interface{}
}
func encode(writer io.Writer, event CustomEvent) error {
w := checkWriter(writer)
return writeData(w, event.Data)
}
func writeData(w stringWriter, data interface{}) error {
dataReplacer.WriteString(w, fmt.Sprint(data))
if strings.HasPrefix(data.(string), "data") {
w.writeString("\n\n")
}
return nil
}
func (r CustomEvent) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
return encode(w, r)
}
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
header := w.Header()
header["Content-Type"] = contentType
if _, exist := header["Cache-Control"]; !exist {
header["Cache-Control"] = noCache
}
}

View File

@ -47,58 +47,60 @@ func Relay(c *gin.Context) {
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body) isStream := resp.Header.Get("Content-Type") == "text/event-stream"
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if isStream {
if atEOF && len(data) == 0 { scanner := bufio.NewScanner(resp.Body)
return 0, nil, nil scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
} if atEOF && len(data) == 0 {
return 0, nil, nil
if i := strings.Index(string(data), "\n\n"); i >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
}
stopChan <- true
}()
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.WriteHeaderNow()
//w := c.Writer
//flusher, _ := w.(http.Flusher)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
suffix := ""
if strings.HasPrefix(data, "data: ") {
suffix = "\n\n"
} }
_, err := fmt.Fprintf(w, "%s%s", data, suffix)
if err != nil { if i := strings.Index(string(data), "\n\n"); i >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
}
stopChan <- true
}()
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false return false
} }
flusher, _ := w.(http.Flusher) })
flusher.Flush() return
//fmt.Println(data) } else {
return true for k, v := range resp.Header {
case <-stopChan: c.Writer.Header().Set(k, v[0])
return false
} }
}) _, err = io.Copy(c.Writer, resp.Body)
return if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
}
} }

View File

@ -2,7 +2,6 @@ package main
import ( import (
"embed" "embed"
"github.com/gin-contrib/gzip"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie" "github.com/gin-contrib/sessions/cookie"
"github.com/gin-contrib/sessions/redis" "github.com/gin-contrib/sessions/redis"
@ -51,7 +50,8 @@ func main() {
// Initialize HTTP server // Initialize HTTP server
server := gin.Default() server := gin.Default()
server.Use(gzip.Gzip(gzip.DefaultCompression)) // This will cause SSE not to work!!!
//server.Use(gzip.Gzip(gzip.DefaultCompression))
server.Use(middleware.CORS()) server.Use(middleware.CORS())
// Initialize session store // Initialize session store