From e92da7928b4029c00666ae54b5d83af70860f6ce Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 28 Jul 2023 23:45:08 +0800 Subject: [PATCH 01/14] feat: support ali's llm (close #326) --- README.md | 1 + common/constants.go | 36 ++-- common/model-ratio.go | 2 + controller/model.go | 18 ++ controller/relay-ali.go | 240 +++++++++++++++++++++++++ controller/relay-text.go | 40 ++++- web/src/constants/channel.constants.js | 3 +- 7 files changed, 321 insertions(+), 19 deletions(-) create mode 100644 controller/relay-ali.go diff --git a/README.md b/README.md index 93fb8247..e5579c3d 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [Anthropic Claude 系列模型](https://anthropic.com) + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) 2. 支持配置镜像以及众多第三方代理服务: + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) diff --git a/common/constants.go b/common/constants.go index 81f98163..5dbfa71c 100644 --- a/common/constants.go +++ b/common/constants.go @@ -156,24 +156,26 @@ const ( ChannelTypeAnthropic = 14 ChannelTypeBaidu = 15 ChannelTypeZhipu = 16 + ChannelTypeAli = 17 ) var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 } diff --git a/common/model-ratio.go b/common/model-ratio.go index 7f991777..ba6d7245 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -46,6 +46,8 @@ var ModelRatio = map[string]float64{ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag + "qwen-plus-v1": 0.5715, // Same as above } func ModelRatio2JSONString() string { diff --git a/controller/model.go b/controller/model.go index b469271f..f8096f75 100644 --- a/controller/model.go +++ b/controller/model.go @@ -324,6 +324,24 @@ func init() { Root: "chatglm_lite", Parent: nil, }, + { + Id: "qwen-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-v1", + Parent: nil, + }, + { + Id: "qwen-plus-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-plus-v1", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay-ali.go b/controller/relay-ali.go new file mode 100644 index 00000000..e8437c27 --- /dev/null +++ b/controller/relay-ali.go @@ -0,0 +1,240 @@ +package controller + +import ( + "bufio" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strings" +) + +// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r + +type AliMessage struct { + User string `json:"user"` + Bot string `json:"bot"` +} + +type AliInput struct { + Prompt string `json:"prompt"` + History []AliMessage `json:"history"` +} + +type AliParameters struct { + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` +} + +type AliChatRequest struct { + Model string `json:"model"` + Input AliInput `json:"input"` + Parameters AliParameters `json:"parameters,omitempty"` +} + +type AliError struct { + Code string `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` +} + +type AliUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type AliOutput struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +type AliChatResponse struct { + Output AliOutput `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} + +func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { + messages := make([]AliMessage, 0, len(request.Messages)) + prompt := "" + for i := 0; i < len(request.Messages); i++ { + message := request.Messages[i] + if message.Role == "system" { + messages = append(messages, AliMessage{ + User: message.Content, + Bot: "Okay", + }) + continue + } else { + if i == len(request.Messages)-1 { + prompt = message.Content + break + } + messages = append(messages, AliMessage{ + User: message.Content, + Bot: request.Messages[i+1].Content, + }) + i++ + } + } + return &AliChatRequest{ + Model: request.Model, + Input: AliInput{ + Prompt: prompt, + History: messages, + }, + //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's + // TopP: request.TopP, + // TopK: 50, + // //Seed: 0, + // //EnableSearch: false, + //}, + } +} + +func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: response.Output.Text, + }, + FinishReason: response.Output.FinishReason, + } + fullTextResponse := OpenAITextResponse{ + Id: response.RequestId, + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []OpenAITextResponseChoice{choice}, + Usage: Usage{ + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + }, + } + return &fullTextResponse +} + +func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = aliResponse.Output.Text + choice.FinishReason = aliResponse.Output.FinishReason + response := ChatCompletionsStreamResponse{ + Id: aliResponse.RequestId, + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "ernie-bot", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var usage Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { // ignore blank line or wrong format + continue + } + if data[:5] != "data:" { + continue + } + data = data[5:] + dataChan <- data + } + stopChan <- true + }() + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + lastResponseText := "" + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var aliResponse AliChatResponse + err := json.Unmarshal([]byte(data), &aliResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + usage.PromptTokens += aliResponse.Usage.InputTokens + usage.CompletionTokens += aliResponse.Usage.OutputTokens + usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens + response := streamResponseAli2OpenAI(&aliResponse) + response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) + lastResponseText = aliResponse.Output.Text + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var aliResponse AliChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &aliResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if aliResponse.Code != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseAli2OpenAI(&aliResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 52e10f2b..e58c810b 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -20,6 +20,7 @@ const ( APITypePaLM APITypeBaidu APITypeZhipu + APITypeAli ) var httpClient *http.Client @@ -94,6 +95,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypePaLM case common.ChannelTypeZhipu: apiType = APITypeZhipu + case common.ChannelTypeAli: + apiType = APITypeAli + } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -153,6 +157,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { method = "sse-invoke" } fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) + case APITypeAli: + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" } var promptTokens int var completionTokens int @@ -226,6 +232,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeAli: + aliRequest := requestOpenAI2Ali(textRequest) + jsonStr, err := json.Marshal(aliRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -250,6 +263,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { case APITypeZhipu: token := getZhipuToken(apiKey) req.Header.Set("Authorization", token) + case APITypeAli: + req.Header.Set("Authorization", "Bearer "+apiKey) + if textRequest.Stream { + req.Header.Set("X-DashScope-SSE", "enable") + } } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) @@ -280,7 +298,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if strings.HasPrefix(textRequest.Model, "gpt-4") { completionRatio = 2 } - if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu { + if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu && apiType != APITypeAli { completionTokens = countTokenText(streamResponseText, textRequest.Model) } else { promptTokens = textResponse.Usage.PromptTokens @@ -415,6 +433,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeAli: + if isStream { + err, usage := aliStreamHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } else { + err, usage := aliHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 54d7716f..16df9894 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -4,6 +4,7 @@ export const CHANNEL_OPTIONS = [ { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, + { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, @@ -14,5 +15,5 @@ export const CHANNEL_OPTIONS = [ { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, - { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } + { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' }, ]; \ No newline at end of file From d1335ebc01b6080bcef2d2fbd5bfaacc38dff5c3 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 28 Jul 2023 23:47:36 +0800 Subject: [PATCH 02/14] docs: update README --- README.en.md | 16 +++++++--------- README.md | 1 - 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/README.en.md b/README.en.md index f635a798..1c5968bc 100644 --- a/README.en.md +++ b/README.en.md @@ -57,15 +57,13 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. ## Features -1. Supports multiple API access channels: - + [x] Official OpenAI channel (support proxy configuration) - + [x] **Azure OpenAI API** - + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) - + [x] [OpenAI-SB](https://openai-sb.com) - + [x] [API2D](https://api2d.com/r/197971) - + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) - + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (invitation code: `OneAPI`) - + [x] Custom channel: Various third-party proxy services not included in the list +1. Support for multiple large models: + + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + + [x] [Anthropic Claude Series Models](https://anthropic.com) + + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) + + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) + + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) 2. Supports access to multiple channels through **load balancing**. 3. Supports **stream mode** that enables typewriter-like effect through stream transmission. 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. diff --git a/README.md b/README.md index e5579c3d..070f90bf 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,6 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) 2. 支持配置镜像以及众多第三方代理服务: - + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) + [x] [OpenAI-SB](https://openai-sb.com) + [x] [API2D](https://api2d.com/r/197971) + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) From 130e6bfd83eef5290410fbbd8df823e07c65c777 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 29 Jul 2023 12:15:07 +0800 Subject: [PATCH 03/14] feat: support baidu's embedding model (close #324) --- common/model-ratio.go | 1 + controller/model.go | 9 +++++ controller/relay-baidu.go | 85 +++++++++++++++++++++++++++++++++++++++ controller/relay-text.go | 25 ++++++++++-- controller/relay.go | 13 ++++++ 5 files changed, 129 insertions(+), 4 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index ba6d7245..123451f7 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -42,6 +42,7 @@ var ModelRatio = map[string]float64{ "claude-2": 30, "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens + "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/model.go b/controller/model.go index f8096f75..123b0a2f 100644 --- a/controller/model.go +++ b/controller/model.go @@ -288,6 +288,15 @@ func init() { Root: "ERNIE-Bot-turbo", Parent: nil, }, + { + Id: "Embedding-V1", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "Embedding-V1", + Parent: nil, + }, { Id: "PaLM-2", Object: "model", diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index 4267757d..7960e8ee 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -54,6 +54,25 @@ type BaiduChatStreamResponse struct { IsEnd bool `json:"is_end"` } +type BaiduEmbeddingRequest struct { + Input []string `json:"input"` +} + +type BaiduEmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type BaiduEmbeddingResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Data []BaiduEmbeddingData `json:"data"` + Usage Usage `json:"usage"` + BaiduError +} + func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { messages := make([]BaiduMessage, 0, len(request.Messages)) for _, message := range request.Messages { @@ -112,6 +131,36 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom return &response } +func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { + baiduEmbeddingRequest := BaiduEmbeddingRequest{ + Input: nil, + } + switch request.Input.(type) { + case string: + baiduEmbeddingRequest.Input = []string{request.Input.(string)} + case []string: + baiduEmbeddingRequest.Input = request.Input.([]string) + } + return &baiduEmbeddingRequest +} + +func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { + openAIEmbeddingResponse := OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), + Model: "baidu-embedding", + Usage: response.Usage, + } + for _, item := range response.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { var usage Usage scanner := bufio.NewScanner(resp.Body) @@ -212,3 +261,39 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var baiduResponse BaiduEmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &baiduResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if baiduResponse.ErrorMsg != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: baiduResponse.ErrorMsg, + Type: "baidu_error", + Param: "", + Code: baiduResponse.ErrorCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index e58c810b..7d3fe1de 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -139,6 +139,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" case "BLOOMZ-7B": fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" + case "Embedding-V1": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" } apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") @@ -212,12 +214,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } requestBody = bytes.NewBuffer(jsonStr) case APITypeBaidu: - baiduRequest := requestOpenAI2Baidu(textRequest) - jsonStr, err := json.Marshal(baiduRequest) + var jsonData []byte + var err error + switch relayMode { + case RelayModeEmbeddings: + baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) + jsonData, err = json.Marshal(baiduEmbeddingRequest) + default: + baiduRequest := requestOpenAI2Baidu(textRequest) + jsonData, err = json.Marshal(baiduRequest) + } if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } - requestBody = bytes.NewBuffer(jsonStr) + requestBody = bytes.NewBuffer(jsonData) case APITypePaLM: palmRequest := requestOpenAI2PaLM(textRequest) jsonStr, err := json.Marshal(palmRequest) @@ -386,7 +396,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } else { - err, usage := baiduHandler(c, resp) + var err *OpenAIErrorWithStatusCode + var usage *Usage + switch relayMode { + case RelayModeEmbeddings: + err, usage = baiduEmbeddingHandler(c, resp) + default: + err, usage = baiduHandler(c, resp) + } if err != nil { return err } diff --git a/controller/relay.go b/controller/relay.go index 9cfa5c4f..609ae2eb 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -99,6 +99,19 @@ type OpenAITextResponse struct { Usage `json:"usage"` } +type OpenAIEmbeddingResponseItem struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +type OpenAIEmbeddingResponse struct { + Object string `json:"object"` + Data []OpenAIEmbeddingResponseItem `json:"data"` + Model string `json:"model"` + Usage `json:"usage"` +} + type ImageResponse struct { Created int `json:"created"` Data []struct { From f31d400b6fee13880a353ed3b1b1e8aa8b6cc124 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 29 Jul 2023 12:24:23 +0800 Subject: [PATCH 04/14] chore: automatically add related models when switch type --- web/src/pages/Channel/EditChannel.js | 33 +++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 7833c7f3..4d810014 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -35,6 +35,27 @@ const EditChannel = () => { const [customModel, setCustomModel] = useState(''); const handleInputChange = (e, { name, value }) => { setInputs((inputs) => ({ ...inputs, [name]: value })); + if (name === 'type' && inputs.models.length === 0) { + let localModels = []; + switch (value) { + case 14: + localModels = ['claude-instant-1', 'claude-2']; + break; + case 11: + localModels = ['PaLM-2']; + break; + case 15: + localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; + break; + case 17: + localModels = ['qwen-v1', 'qwen-plus-v1']; + break; + case 16: + localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; + break; + } + setInputs((inputs) => ({ ...inputs, models: localModels })); + } }; const loadChannel = async () => { @@ -270,8 +291,8 @@ const EditChannel = () => { }}>清除所有模型 { - if (customModel.trim() === "") return; + + From 50dec03ff39f9a51fea4737484968172af7c23e3 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 29 Jul 2023 19:16:42 +0800 Subject: [PATCH 05/14] fix: fix model mapping cannot be deleted --- web/src/pages/Channel/EditChannel.js | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 4d810014..bb0567de 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -153,7 +153,10 @@ const EditChannel = () => { localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); } if (localInputs.type === 3 && localInputs.other === '') { - localInputs.other = '2023-03-15-preview'; + localInputs.other = '2023-06-01-preview'; + } + if (localInputs.model_mapping === '') { + localInputs.model_mapping = '{}'; } let res; localInputs.models = localInputs.models.join(','); @@ -213,7 +216,7 @@ const EditChannel = () => { { Date: Sat, 29 Jul 2023 19:17:26 +0800 Subject: [PATCH 06/14] fix: fix model mapping cannot be deleted --- controller/relay-text.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controller/relay-text.go b/controller/relay-text.go index 7d3fe1de..79dca606 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -74,7 +74,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { // map model name modelMapping := c.GetString("model_mapping") isModelMapped := false - if modelMapping != "" { + if modelMapping != "" && modelMapping != "{}" { modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { From b8cb86c2c16946bdffd8a07913af0060fa854b08 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 29 Jul 2023 19:32:06 +0800 Subject: [PATCH 07/14] chore: adjust ui --- web/src/components/ChannelsTable.js | 29 +++++++++++++++------------- web/src/components/UsersTable.js | 2 +- web/src/constants/toast.constants.js | 2 +- web/src/pages/Token/EditToken.js | 2 +- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 4ea6965d..0459619a 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -363,9 +363,12 @@ const ChannelsTable = () => { { + updateChannelBalance(channel.id, channel.name, idx); + }} style={{ cursor: 'pointer' }}> + {renderBalance(channel.type, channel.balance)} + } + content="点击更新" basic /> @@ -380,16 +383,16 @@ const ChannelsTable = () => { > 测试 - + {/* {*/} + {/* updateChannelBalance(channel.id, channel.name, idx);*/} + {/* }}*/} + {/*>*/} + {/* 更新余额*/} + {/**/} diff --git a/web/src/components/UsersTable.js b/web/src/components/UsersTable.js index 08ba961a..f8fb0a75 100644 --- a/web/src/components/UsersTable.js +++ b/web/src/components/UsersTable.js @@ -227,7 +227,7 @@ const UsersTable = () => { content={user.email ? user.email : '未绑定邮箱地址'} key={user.username} header={user.display_name ? user.display_name : user.username} - trigger={{renderText(user.username, 10)}} + trigger={{renderText(user.username, 15)}} hoverable /> diff --git a/web/src/constants/toast.constants.js b/web/src/constants/toast.constants.js index 8b212350..50684722 100644 --- a/web/src/constants/toast.constants.js +++ b/web/src/constants/toast.constants.js @@ -1,5 +1,5 @@ export const toastConstants = { - SUCCESS_TIMEOUT: 500, + SUCCESS_TIMEOUT: 1500, INFO_TIMEOUT: 3000, ERROR_TIMEOUT: 5000, WARNING_TIMEOUT: 10000, diff --git a/web/src/pages/Token/EditToken.js b/web/src/pages/Token/EditToken.js index a4b6044f..1f85520b 100644 --- a/web/src/pages/Token/EditToken.js +++ b/web/src/pages/Token/EditToken.js @@ -83,7 +83,7 @@ const EditToken = () => { if (isEdit) { showSuccess('令牌更新成功!'); } else { - showSuccess('令牌创建成功!'); + showSuccess('令牌创建成功,请在列表页面点击复制获取令牌!'); setInputs(originInputs); } } else { From 3e81d8af45077751e93d127edef6dcc592b7fc5e Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 29 Jul 2023 19:50:29 +0800 Subject: [PATCH 08/14] chore: update i18n --- i18n/en.json | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/i18n/en.json b/i18n/en.json index 3ef1b010..3c430a7e 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -503,5 +503,12 @@ "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT", "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", "Homepage URL 填": "Fill in the Homepage URL", - "Authorization callback URL 填": "Fill in the Authorization callback URL" + "Authorization callback URL 填": "Fill in the Authorization callback URL", + "请为通道命名": "Please name the channel", + "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", + "模型重定向": "Model redirection", + "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", + "注意,": "Note that, ", + ",图片演示。": "related image demo.", + "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!" } From 8a866078b2d74af49f42a12565356e86c879aa8e Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 29 Jul 2023 21:55:57 +0800 Subject: [PATCH 09/14] feat: support xunfei's llm (close #206) --- README.md | 1 + common/constants.go | 2 + common/model-ratio.go | 1 + controller/model.go | 9 + controller/relay-text.go | 113 ++++++---- controller/relay-xunfei.go | 274 +++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 + web/src/constants/channel.constants.js | 3 +- 9 files changed, 363 insertions(+), 43 deletions(-) create mode 100644 controller/relay-xunfei.go diff --git a/README.md b/README.md index 070f90bf..e01ea7d9 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) 2. 支持配置镜像以及众多第三方代理服务: + [x] [OpenAI-SB](https://openai-sb.com) diff --git a/common/constants.go b/common/constants.go index 5dbfa71c..c4bb6671 100644 --- a/common/constants.go +++ b/common/constants.go @@ -157,6 +157,7 @@ const ( ChannelTypeBaidu = 15 ChannelTypeZhipu = 16 ChannelTypeAli = 17 + ChannelTypeXunfei = 18 ) var ChannelBaseURLs = []string{ @@ -178,4 +179,5 @@ var ChannelBaseURLs = []string{ "https://aip.baidubce.com", // 15 "https://open.bigmodel.cn", // 16 "https://dashscope.aliyuncs.com", // 17 + "", // 18 } diff --git a/common/model-ratio.go b/common/model-ratio.go index 123451f7..5865b4dc 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -49,6 +49,7 @@ var ModelRatio = map[string]float64{ "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens "qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag "qwen-plus-v1": 0.5715, // Same as above + "SparkDesk": 0.8572, // TBD } func ModelRatio2JSONString() string { diff --git a/controller/model.go b/controller/model.go index 123b0a2f..c68aa50c 100644 --- a/controller/model.go +++ b/controller/model.go @@ -351,6 +351,15 @@ func init() { Root: "qwen-plus-v1", Parent: nil, }, + { + Id: "SparkDesk", + Object: "model", + Created: 1677649963, + OwnedBy: "xunfei", + Permission: permission, + Root: "SparkDesk", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay-text.go b/controller/relay-text.go index 79dca606..48e7176a 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -21,6 +21,7 @@ const ( APITypeBaidu APITypeZhipu APITypeAli + APITypeXunfei ) var httpClient *http.Client @@ -97,7 +98,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeZhipu case common.ChannelTypeAli: apiType = APITypeAli - + case common.ChannelTypeXunfei: + apiType = APITypeXunfei } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -250,52 +252,60 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } requestBody = bytes.NewBuffer(jsonStr) } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - switch apiType { - case APITypeOpenAI: - if channelType == common.ChannelTypeAzure { - req.Header.Set("api-key", apiKey) - } else { - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + + var req *http.Request + var resp *http.Response + isStream := textRequest.Stream + + if apiType != APITypeXunfei { // cause xunfei use websocket + req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - case APITypeClaude: - req.Header.Set("x-api-key", apiKey) - anthropicVersion := c.Request.Header.Get("anthropic-version") - if anthropicVersion == "" { - anthropicVersion = "2023-06-01" + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + switch apiType { + case APITypeOpenAI: + if channelType == common.ChannelTypeAzure { + req.Header.Set("api-key", apiKey) + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + } + case APITypeClaude: + req.Header.Set("x-api-key", apiKey) + anthropicVersion := c.Request.Header.Get("anthropic-version") + if anthropicVersion == "" { + anthropicVersion = "2023-06-01" + } + req.Header.Set("anthropic-version", anthropicVersion) + case APITypeZhipu: + token := getZhipuToken(apiKey) + req.Header.Set("Authorization", token) + case APITypeAli: + req.Header.Set("Authorization", "Bearer "+apiKey) + if textRequest.Stream { + req.Header.Set("X-DashScope-SSE", "enable") + } } - req.Header.Set("anthropic-version", anthropicVersion) - case APITypeZhipu: - token := getZhipuToken(apiKey) - req.Header.Set("Authorization", token) - case APITypeAli: - req.Header.Set("Authorization", "Bearer "+apiKey) - if textRequest.Stream { - req.Header.Set("X-DashScope-SSE", "enable") + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + //req.Header.Set("Connection", c.Request.Header.Get("Connection")) + resp, err = httpClient.Do(req) + if err != nil { + return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) } + err = req.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + err = c.Request.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + isStream = strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") } - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - //req.Header.Set("Connection", c.Request.Header.Get("Connection")) - resp, err := httpClient.Do(req) - if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) - } - err = req.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } + var textResponse TextResponse - isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") var streamResponseText string defer func() { @@ -470,6 +480,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeXunfei: + if isStream { + auth := c.Request.Header.Get("Authorization") + auth = strings.TrimPrefix(auth, "Bearer ") + splits := strings.Split(auth, "|") + if len(splits) != 3 { + return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + } + err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } else { + return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) + } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go new file mode 100644 index 00000000..9343f216 --- /dev/null +++ b/controller/relay-xunfei.go @@ -0,0 +1,274 @@ +package controller + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "io" + "net/http" + "net/url" + "one-api/common" + "strings" + "time" +) + +// https://console.xfyun.cn/services/cbm +// https://www.xfyun.cn/doc/spark/Web.html + +type XunfeiMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type XunfeiChatRequest struct { + Header struct { + AppId string `json:"app_id"` + } `json:"header"` + Parameter struct { + Chat struct { + Domain string `json:"domain,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Auditing bool `json:"auditing,omitempty"` + } `json:"chat"` + } `json:"parameter"` + Payload struct { + Message struct { + Text []XunfeiMessage `json:"text"` + } `json:"message"` + } `json:"payload"` +} + +type XunfeiChatResponseTextItem struct { + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` +} + +type XunfeiChatResponse struct { + Header struct { + Code int `json:"code"` + Message string `json:"message"` + Sid string `json:"sid"` + Status int `json:"status"` + } `json:"header"` + Payload struct { + Choices struct { + Status int `json:"status"` + Seq int `json:"seq"` + Text []XunfeiChatResponseTextItem `json:"text"` + } `json:"choices"` + } `json:"payload"` + Usage struct { + //Text struct { + // QuestionTokens string `json:"question_tokens"` + // PromptTokens string `json:"prompt_tokens"` + // CompletionTokens string `json:"completion_tokens"` + // TotalTokens string `json:"total_tokens"` + //} `json:"text"` + Text Usage `json:"text"` + } `json:"usage"` +} + +func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest { + messages := make([]XunfeiMessage, 0, len(request.Messages)) + for _, message := range request.Messages { + if message.Role == "system" { + messages = append(messages, XunfeiMessage{ + Role: "user", + Content: message.Content, + }) + messages = append(messages, XunfeiMessage{ + Role: "assistant", + Content: "Okay", + }) + } else { + messages = append(messages, XunfeiMessage{ + Role: message.Role, + Content: message.Content, + }) + } + } + xunfeiRequest := XunfeiChatRequest{} + xunfeiRequest.Header.AppId = xunfeiAppId + xunfeiRequest.Parameter.Chat.Domain = "general" + xunfeiRequest.Parameter.Chat.Temperature = request.Temperature + xunfeiRequest.Parameter.Chat.TopK = request.N + xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens + xunfeiRequest.Payload.Message.Text = messages + return &xunfeiRequest +} + +func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { + if len(response.Payload.Choices.Text) == 0 { + response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + { + Content: "", + }, + } + } + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: response.Payload.Choices.Text[0].Content, + }, + } + fullTextResponse := OpenAITextResponse{ + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []OpenAITextResponseChoice{choice}, + Usage: response.Usage.Text, + } + return &fullTextResponse +} + +func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + { + Content: "", + }, + } + } + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content + response := ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "SparkDesk", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { + HmacWithShaToBase64 := func(algorithm, data, key string) string { + mac := hmac.New(sha256.New, []byte(key)) + mac.Write([]byte(data)) + encodeData := mac.Sum(nil) + return base64.StdEncoding.EncodeToString(encodeData) + } + ul, err := url.Parse(hostUrl) + if err != nil { + fmt.Println(err) + } + date := time.Now().UTC().Format(time.RFC1123) + signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} + sign := strings.Join(signString, "\n") + sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) + authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, + "hmac-sha256", "host date request-line", sha) + authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) + v := url.Values{} + v.Add("host", ul.Host) + v.Add("date", date) + v.Add("authorization", authorization) + callUrl := hostUrl + "?" + v.Encode() + return callUrl +} + +func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiKey string, apiSecret string) (*OpenAIErrorWithStatusCode, *Usage) { + var usage Usage + d := websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + hostUrl := "wss://aichat.xf-yun.com/v1/chat" + conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) + if err != nil || resp.StatusCode != 101 { + return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil + } + data := requestOpenAI2Xunfei(textRequest, appId) + err = conn.WriteJSON(data) + if err != nil { + return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil + } + dataChan := make(chan XunfeiChatResponse) + stopChan := make(chan bool) + go func() { + for { + _, msg, err := conn.ReadMessage() + if err != nil { + common.SysError("error reading stream response: " + err.Error()) + break + } + var response XunfeiChatResponse + err = json.Unmarshal(msg, &response) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + break + } + dataChan <- response + if response.Payload.Choices.Status == 2 { + break + } + } + stopChan <- true + }() + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Stream(func(w io.Writer) bool { + select { + case xunfeiResponse := <-dataChan: + usage.PromptTokens += xunfeiResponse.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Usage.Text.TotalTokens + response := streamResponseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + return nil, &usage +} + +func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var xunfeiResponse XunfeiChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &xunfeiResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if xunfeiResponse.Header.Code != 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: xunfeiResponse.Header.Message, + Type: "xunfei_error", + Param: "", + Code: xunfeiResponse.Header.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/go.mod b/go.mod index 2e0cf017..1d061520 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/go-redis/redis/v8 v8.11.5 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.3.0 + github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.1 golang.org/x/crypto v0.9.0 gorm.io/driver/mysql v1.4.3 diff --git a/go.sum b/go.sum index 7287206a..c6e4423c 100644 --- a/go.sum +++ b/go.sum @@ -67,6 +67,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 16df9894..f51c6c44 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -5,6 +5,7 @@ export const CHANNEL_OPTIONS = [ { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, + { key: 18, text: '讯飞星火认知大模型', value: 18, color: 'blue' }, { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, @@ -15,5 +16,5 @@ export const CHANNEL_OPTIONS = [ { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, - { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' }, + { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } ]; \ No newline at end of file From ce9c8024a6b18bb9b1950da20d1fc6b5f6c2b3a7 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 29 Jul 2023 22:05:15 +0800 Subject: [PATCH 10/14] chore: update prompt for xunfei --- controller/relay-xunfei.go | 2 +- web/src/pages/Channel/EditChannel.js | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 9343f216..cd55df89 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -173,7 +173,7 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { return callUrl } -func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiKey string, apiSecret string) (*OpenAIErrorWithStatusCode, *Usage) { +func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { var usage Usage d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index bb0567de..6974315e 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -53,6 +53,9 @@ const EditChannel = () => { case 16: localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; break; + case 18: + localModels = ['SparkDesk']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); } @@ -347,7 +350,7 @@ const EditChannel = () => { label='密钥' name='key' required - placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : '请输入渠道对应的鉴权密钥'} + placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} onChange={handleInputChange} value={inputs.key} autoComplete='new-password' From b7d0616ae0a9b520700e9337c43b4de64600c9f7 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 29 Jul 2023 22:09:10 +0800 Subject: [PATCH 11/14] chore: update title for xunfei --- web/src/constants/channel.constants.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index f51c6c44..a17ef374 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -5,7 +5,7 @@ export const CHANNEL_OPTIONS = [ { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, - { key: 18, text: '讯飞星火认知大模型', value: 18, color: 'blue' }, + { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, From fe8f216dd9da50050eb28fb469b7c823a9b7dec1 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 29 Jul 2023 22:32:05 +0800 Subject: [PATCH 12/14] refactor: update billing related code --- controller/relay-text.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/controller/relay-text.go b/controller/relay-text.go index 48e7176a..ceac4103 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -306,7 +306,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } var textResponse TextResponse - var streamResponseText string defer func() { if consumeQuota { @@ -318,16 +317,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if strings.HasPrefix(textRequest.Model, "gpt-4") { completionRatio = 2 } - if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu && apiType != APITypeAli { - completionTokens = countTokenText(streamResponseText, textRequest.Model) - } else { - promptTokens = textResponse.Usage.PromptTokens - completionTokens = textResponse.Usage.CompletionTokens - if apiType == APITypeZhipu { - // zhipu's API does not return prompt tokens & completion tokens - promptTokens = textResponse.Usage.TotalTokens - } - } + + promptTokens = textResponse.Usage.PromptTokens + completionTokens = textResponse.Usage.CompletionTokens + quota = promptTokens + int(float64(completionTokens)*completionRatio) quota = int(float64(quota) * ratio) if ratio != 0 && quota <= 0 { @@ -365,7 +358,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if err != nil { return err } - streamResponseText = responseText + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) return nil } else { err, usage := openaiHandler(c, resp, consumeQuota) @@ -383,7 +377,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if err != nil { return err } - streamResponseText = responseText + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) return nil } else { err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) @@ -428,7 +423,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if err != nil { return err } - streamResponseText = responseText + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) return nil } else { err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) @@ -449,6 +445,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if usage != nil { textResponse.Usage = *usage } + // zhipu's API does not return prompt tokens & completion tokens + textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens return nil } else { err, usage := zhipuHandler(c, resp) @@ -458,6 +456,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if usage != nil { textResponse.Usage = *usage } + // zhipu's API does not return prompt tokens & completion tokens + textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens return nil } case APITypeAli: From 065147b440b075321cfaecfca356dae5d36e592a Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 29 Jul 2023 23:52:18 +0800 Subject: [PATCH 13/14] fix: close connection when response ended --- controller/relay-xunfei.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index cd55df89..c6d78a84 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -205,6 +205,10 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId } dataChan <- response if response.Payload.Choices.Status == 2 { + err := conn.Close() + if err != nil { + common.SysError("error closing websocket connection: " + err.Error()) + } break } } From ec88c0c24092b442ed9aef4d2c10c5d16eabb9ea Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 29 Jul 2023 23:54:09 +0800 Subject: [PATCH 14/14] fix: prompt user that channel test is unavailable --- controller/channel-test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/controller/channel-test.go b/controller/channel-test.go index be658fa8..8465d51d 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -23,6 +23,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr case common.ChannelTypeBaidu: fallthrough case common.ChannelTypeZhipu: + fallthrough + case common.ChannelTypeXunfei: return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil case common.ChannelTypeAzure: request.Model = "gpt-35-turbo"