From b4b4acc288d7c1d9b942a85abe0f2b505ecb62a7 Mon Sep 17 00:00:00 2001 From: JustSong Date: Tue, 3 Oct 2023 14:19:03 +0800 Subject: [PATCH] feat: support Tencent's model (close #519) --- README.md | 1 + common/constants.go | 48 +++-- common/model-ratio.go | 1 + controller/model.go | 9 + controller/option.go | 2 +- controller/relay-tencent.go | 287 +++++++++++++++++++++++++ controller/relay-text.go | 43 ++++ middleware/distributor.go | 4 +- web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 5 + 10 files changed, 375 insertions(+), 26 deletions(-) create mode 100644 controller/relay-tencent.go diff --git a/README.md b/README.md index 98e38b8e..cb641947 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) + [x] [360 智脑](https://ai.360.cn) + + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) 2. 支持配置镜像以及众多第三方代理服务: + [x] [OpenAI-SB](https://openai-sb.com) + [x] [CloseAI](https://console.closeai-asia.com/r/2412) diff --git a/common/constants.go b/common/constants.go index d07dcc8a..a0361c35 100644 --- a/common/constants.go +++ b/common/constants.go @@ -186,30 +186,32 @@ const ( ChannelTypeOpenRouter = 20 ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 + ChannelTypeTencent = 23 ) var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 - "https://fastgpt.run/api/openapi", // 22 + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.cloud.tencent.com", //23 } diff --git a/common/model-ratio.go b/common/model-ratio.go index 8156afc3..f1ce99a5 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -59,6 +59,7 @@ var ModelRatio = map[string]float64{ "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 } func ModelRatio2JSONString() string { diff --git a/controller/model.go b/controller/model.go index 593bdaba..e9b64514 100644 --- a/controller/model.go +++ b/controller/model.go @@ -423,6 +423,15 @@ func init() { Root: "semantic_similarity_s1_v1", Parent: nil, }, + { + Id: "hunyuan", + Object: "model", + Created: 1677649963, + OwnedBy: "tencent", + Permission: permission, + Root: "hunyuan", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/option.go b/controller/option.go index 9cf4ff1b..bbf83578 100644 --- a/controller/option.go +++ b/controller/option.go @@ -46,7 +46,7 @@ func UpdateOption(c *gin.Context) { if option.Value == "true" && common.GitHubClientId == "" { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!", + "message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", }) return } diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go new file mode 100644 index 00000000..024468bc --- /dev/null +++ b/controller/relay-tencent.go @@ -0,0 +1,287 @@ +package controller + +import ( + "bufio" + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "sort" + "strconv" + "strings" +) + +// https://cloud.tencent.com/document/product/1729/97732 + +type TencentMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type TencentChatRequest struct { + AppId int64 `json:"app_id"` // 腾讯云账号的 APPID + SecretId string `json:"secret_id"` // 官网 SecretId + // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 + // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 + Timestamp int64 `json:"timestamp"` + // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, + // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 + Expired int64 `json:"expired"` + QueryID string `json:"query_id"` //请求 Id,用于问题排查 + // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 + // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 + // 建议该参数和 top_p 只设置1个,不要同时更改 top_p + Temperature float64 `json:"temperature"` + // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 + // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 + // 建议该参数和 temperature 只设置1个,不要同时更改 + TopP float64 `json:"top_p"` + // Stream 0:同步,1:流式 (默认,协议:SSE) + // 同步请求超时:60s,如果内容较长建议使用流式 + Stream int `json:"stream"` + // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 + // 输入 content 总数最大支持 3000 token。 + Messages []TencentMessage `json:"messages"` +} + +type TencentError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type TencentUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type TencentResponseChoices struct { + FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 + Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 + Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 +} + +type TencentChatResponse struct { + Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 + Created string `json:"created,omitempty"` // unix 时间戳的字符串 + Id string `json:"id,omitempty"` // 会话 id + Usage Usage `json:"usage,omitempty"` // token 数量 + Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 + Note string `json:"note,omitempty"` // 注释 + ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 +} + +func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { + messages := make([]TencentMessage, 0, len(request.Messages)) + for i := 0; i < len(request.Messages); i++ { + message := request.Messages[i] + if message.Role == "system" { + messages = append(messages, TencentMessage{ + Role: "user", + Content: message.Content, + }) + messages = append(messages, TencentMessage{ + Role: "assistant", + Content: "Okay", + }) + continue + } + messages = append(messages, TencentMessage{ + Content: message.Content, + Role: message.Role, + }) + } + stream := 0 + if request.Stream { + stream = 1 + } + return &TencentChatRequest{ + Timestamp: common.GetTimestamp(), + Expired: common.GetTimestamp() + 24*60*60, + QueryID: common.GetUUID(), + Temperature: request.Temperature, + TopP: request.TopP, + Stream: stream, + Messages: messages, + } +} + +func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Object: "chat.completion", + Created: common.GetTimestamp(), + Usage: response.Usage, + } + if len(response.Choices) > 0 { + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: response.Choices[0].Messages.Content, + }, + FinishReason: response.Choices[0].FinishReason, + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { + response := ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "tencent-hunyuan", + } + if len(TencentResponse.Choices) > 0 { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = TencentResponse.Choices[0].Delta.Content + if TencentResponse.Choices[0].FinishReason == "stop" { + choice.FinishReason = &stopFinishReason + } + response.Choices = append(response.Choices, choice) + } + return &response +} + +func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + var responseText string + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { // ignore blank line or wrong format + continue + } + if data[:5] != "data:" { + continue + } + data = data[5:] + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var TencentResponse TencentChatResponse + err := json.Unmarshal([]byte(data), &TencentResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response := streamResponseTencent2OpenAI(&TencentResponse) + if len(response.Choices) != 0 { + responseText += response.Choices[0].Delta.Content + } + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var TencentResponse TencentChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &TencentResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if TencentResponse.Error.Code != 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: TencentResponse.Error.Message, + Code: TencentResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseTencent2OpenAI(&TencentResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { + parts := strings.Split(config, "|") + if len(parts) != 3 { + err = errors.New("invalid tencent config") + return + } + appId, err = strconv.ParseInt(parts[0], 10, 64) + secretId = parts[1] + secretKey = parts[2] + return +} + +func getTencentSign(req TencentChatRequest, secretKey string) string { + params := make([]string, 0) + params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) + params = append(params, "secret_id="+req.SecretId) + params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) + params = append(params, "query_id="+req.QueryID) + params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) + params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) + params = append(params, "stream="+strconv.Itoa(req.Stream)) + params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) + + var messageStr string + for _, msg := range req.Messages { + messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) + } + messageStr = strings.TrimSuffix(messageStr, ",") + params = append(params, "messages=["+messageStr+"]") + + sort.Sort(sort.StringSlice(params)) + url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") + mac := hmac.New(sha1.New, []byte(secretKey)) + signURL := url + mac.Write([]byte(signURL)) + sign := mac.Sum([]byte(nil)) + return base64.StdEncoding.EncodeToString(sign) +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 3041e3a9..fba49a7f 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -24,6 +24,7 @@ const ( APITypeAli APITypeXunfei APITypeAIProxyLibrary + APITypeTencent ) var httpClient *http.Client @@ -109,6 +110,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeXunfei case common.ChannelTypeAIProxyLibrary: apiType = APITypeAIProxyLibrary + case common.ChannelTypeTencent: + apiType = APITypeTencent } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -179,6 +182,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if relayMode == RelayModeEmbeddings { fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" } + case APITypeTencent: + fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" case APITypeAIProxyLibrary: fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) } @@ -285,6 +290,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeTencent: + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + appId, secretId, secretKey, err := parseTencentConfig(apiKey) + if err != nil { + return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) + } + tencentRequest := requestOpenAI2Tencent(textRequest) + tencentRequest.AppId = appId + tencentRequest.SecretId = secretId + jsonStr, err := json.Marshal(tencentRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + sign := getTencentSign(*tencentRequest, secretKey) + c.Request.Header.Set("Authorization", sign) + requestBody = bytes.NewBuffer(jsonStr) case APITypeAIProxyLibrary: aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) aiProxyLibraryRequest.LibraryId = c.GetString("library_id") @@ -332,6 +354,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { req.Header.Set("X-DashScope-SSE", "enable") } + case APITypeTencent: + req.Header.Set("Authorization", apiKey) default: req.Header.Set("Authorization", "Bearer "+apiKey) } @@ -584,6 +608,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeTencent: + if isStream { + err, responseText := tencentStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := tencentHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } diff --git a/middleware/distributor.go b/middleware/distributor.go index 9ded3231..d80945fc 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -25,12 +25,12 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } channel, err = model.GetChannelById(id, true) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } if channel.Status != common.ChannelStatusEnabled { diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index e42afc6e..76407745 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -8,6 +8,7 @@ export const CHANNEL_OPTIONS = [ { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 19, text: '360 智脑', value: 19, color: 'blue' }, + { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index e11912bd..82077c1f 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -19,6 +19,8 @@ function type2secretPrompt(type) { return '按照如下格式输入:APPID|APISecret|APIKey'; case 22: return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; + case 23: + return '按照如下格式输入:AppId|SecretId|SecretKey'; default: return '请输入渠道对应的鉴权密钥'; } @@ -78,6 +80,9 @@ const EditChannel = () => { case 19: localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; break; + case 23: + localModels = ['hunyuan']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); }