feat: support chatbot ui

This commit is contained in:
ckt1031 2023-07-17 15:35:02 +08:00
parent a6ae20ed54
commit 4b9756b257
4 changed files with 128 additions and 22 deletions

View File

@ -158,6 +158,7 @@ const (
// Reserve engineering for public projects // Reserve engineering for public projects
ChannelTypeChatGPTWeb = 14 // Chanzhaoyu/chatgpt-web ChannelTypeChatGPTWeb = 14 // Chanzhaoyu/chatgpt-web
ChannelTypeChatbotUI = 15 // mckaywrigley/chatbot-ui
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{
@ -178,4 +179,5 @@ var ChannelBaseURLs = []string{
// Reserve engineering for public projects // Reserve engineering for public projects
"", // 14 // Chanzhaoyu/chatgpt-web "", // 14 // Chanzhaoyu/chatgpt-web
"", // 15 // mckaywrigley/chatbot-ui
} }

View File

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
@ -38,6 +39,10 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
if channel.BaseURL != "" { if channel.BaseURL != "" {
requestURL = channel.BaseURL requestURL = channel.BaseURL
} }
} else if channel.Type == common.ChannelTypeChatbotUI {
if channel.BaseURL != "" {
requestURL = channel.BaseURL
}
} else { } else {
if channel.BaseURL != "" { if channel.BaseURL != "" {
requestURL = channel.BaseURL requestURL = channel.BaseURL
@ -85,7 +90,35 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
// Convert map to json string // Convert map to json string
jsonData, err = json.Marshal(map1) jsonData, err = json.Marshal(map1)
} else if channel.Type == common.ChannelTypeChatbotUI {
// Get system message from Message json, Role == "system"
var systemMessage string
for _, message := range request.Messages {
if message.Role == "system" {
systemMessage = message.Content
break
} }
}
// Construct json data without adding escape character
map1 := make(map[string]interface{})
map1["prompt"] = systemMessage
map1["temperature"] = formatFloat(request.Temperature)
map1["key"] = ""
map1["messages"] = request.Messages
map1["model"] = map[string]interface{}{
"id": request.Model,
}
// Convert map to json string
jsonData, err = json.Marshal(map1)
//Print jsoinData to console
log.Println(string(jsonData))
}
if err != nil { if err != nil {
return err return err
} }
@ -134,7 +167,7 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
if channel.Type != common.ChannelTypeChatGPTWeb { if channel.Type != common.ChannelTypeChatGPTWeb && channel.Type != common.ChannelTypeChatbotUI {
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
@ -158,7 +191,27 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
continue continue
} }
if channel.Type != common.ChannelTypeChatGPTWeb { if channel.Type == common.ChannelTypeChatGPTWeb {
var chatResponse ChatGptWebChatResponse
err = json.Unmarshal([]byte(data), &chatResponse)
if err != nil {
// Print the body in string
buf := new(bytes.Buffer)
buf.ReadFrom(resp.Body)
common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String())
return err
}
// if response role is assistant and contains delta, append the content to streamResponseText
if chatResponse.Role == "assistant" && chatResponse.Detail != nil {
for _, choice := range chatResponse.Detail.Choices {
streamResponseText += choice.Delta.Content
}
}
} else if channel.Type == common.ChannelTypeChatbotUI {
streamResponseText += data
} else if channel.Type != common.ChannelTypeChatGPTWeb {
// If data has event: event content inside, remove it, it can be prefix or inside the data // If data has event: event content inside, remove it, it can be prefix or inside the data
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
// Remove event: event in the front or back // Remove event: event in the front or back
@ -197,24 +250,6 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
} }
} }
} else if channel.Type == common.ChannelTypeChatGPTWeb {
var chatResponse ChatGptWebChatResponse
err = json.Unmarshal([]byte(data), &chatResponse)
if err != nil {
// Print the body in string
buf := new(bytes.Buffer)
buf.ReadFrom(resp.Body)
common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String())
return err
}
// if response role is assistant and contains delta, append the content to streamResponseText
if chatResponse.Role == "assistant" && chatResponse.Detail != nil {
for _, choice := range chatResponse.Detail.Choices {
streamResponseText += choice.Delta.Content
}
}
} }
} }

View File

@ -11,7 +11,9 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -121,6 +123,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// remove /v1/chat/completions from request url // remove /v1/chat/completions from request url
requestURL := strings.Split(requestURL, "/v1/chat/completions")[0] requestURL := strings.Split(requestURL, "/v1/chat/completions")[0]
fullRequestURL = fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL = fmt.Sprintf("%s%s", baseURL, requestURL)
} else if channelType == common.ChannelTypeChatbotUI {
// remove /v1/chat/completions from request url
requestURL := strings.Split(requestURL, "/v1/chat/completions")[0]
fullRequestURL = fmt.Sprintf("%s%s", baseURL, requestURL)
} else if channelType == common.ChannelTypePaLM { } else if channelType == common.ChannelTypePaLM {
err := relayPaLM(textRequest, c) err := relayPaLM(textRequest, c)
return err return err
@ -241,6 +247,47 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_json_failed", http.StatusInternalServerError) return errorWrapper(err, "marshal_json_failed", http.StatusInternalServerError)
} }
// Convert json string to io.Reader
requestBody = bytes.NewReader(jsonData)
} else if channelType == common.ChannelTypeChatbotUI {
// Get system message from Message json, Role == "system"
var reqBody ChatRequest
// Parse requestBody into systemMessage
err := json.NewDecoder(requestBody).Decode(&reqBody)
if err != nil {
return errorWrapper(err, "decode_request_body_failed", http.StatusInternalServerError)
}
// Get system message from Message json, Role == "system"
var systemMessage string
for _, message := range reqBody.Messages {
if message.Role == "system" {
systemMessage = message.Content
break
}
}
// Construct json data without adding escape character
map1 := make(map[string]interface{})
map1["prompt"] = systemMessage
map1["temperature"] = formatFloat(reqBody.Temperature)
map1["key"] = ""
map1["messages"] = reqBody.Messages
map1["model"] = map[string]interface{}{
"id": reqBody.Model,
}
// Convert map to json string
jsonData, err := json.Marshal(map1)
if err != nil {
return errorWrapper(err, "marshal_json_failed", http.StatusInternalServerError)
}
// Convert json string to io.Reader // Convert json string to io.Reader
requestBody = bytes.NewReader(jsonData) requestBody = bytes.NewReader(jsonData)
} }
@ -348,13 +395,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
}() }()
if isStream { if isStream || channelType == common.ChannelTypeChatGPTWeb || channelType == common.ChannelTypeChatbotUI {
dataChan := make(chan string) dataChan := make(chan string)
stopChan := make(chan bool) stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
if channelType != common.ChannelTypeChatGPTWeb { if channelType != common.ChannelTypeChatGPTWeb && channelType != common.ChannelTypeChatbotUI {
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
@ -417,6 +464,27 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
dataChan <- "data: " + string(jsonData) dataChan <- "data: " + string(jsonData)
} }
} }
} else if channelType == common.ChannelTypeChatbotUI {
returnObj := map[string]interface{}{
"id": "chatcmpl-" + strconv.Itoa(int(time.Now().UnixNano())),
"object": "text_completion",
"created": time.Now().Unix(),
"model": textRequest.Model,
"choices": []map[string]interface{}{
// set finish_reason to null in json
{
"finish_reason": nil,
"index": 0,
"delta": map[string]interface{}{
"content": data,
},
},
},
}
jsonData, _ := json.Marshal(returnObj)
dataChan <- "data: " + string(jsonData)
} else { } else {
// If data has event: event content inside, remove it, it can be prefix or inside the data // If data has event: event content inside, remove it, it can be prefix or inside the data
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {

View File

@ -14,4 +14,5 @@ export const CHANNEL_OPTIONS = [
// //
{ key: 14, text: 'Chanzhaoyu/chatgpt-web', value: 14, color: 'purple' }, { key: 14, text: 'Chanzhaoyu/chatgpt-web', value: 14, color: 'purple' },
{ key: 14, text: 'mckaywrigley/chatbot-ui', value: 15, color: 'orange' },
]; ];