From 4b9756b257b2b990f8abc9f7c0fca054ecd61ff7 Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Mon, 17 Jul 2023 15:35:02 +0800 Subject: [PATCH] feat: support chatbot ui --- common/constants.go | 2 + controller/channel-test.go | 75 ++++++++++++++++++------- controller/relay-text.go | 72 +++++++++++++++++++++++- web/src/constants/channel.constants.jsx | 1 + 4 files changed, 128 insertions(+), 22 deletions(-) diff --git a/common/constants.go b/common/constants.go index 5d1adbd7..c77ed748 100644 --- a/common/constants.go +++ b/common/constants.go @@ -158,6 +158,7 @@ const ( // Reserve engineering for public projects ChannelTypeChatGPTWeb = 14 // Chanzhaoyu/chatgpt-web + ChannelTypeChatbotUI = 15 // mckaywrigley/chatbot-ui ) var ChannelBaseURLs = []string{ @@ -178,4 +179,5 @@ var ChannelBaseURLs = []string{ // Reserve engineering for public projects "", // 14 // Chanzhaoyu/chatgpt-web + "", // 15 // mckaywrigley/chatbot-ui } diff --git a/controller/channel-test.go b/controller/channel-test.go index 00cc2cc1..e9f0dc9c 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/http" "one-api/common" "one-api/model" @@ -38,6 +39,10 @@ func testChannel(channel *model.Channel, request ChatRequest) error { if channel.BaseURL != "" { requestURL = channel.BaseURL } + } else if channel.Type == common.ChannelTypeChatbotUI { + if channel.BaseURL != "" { + requestURL = channel.BaseURL + } } else { if channel.BaseURL != "" { requestURL = channel.BaseURL @@ -85,7 +90,35 @@ func testChannel(channel *model.Channel, request ChatRequest) error { // Convert map to json string jsonData, err = json.Marshal(map1) + } else if channel.Type == common.ChannelTypeChatbotUI { + // Get system message from Message json, Role == "system" + var systemMessage string + + for _, message := range request.Messages { + if message.Role == "system" { + systemMessage = message.Content + break + } + } + + // Construct json data without adding escape character + map1 := make(map[string]interface{}) + + map1["prompt"] = systemMessage + map1["temperature"] = formatFloat(request.Temperature) + map1["key"] = "" + map1["messages"] = request.Messages + map1["model"] = map[string]interface{}{ + "id": request.Model, + } + + // Convert map to json string + jsonData, err = json.Marshal(map1) + + //Print jsoinData to console + log.Println(string(jsonData)) } + if err != nil { return err } @@ -134,7 +167,7 @@ func testChannel(channel *model.Channel, request ChatRequest) error { scanner := bufio.NewScanner(resp.Body) - if channel.Type != common.ChannelTypeChatGPTWeb { + if channel.Type != common.ChannelTypeChatGPTWeb && channel.Type != common.ChannelTypeChatbotUI { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil @@ -158,7 +191,27 @@ func testChannel(channel *model.Channel, request ChatRequest) error { continue } - if channel.Type != common.ChannelTypeChatGPTWeb { + if channel.Type == common.ChannelTypeChatGPTWeb { + var chatResponse ChatGptWebChatResponse + err = json.Unmarshal([]byte(data), &chatResponse) + if err != nil { + // Print the body in string + buf := new(bytes.Buffer) + buf.ReadFrom(resp.Body) + common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String()) + return err + } + + // if response role is assistant and contains delta, append the content to streamResponseText + if chatResponse.Role == "assistant" && chatResponse.Detail != nil { + for _, choice := range chatResponse.Detail.Choices { + streamResponseText += choice.Delta.Content + } + } + + } else if channel.Type == common.ChannelTypeChatbotUI { + streamResponseText += data + } else if channel.Type != common.ChannelTypeChatGPTWeb { // If data has event: event content inside, remove it, it can be prefix or inside the data if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { // Remove event: event in the front or back @@ -197,24 +250,6 @@ func testChannel(channel *model.Channel, request ChatRequest) error { } } - } else if channel.Type == common.ChannelTypeChatGPTWeb { - var chatResponse ChatGptWebChatResponse - err = json.Unmarshal([]byte(data), &chatResponse) - if err != nil { - // Print the body in string - buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) - common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String()) - return err - } - - // if response role is assistant and contains delta, append the content to streamResponseText - if chatResponse.Role == "assistant" && chatResponse.Detail != nil { - for _, choice := range chatResponse.Detail.Choices { - streamResponseText += choice.Delta.Content - } - } - } } diff --git a/controller/relay-text.go b/controller/relay-text.go index e6b54f4f..54973c23 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -11,7 +11,9 @@ import ( "net/http" "one-api/common" "one-api/model" + "strconv" "strings" + "time" "github.com/gin-gonic/gin" ) @@ -121,6 +123,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { // remove /v1/chat/completions from request url requestURL := strings.Split(requestURL, "/v1/chat/completions")[0] fullRequestURL = fmt.Sprintf("%s%s", baseURL, requestURL) + } else if channelType == common.ChannelTypeChatbotUI { + // remove /v1/chat/completions from request url + requestURL := strings.Split(requestURL, "/v1/chat/completions")[0] + fullRequestURL = fmt.Sprintf("%s%s", baseURL, requestURL) } else if channelType == common.ChannelTypePaLM { err := relayPaLM(textRequest, c) return err @@ -241,6 +247,47 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_json_failed", http.StatusInternalServerError) } + // Convert json string to io.Reader + requestBody = bytes.NewReader(jsonData) + } else if channelType == common.ChannelTypeChatbotUI { + // Get system message from Message json, Role == "system" + var reqBody ChatRequest + + // Parse requestBody into systemMessage + err := json.NewDecoder(requestBody).Decode(&reqBody) + + if err != nil { + return errorWrapper(err, "decode_request_body_failed", http.StatusInternalServerError) + } + + // Get system message from Message json, Role == "system" + var systemMessage string + + for _, message := range reqBody.Messages { + if message.Role == "system" { + systemMessage = message.Content + break + } + } + + // Construct json data without adding escape character + map1 := make(map[string]interface{}) + + map1["prompt"] = systemMessage + map1["temperature"] = formatFloat(reqBody.Temperature) + map1["key"] = "" + map1["messages"] = reqBody.Messages + map1["model"] = map[string]interface{}{ + "id": reqBody.Model, + } + + // Convert map to json string + jsonData, err := json.Marshal(map1) + + if err != nil { + return errorWrapper(err, "marshal_json_failed", http.StatusInternalServerError) + } + // Convert json string to io.Reader requestBody = bytes.NewReader(jsonData) } @@ -348,13 +395,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } }() - if isStream { + if isStream || channelType == common.ChannelTypeChatGPTWeb || channelType == common.ChannelTypeChatbotUI { dataChan := make(chan string) stopChan := make(chan bool) scanner := bufio.NewScanner(resp.Body) - if channelType != common.ChannelTypeChatGPTWeb { + if channelType != common.ChannelTypeChatGPTWeb && channelType != common.ChannelTypeChatbotUI { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil @@ -417,6 +464,27 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { dataChan <- "data: " + string(jsonData) } } + } else if channelType == common.ChannelTypeChatbotUI { + returnObj := map[string]interface{}{ + "id": "chatcmpl-" + strconv.Itoa(int(time.Now().UnixNano())), + "object": "text_completion", + "created": time.Now().Unix(), + "model": textRequest.Model, + "choices": []map[string]interface{}{ + // set finish_reason to null in json + { + "finish_reason": nil, + "index": 0, + "delta": map[string]interface{}{ + "content": data, + }, + }, + }, + } + + jsonData, _ := json.Marshal(returnObj) + + dataChan <- "data: " + string(jsonData) } else { // If data has event: event content inside, remove it, it can be prefix or inside the data if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { diff --git a/web/src/constants/channel.constants.jsx b/web/src/constants/channel.constants.jsx index 784c40d3..1c3fc831 100644 --- a/web/src/constants/channel.constants.jsx +++ b/web/src/constants/channel.constants.jsx @@ -14,4 +14,5 @@ export const CHANNEL_OPTIONS = [ // { key: 14, text: 'Chanzhaoyu/chatgpt-web', value: 14, color: 'purple' }, + { key: 14, text: 'mckaywrigley/chatbot-ui', value: 15, color: 'orange' }, ];