From 143b61cbf4870a17bdd98976d7c8ffd8ab3accc2 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 12 Aug 2023 23:40:02 +0800 Subject: [PATCH] chore: update implementation --- controller/relay-baidu.go | 97 +++++++++------------------- controller/relay-text.go | 10 ++- i18n/en.json | 3 +- web/src/pages/Channel/EditChannel.js | 2 +- 4 files changed, 41 insertions(+), 71 deletions(-) diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index d0e90510..c9c6d45e 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -2,7 +2,6 @@ package controller import ( "bufio" - "crypto/md5" "encoding/json" "errors" "fmt" @@ -18,12 +17,8 @@ import ( // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 type BaiduTokenResponse struct { - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` - SessionKey string `json:"session_key"` - AccessToken string `json:"access_token"` - Scope string `json:"scope"` - SessionSecret string `json:"session_secret"` + ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` } type BaiduMessage struct { @@ -79,20 +74,14 @@ type BaiduEmbeddingResponse struct { } type BaiduAccessToken struct { - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` - SessionKey string `json:"session_key"` AccessToken string `json:"access_token"` - Scope string `json:"scope"` - SessionSecret string `json:"session_secret"` Error string `json:"error,omitempty"` ErrorDescription string `json:"error_description,omitempty"` - ExpiresAt time.Time `json:"expires_at,omitempty"` - SecretKey string `json:"secret_key,omitempty"` - ApiKey string `json:"api_key,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"-"` } -var baiduAccessTokens sync.Map +var baiduTokenStore sync.Map func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { messages := make([]BaiduMessage, 0, len(request.Messages)) @@ -322,82 +311,58 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit } func getBaiduAccessToken(apiKey string) (string, error) { - var accessToken BaiduAccessToken - md5Key := md5.Sum([]byte(apiKey)) - if val, ok := baiduAccessTokens.Load(md5Key); ok { + if val, ok := baiduTokenStore.Load(apiKey); ok { + var accessToken BaiduAccessToken if accessToken, ok = val.(BaiduAccessToken); ok { - // 提前1小时刷新 + // soon this will expire if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { - go refreshBaiduAccessToken(&accessToken) + go func() { + _, _ = getBaiduAccessTokenHelper(apiKey) + }() } return accessToken.AccessToken, nil } } - - splits := strings.Split(apiKey, "|") - if len(splits) == 1 { - accessToken.AccessToken = apiKey - accessToken.ExpiresAt = time.Now().Add(30 * 24 * time.Hour) - baiduAccessTokens.Store(md5Key, accessToken) - return apiKey, nil - } - - var token string - var err error - if token, err = initBaiduAccessToken(splits[0], splits[1], ""); err != nil { + accessToken, err := getBaiduAccessTokenHelper(apiKey) + if err != nil { return "", err } - - return token, nil + if accessToken == nil { + return "", errors.New("getBaiduAccessToken return a nil token") + } + return (*accessToken).AccessToken, nil } -func refreshBaiduAccessToken(accessToken *BaiduAccessToken) error { - if accessToken.RefreshToken == "" { - return nil +func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { + parts := strings.Split(apiKey, "|") + if len(parts) != 2 { + return nil, errors.New("invalid baidu apikey") } - _, err := initBaiduAccessToken(accessToken.SecretKey, accessToken.ApiKey, accessToken.RefreshToken) - return err -} - -func initBaiduAccessToken(secretKey, apiKey, refreshToken string) (string, error) { - url := "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + apiKey + "&client_secret=" + secretKey - if refreshToken != "" { - url += "&refresh_token=" + refreshToken - } - req, err := http.NewRequest("POST", url, nil) + req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", + parts[0], parts[1]), nil) if err != nil { - return "", errors.New(fmt.Sprintf("initBaiduAccessToken err: %s", err.Error())) + return nil, err } - req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") - - client := http.Client{ - Timeout: 5 * time.Second, - } - - res, err := client.Do(req) + res, err := impatientHTTPClient.Do(req) if err != nil { - return "", errors.New(fmt.Sprintf("initBaiduAccessToken request err: %s", err.Error())) + return nil, err } defer res.Body.Close() var accessToken BaiduAccessToken err = json.NewDecoder(res.Body).Decode(&accessToken) if err != nil { - return "", errors.New(fmt.Sprintf("initBaiduAccessToken decode access token err: %s", err.Error())) + return nil, err } if accessToken.Error != "" { - return "", errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) + return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) } - if accessToken.AccessToken == "" { - return "", errors.New("initBaiduAccessToken get access token empty") + return nil, errors.New("getBaiduAccessTokenHelper get empty access token") } - accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) - accessToken.SecretKey = secretKey - accessToken.ApiKey = apiKey - baiduAccessTokens.Store(md5.Sum([]byte(secretKey+"|"+apiKey)), accessToken) - return accessToken.AccessToken, nil + baiduTokenStore.Store(apiKey, accessToken) + return &accessToken, nil } diff --git a/controller/relay-text.go b/controller/relay-text.go index 3960a7a0..5c42bbcd 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -5,13 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" "strings" - - "github.com/gin-gonic/gin" + "time" ) const ( @@ -25,9 +25,13 @@ const ( ) var httpClient *http.Client +var impatientHTTPClient *http.Client func init() { httpClient = &http.Client{} + impatientHTTPClient = &http.Client{ + Timeout: 5 * time.Second, + } } func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -148,7 +152,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey = strings.TrimPrefix(apiKey, "Bearer ") var err error if apiKey, err = getBaiduAccessToken(apiKey); err != nil { - return errorWrapper(err, "invalid_auth", http.StatusBadRequest) + return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) } fullRequestURL += "?access_token=" + apiKey case APITypePaLM: diff --git a/i18n/en.json b/i18n/en.json index 67ce8a56..a9402419 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -519,5 +519,6 @@ "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!", "代理": "Proxy", "此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com", - "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?" + "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?", + "按照如下格式输入:": "Enter in the following format:" } diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index bb736153..b5fb524e 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -355,7 +355,7 @@ const EditChannel = () => { label='密钥' name='key' required - placeholder={inputs.type === 15 ? '按照如下格式输入:SecretKey|APIKey 或者直接输入access token' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} + placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} onChange={handleInputChange} value={inputs.key} autoComplete='new-password'