From 4c384e61434559a2cd0790cf0dfca70b0aed62b8 Mon Sep 17 00:00:00 2001 From: igophper Date: Sat, 12 Aug 2023 21:07:31 +0800 Subject: [PATCH] feat:baidu channel support apiKey and secretKey MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加百度文心渠道时支持填写secretKey|apiKey或者accessToken,支持自动刷新accessToken --- controller/relay-baidu.go | 101 +++++++++++++++++++++++++++ controller/relay-text.go | 6 +- web/src/pages/Channel/EditChannel.js | 2 +- 3 files changed, 107 insertions(+), 2 deletions(-) diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index 664bbd11..88d51dc8 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -2,12 +2,17 @@ package controller import ( "bufio" + "crypto/md5" "encoding/json" + "errors" + "fmt" "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "strings" + "sync" + "time" ) // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 @@ -73,6 +78,22 @@ type BaiduEmbeddingResponse struct { BaiduError } +type BaiduAccessToken struct { + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + SessionKey string `json:"session_key"` + AccessToken string `json:"access_token"` + Scope string `json:"scope"` + SessionSecret string `json:"session_secret"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + SecretKey string `json:"secret_key,omitempty"` + ApiKey string `json:"api_key,omitempty"` +} + +var baiduAccessTokens sync.Map + func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { messages := make([]BaiduMessage, 0, len(request.Messages)) for _, message := range request.Messages { @@ -299,3 +320,83 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func getBaiduAccessToken(apiKey string) (string, error) { + var accessToken BaiduAccessToken + if val, ok := baiduAccessTokens.Load(md5.Sum([]byte(apiKey))); ok { + if accessToken, ok = val.(BaiduAccessToken); ok { + // 提前1小时刷新 + if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { + go refreshBaiduAccessToken(&accessToken) + return accessToken.AccessToken, nil + } + return accessToken.AccessToken, nil + } + } + + splits := strings.Split(apiKey, "|") + if len(splits) == 1 { + accessToken.AccessToken = apiKey + accessToken.ExpiresAt = time.Now().Add(30 * 24 * time.Hour) + return apiKey, nil + } + + var token string + var err error + if token, err = initBaiduAccessToken(splits[0], splits[1], ""); err != nil { + return "", err + } + + return token, nil +} + +func refreshBaiduAccessToken(accessToken *BaiduAccessToken) error { + if accessToken.RefreshToken == "" { + return nil + } + _, err := initBaiduAccessToken(accessToken.SecretKey, accessToken.ApiKey, accessToken.RefreshToken) + return err +} + +func initBaiduAccessToken(secretKey, apiKey, refreshToken string) (string, error) { + url := "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + apiKey + "&client_secret=" + secretKey + if refreshToken != "" { + url += "&refresh_token=" + refreshToken + } + req, err := http.NewRequest("POST", url, nil) + if err != nil { + return "", errors.New(fmt.Sprintf("initBaiduAccessToken err: %s", err.Error())) + } + + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + + client := http.Client{ + Timeout: 5 * time.Second, + } + + res, err := client.Do(req) + if err != nil { + return "", errors.New(fmt.Sprintf("initBaiduAccessToken request err: %s", err.Error())) + } + defer res.Body.Close() + + var accessToken BaiduAccessToken + err = json.NewDecoder(res.Body).Decode(&accessToken) + if err != nil { + return "", errors.New(fmt.Sprintf("initBaiduAccessToken decode access token err: %s", err.Error())) + } + if accessToken.Error != "" { + return "", errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) + } + + if accessToken.AccessToken == "" { + return "", errors.New("initBaiduAccessToken get access token empty") + } + + accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) + accessToken.SecretKey = secretKey + accessToken.ApiKey = apiKey + baiduAccessTokens.Store(md5.Sum([]byte(secretKey+"|"+apiKey)), accessToken) + return accessToken.AccessToken, nil +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 65f03bcf..3960a7a0 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -146,7 +146,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") - fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days + var err error + if apiKey, err = getBaiduAccessToken(apiKey); err != nil { + return errorWrapper(err, "invalid_auth", http.StatusBadRequest) + } + fullRequestURL += "?access_token=" + apiKey case APITypePaLM: fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" if baseURL != "" { diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 0d7a4a01..1e922d1e 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -355,7 +355,7 @@ const EditChannel = () => { label='密钥' name='key' required - placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} + placeholder={inputs.type === 15 ? '按照如下格式输入:APISecret|APIKey 或者直接输入access token' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} onChange={handleInputChange} value={inputs.key} autoComplete='new-password'