From 8f0799d90952b01c21dee6723c1332f4b18ca2fd Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Sat, 15 Jul 2023 21:03:27 +0800 Subject: [PATCH] feat: support reverse proxy of `Chanzhaoyu/chatgpt-web` --- common/constants.go | 6 + controller/channel-test.go | 144 +++++++++++---- controller/relay-text.go | 242 ++++++++++++++++++------- controller/relay.go | 29 +++ go.mod | 4 +- go.sum | 5 + web/src/constants/channel.constants.js | 5 +- 7 files changed, 328 insertions(+), 107 deletions(-) diff --git a/common/constants.go b/common/constants.go index f29153f2..7d505141 100644 --- a/common/constants.go +++ b/common/constants.go @@ -154,6 +154,9 @@ const ( ChannelTypePaLM = 11 ChannelTypeAPI2GPT = 12 ChannelTypeAIGC2D = 13 + + // Reserve engineering for public projects + ChannelTypeChatGPTWeb = 14 // Chanzhaoyu/chatgpt-web ) var ChannelBaseURLs = []string{ @@ -171,4 +174,7 @@ var ChannelBaseURLs = []string{ "", // 11 "https://api.api2gpt.com", // 12 "https://api.aigc2d.com", // 13 + + // Reserve engineering for public projects + "", // 14 // Chanzhaoyu/chatgpt-web } diff --git a/controller/channel-test.go b/controller/channel-test.go index 54bb0c71..070e6501 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/http" "one-api/common" "one-api/model" @@ -27,6 +28,11 @@ func testChannel(channel *model.Channel, request ChatRequest) error { requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) + } else if channel.Type == common.ChannelTypeChatGPTWeb { + if channel.BaseURL != "" { + requestURL = channel.BaseURL + } + requestURL += "/api/chat-process" } else { if channel.BaseURL != "" { requestURL = channel.BaseURL @@ -35,6 +41,41 @@ func testChannel(channel *model.Channel, request ChatRequest) error { } jsonData, err := json.Marshal(request) + + if channel.Type == common.ChannelTypeChatGPTWeb { + // Get system message from Message json, Role == "system" + var systemMessage Message + + for _, message := range request.Messages { + if message.Role == "system" { + systemMessage = message + break + } + } + + var prompt string + + // Get all the Message, Roles from request.Messages, and format it into string by + // ||> role: content + for _, message := range request.Messages { + // Exclude system message + if message.Role == "system" { + continue + } + prompt += "||> " + message.Role + ": " + message.Content + "\n" + } + + // Construct json data without adding escape character + map1 := map[string]string{ + "prompt": prompt, + "systemMessage": systemMessage.Content, + "temperature": strconv.FormatFloat(request.Temperature, 'f', 2, 64), + "top_p": strconv.FormatFloat(request.TopP, 'f', 2, 64), + } + + // Convert map to json string + jsonData, err = json.Marshal(map1) + } if err != nil { return err } @@ -104,52 +145,83 @@ func testChannel(channel *model.Channel, request ChatRequest) error { common.SysError("invalid stream response: " + data) continue } - // If data has event: event content inside, remove it, it can be prefix or inside the data - if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { - // Remove event: event in the front or back - data = strings.TrimPrefix(data, "event: event") - data = strings.TrimSuffix(data, "event: event") - // Remove everything, only keep `data: {...}` <--- this is the json - // Find the start and end indices of `data: {...}` substring - startIndex := strings.Index(data, "data:") - endIndex := strings.LastIndex(data, "}") + if channel.Type != common.ChannelTypeChatGPTWeb { + // If data has event: event content inside, remove it, it can be prefix or inside the data + if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { + // Remove event: event in the front or back + data = strings.TrimPrefix(data, "event: event") + data = strings.TrimSuffix(data, "event: event") + // Remove everything, only keep `data: {...}` <--- this is the json + // Find the start and end indices of `data: {...}` substring + startIndex := strings.Index(data, "data:") + endIndex := strings.LastIndex(data, "}") - // If both indices are found and end index is greater than start index - if startIndex != -1 && endIndex != -1 && endIndex > startIndex { - // Extract the `data: {...}` substring - data = data[startIndex : endIndex+1] - } + // If both indices are found and end index is greater than start index + if startIndex != -1 && endIndex != -1 && endIndex > startIndex { + // Extract the `data: {...}` substring + data = data[startIndex : endIndex+1] + } - // Trim whitespace and newlines from the modified data string - data = strings.TrimSpace(data) - } - if !strings.HasPrefix(data, "data:") { - continue - } - data = data[6:] - if !strings.HasPrefix(data, "[DONE]") { - var streamResponse ChatCompletionsStreamResponse - err = json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - // Prinnt the body in string - buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) - common.SysError("error unmarshalling stream response: " + err.Error() + " " + buf.String()) - return err + // Trim whitespace and newlines from the modified data string + data = strings.TrimSpace(data) } - for _, choice := range streamResponse.Choices { - streamResponseText += choice.Delta.Content + if !strings.HasPrefix(data, "data:") { + continue + } + data = data[6:] + if !strings.HasPrefix(data, "[DONE]") { + var streamResponse ChatCompletionsStreamResponse + err = json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + // Prinnt the body in string + buf := new(bytes.Buffer) + buf.ReadFrom(resp.Body) + common.SysError("error unmarshalling stream response: " + err.Error() + " " + buf.String()) + return err + } + for _, choice := range streamResponse.Choices { + streamResponseText += choice.Delta.Content + } + } else { + done = true + break + } + } else if channel.Type == common.ChannelTypeChatGPTWeb { + // data may contain multiple json objects, so we need to split them + // they are "{....}{....}{....}" or "{....}\n{....}\n{....}" or "{....}" + + // remove all spaces and newlines outside of json objects + jsonObjs := strings.Split(data, "\n") // Split the data into multiple JSON objects + for _, jsonObj := range jsonObjs { + if jsonObj == "" { + continue + } + + var chatResponse ChatGptWebChatResponse + err = json.Unmarshal([]byte(jsonObj), &chatResponse) + if err != nil { + // Print the body in string + buf := new(bytes.Buffer) + buf.ReadFrom(resp.Body) + common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String()) + return err + } + + // if response role is assistant and contains delta, append the content to streamResponseText + if chatResponse.Role == "assistant" && chatResponse.Detail != nil { + for _, choice := range chatResponse.Detail.Choices { + log.Print(choice.Delta.Content) + streamResponseText += choice.Delta.Content + } + } } - } else { - done = true - break } } defer resp.Body.Close() // Check if streaming is complete and streamResponseText is populated - if streamResponseText == "" || !done { + if streamResponseText == "" || !done && channel.Type != common.ChannelTypeChatGPTWeb { return errors.New("Streaming not complete") } diff --git a/controller/relay-text.go b/controller/relay-text.go index 36b18a1b..0f2472f7 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -11,6 +11,7 @@ import ( "net/http" "one-api/common" "one-api/model" + "strconv" "strings" "github.com/gin-gonic/gin" @@ -114,6 +115,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { model_ = strings.TrimSuffix(model_, "-0314") model_ = strings.TrimSuffix(model_, "-0613") fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) + } else if channelType == common.ChannelTypeChatGPTWeb { + // remove /v1/chat/completions from request url + requestURL := strings.Split(requestURL, "/v1/chat/completions")[0] + requestURL += "/api/chat-process" + + fullRequestURL = fmt.Sprintf("%s%s", baseURL, requestURL) } else if channelType == common.ChannelTypePaLM { err := relayPaLM(textRequest, c) return err @@ -182,6 +189,57 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { requestBody = bytes.NewBuffer(bodyBytes) } + + if channelType == common.ChannelTypeChatGPTWeb { + // Get system message from Message json, Role == "system" + var reqBody ChatRequest + var systemMessage Message + + // Parse requestBody into systemMessage + err := json.NewDecoder(requestBody).Decode(&reqBody) + + if err != nil { + return errorWrapper(err, "decode_request_body_failed", http.StatusInternalServerError) + } + + for _, message := range reqBody.Messages { + if message.Role == "system" { + systemMessage = message + break + } + } + + var prompt string + + // Get all the Message, Roles from request.Messages, and format it into string by + // ||> role: content + for _, message := range reqBody.Messages { + // Exclude system message + if message.Role == "system" { + continue + } + prompt += "||> " + message.Role + ": " + message.Content + "\n" + } + + // Construct json data without adding escape character + map1 := map[string]string{ + "prompt": prompt, + "systemMessage": systemMessage.Content, + "temperature": strconv.FormatFloat(reqBody.Temperature, 'f', 2, 64), + "top_p": strconv.FormatFloat(reqBody.TopP, 'f', 2, 64), + } + + // Convert map to json string + jsonData, err := json.Marshal(map1) + + if err != nil { + return errorWrapper(err, "marshal_json_failed", http.StatusInternalServerError) + } + + // Convert json string to io.Reader + requestBody = bytes.NewReader(jsonData) + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) @@ -235,7 +293,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } var textResponse TextResponse - isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") || strings.HasPrefix(resp.Header.Get("Content-Type"), "application/octet-stream") var streamResponseText string defer func() { @@ -286,82 +344,129 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { }() if isStream { - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - - if i := strings.Index(string(data), "\n\n"); i >= 0 { - return i + 2, data[0:i], nil - } - - if atEOF { - return len(data), data, nil - } - - return 0, nil, nil - }) dataChan := make(chan string) stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // must be something wrong! - common.SysError("invalid stream response: " + data) - continue - } - // If data has event: event content inside, remove it, it can be prefix or inside the data - if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { - // Remove event: event in the front or back - data = strings.TrimPrefix(data, "event: event") - data = strings.TrimSuffix(data, "event: event") - // Remove everything, only keep `data: {...}` <--- this is the json - // Find the start and end indices of `data: {...}` substring - startIndex := strings.Index(data, "data:") - endIndex := strings.LastIndex(data, "}") - // If both indices are found and end index is greater than start index - if startIndex != -1 && endIndex != -1 && endIndex > startIndex { - // Extract the `data: {...}` substring - data = data[startIndex : endIndex+1] + if channelType == common.ChannelTypeChatGPTWeb { + scanner := bufio.NewScanner(resp.Body) + go func() { + for scanner.Scan() { + var chatResponse ChatGptWebChatResponse + err = json.Unmarshal(scanner.Bytes(), &chatResponse) + + if err != nil { + log.Println("error unmarshal chat response: " + err.Error()) + continue } - // Trim whitespace and newlines from the modified data string - data = strings.TrimSpace(data) - } - if !strings.HasPrefix(data, "data:") { - continue - } - dataChan <- data - data = data[6:] - if !strings.HasPrefix(data, "[DONE]") { - switch relayMode { - case RelayModeChatCompletions: - var streamResponse ChatCompletionsStreamResponse - err = json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return - } - for _, choice := range streamResponse.Choices { + // if response role is assistant and contains delta, append the content to streamResponseText + if chatResponse.Role == "assistant" && chatResponse.Detail != nil { + for _, choice := range chatResponse.Detail.Choices { streamResponseText += choice.Delta.Content - } - case RelayModeCompletions: - var streamResponse CompletionsStreamResponse - err = json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return - } - for _, choice := range streamResponse.Choices { - streamResponseText += choice.Text + + returnObj := map[string]interface{}{ + "id": chatResponse.ID, + "object": chatResponse.Detail.Object, + "created": chatResponse.Detail.Created, + "model": chatResponse.Detail.Model, + "choices": []map[string]interface{}{ + // set finish_reason to null in json + { + "finish_reason": nil, + "index": 0, + "delta": map[string]interface{}{ + "content": choice.Delta.Content, + }, + }, + }, + } + + jsonData, _ := json.Marshal(returnObj) + + dataChan <- "data: " + string(jsonData) } } } - } - stopChan <- true - }() + stopChan <- true + }() + } else { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + + if i := strings.Index(string(data), "\n\n"); i >= 0 { + return i + 2, data[0:i], nil + } + + if atEOF { + return len(data), data, nil + } + + return 0, nil, nil + }) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // must be something wrong! + common.SysError("invalid stream response: " + data) + continue + } + // If data has event: event content inside, remove it, it can be prefix or inside the data + if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { + // Remove event: event in the front or back + data = strings.TrimPrefix(data, "event: event") + data = strings.TrimSuffix(data, "event: event") + // Remove everything, only keep `data: {...}` <--- this is the json + // Find the start and end indices of `data: {...}` substring + startIndex := strings.Index(data, "data:") + endIndex := strings.LastIndex(data, "}") + + // If both indices are found and end index is greater than start index + if startIndex != -1 && endIndex != -1 && endIndex > startIndex { + // Extract the `data: {...}` substring + data = data[startIndex : endIndex+1] + } + + // Trim whitespace and newlines from the modified data string + data = strings.TrimSpace(data) + } + if !strings.HasPrefix(data, "data:") { + continue + } + dataChan <- data + data = data[6:] + if !strings.HasPrefix(data, "[DONE]") { + switch relayMode { + case RelayModeChatCompletions: + var streamResponse ChatCompletionsStreamResponse + err = json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return + } + for _, choice := range streamResponse.Choices { + streamResponseText += choice.Delta.Content + } + case RelayModeCompletions: + var streamResponse CompletionsStreamResponse + err = json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return + } + for _, choice := range streamResponse.Choices { + streamResponseText += choice.Text + } + } + + } + } + stopChan <- true + }() + } + c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") @@ -373,6 +478,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if strings.HasPrefix(data, "data: [DONE]") { data = data[:12] } + log.Print(data) c.Render(-1, common.CustomEvent{Data: data}) return true case <-stopChan: diff --git a/controller/relay.go b/controller/relay.go index bef667ec..84d7f7bd 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -46,6 +46,9 @@ type ChatRequest struct { Messages []Message `json:"messages"` MaxTokens *int `json:"max_tokens,omitempty"` Stream bool `json:"stream"` + // -1.0 to 1.0 + Temperature float64 `json:"temperature"` + TopP float64 `json:"top_p"` } type TextRequest struct { @@ -102,6 +105,32 @@ type CompletionsStreamResponse struct { } `json:"choices"` } +type ChatGptWebDetail struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []ChatGptWebChoice `json:"choices"` +} + +type ChatGptWebChoice struct { + Delta struct { + Content string `json:"content"` + Role string `json:"role"` + } `json:"delta"` + Index int `json:"index"` + Finish_Reason string `json:"finish_reason"` +} + +type ChatGptWebChatResponse struct { + Role string `json:"role"` + ID string `json:"id"` + ParentMessageID string `json:"parentMessageId"` + Text string `json:"text"` + Delta string `json:"delta"` + Detail *ChatGptWebDetail `json:"detail"` +} + func Relay(c *gin.Context) { relayMode := RelayModeUnknown if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { diff --git a/go.mod b/go.mod index 1c3d25ac..6aaca91e 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/go-playground/validator/v10 v10.14.1 github.com/go-redis/redis/v8 v8.11.5 github.com/google/uuid v1.3.0 - github.com/pkoukk/tiktoken-go v0.1.4 + github.com/pkoukk/tiktoken-go v0.1.5 golang.org/x/crypto v0.11.0 gorm.io/driver/mysql v1.5.1 gorm.io/driver/sqlite v1.5.2 @@ -46,7 +46,7 @@ require ( github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pelletier/go-toml/v2 v2.0.9 // indirect github.com/realTristan/disgoauth v1.0.2 github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect diff --git a/go.sum b/go.sum index d8410454..3ae6c1ff 100644 --- a/go.sum +++ b/go.sum @@ -130,11 +130,15 @@ github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= +github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo= github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw= github.com/pkoukk/tiktoken-go v0.1.4 h1:bniMzWdUvNO6YkRbASo2x5qJf2LAG/TIJojqz+Igm8E= github.com/pkoukk/tiktoken-go v0.1.4/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= +github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4= +github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/ravener/discord-oauth2 v0.0.0-20230514095040-ae65713199b3 h1:x3LgcvujjG+mx8PUMfPmwn3tcu2aA95uCB6ilGGObWk= @@ -157,6 +161,7 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 7d732223..ad18c464 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -10,5 +10,8 @@ export const CHANNEL_OPTIONS = [ { key: 9, text: 'AI.LS', value: 9, color: 'yellow' }, { key: 10, text: 'AI Proxy', value: 10, color: 'purple' }, { key: 12, text: 'API2GPT', value: 12, color: 'blue' }, - { key: 13, text: 'AIGC2D', value: 13, color: 'purple' } + { key: 13, text: 'AIGC2D', value: 13, color: 'purple' }, + + // + { key: 14, text: 'Chanzhaoyu/chatgpt-web', value: 14, color: 'purple' }, ]; \ No newline at end of file