diff --git a/common/custom_event.go b/common/custom_event.go new file mode 100644 index 00000000..69da4bc4 --- /dev/null +++ b/common/custom_event.go @@ -0,0 +1,82 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package common + +import ( + "fmt" + "io" + "net/http" + "strings" +) + +type stringWriter interface { + io.Writer + writeString(string) (int, error) +} + +type stringWrapper struct { + io.Writer +} + +func (w stringWrapper) writeString(str string) (int, error) { + return w.Writer.Write([]byte(str)) +} + +func checkWriter(writer io.Writer) stringWriter { + if w, ok := writer.(stringWriter); ok { + return w + } else { + return stringWrapper{writer} + } +} + +// Server-Sent Events +// W3C Working Draft 29 October 2009 +// http://www.w3.org/TR/2009/WD-eventsource-20091029/ + +var contentType = []string{"text/event-stream"} +var noCache = []string{"no-cache"} + +var fieldReplacer = strings.NewReplacer( + "\n", "\\n", + "\r", "\\r") + +var dataReplacer = strings.NewReplacer( + "\n", "\ndata:", + "\r", "\\r") + +type CustomEvent struct { + Event string + Id string + Retry uint + Data interface{} +} + +func encode(writer io.Writer, event CustomEvent) error { + w := checkWriter(writer) + return writeData(w, event.Data) +} + +func writeData(w stringWriter, data interface{}) error { + dataReplacer.WriteString(w, fmt.Sprint(data)) + if strings.HasPrefix(data.(string), "data") { + w.writeString("\n\n") + } + return nil +} + +func (r CustomEvent) Render(w http.ResponseWriter) error { + r.WriteContentType(w) + return encode(w, r) +} + +func (r CustomEvent) WriteContentType(w http.ResponseWriter) { + header := w.Header() + header["Content-Type"] = contentType + + if _, exist := header["Cache-Control"]; !exist { + header["Cache-Control"] = noCache + } +} diff --git a/controller/relay.go b/controller/relay.go index f2d4003e..6118ca80 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -47,58 +47,60 @@ func Relay(c *gin.Context) { return } defer resp.Body.Close() - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - - if i := strings.Index(string(data), "\n\n"); i >= 0 { - return i + 2, data[0:i], nil - } - - if atEOF { - return len(data), data, nil - } - - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - dataChan <- data - } - stopChan <- true - }() - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.WriteHeaderNow() - //w := c.Writer - //flusher, _ := w.(http.Flusher) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - suffix := "" - if strings.HasPrefix(data, "data: ") { - suffix = "\n\n" + isStream := resp.Header.Get("Content-Type") == "text/event-stream" + if isStream { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil } - _, err := fmt.Fprintf(w, "%s%s", data, suffix) - if err != nil { + + if i := strings.Index(string(data), "\n\n"); i >= 0 { + return i + 2, data[0:i], nil + } + + if atEOF { + return len(data), data, nil + } + + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + dataChan <- data + } + stopChan <- true + }() + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + c.Render(-1, common.CustomEvent{Data: data}) + return true + case <-stopChan: return false } - flusher, _ := w.(http.Flusher) - flusher.Flush() - //fmt.Println(data) - return true - case <-stopChan: - return false + }) + return + } else { + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) } - }) - return + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + return + } + } } diff --git a/main.go b/main.go index 21db969b..0a04ca6c 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( "embed" - "github.com/gin-contrib/gzip" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-contrib/sessions/redis" @@ -51,7 +50,8 @@ func main() { // Initialize HTTP server server := gin.Default() - server.Use(gzip.Gzip(gzip.DefaultCompression)) + // This will cause SSE not to work!!! + //server.Use(gzip.Gzip(gzip.DefaultCompression)) server.Use(middleware.CORS()) // Initialize session store