diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index 118e87a6..d66391bc 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -3,22 +3,22 @@ package controller import ( "bufio" "encoding/json" + "errors" + "fmt" "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "strings" + "sync" + "time" ) // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 type BaiduTokenResponse struct { - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` - SessionKey string `json:"session_key"` - AccessToken string `json:"access_token"` - Scope string `json:"scope"` - SessionSecret string `json:"session_secret"` + ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` } type BaiduMessage struct { @@ -73,6 +73,16 @@ type BaiduEmbeddingResponse struct { BaiduError } +type BaiduAccessToken struct { + AccessToken string `json:"access_token"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"-"` +} + +var baiduTokenStore sync.Map + func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { messages := make([]BaiduMessage, 0, len(request.Messages)) for _, message := range request.Messages { @@ -295,3 +305,60 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func getBaiduAccessToken(apiKey string) (string, error) { + if val, ok := baiduTokenStore.Load(apiKey); ok { + var accessToken BaiduAccessToken + if accessToken, ok = val.(BaiduAccessToken); ok { + // soon this will expire + if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { + go func() { + _, _ = getBaiduAccessTokenHelper(apiKey) + }() + } + return accessToken.AccessToken, nil + } + } + accessToken, err := getBaiduAccessTokenHelper(apiKey) + if err != nil { + return "", err + } + if accessToken == nil { + return "", errors.New("getBaiduAccessToken return a nil token") + } + return (*accessToken).AccessToken, nil +} + +func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { + parts := strings.Split(apiKey, "|") + if len(parts) != 2 { + return nil, errors.New("invalid baidu apikey") + } + req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", + parts[0], parts[1]), nil) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + res, err := impatientHTTPClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var accessToken BaiduAccessToken + err = json.NewDecoder(res.Body).Decode(&accessToken) + if err != nil { + return nil, err + } + if accessToken.Error != "" { + return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) + } + if accessToken.AccessToken == "" { + return nil, errors.New("getBaiduAccessTokenHelper get empty access token") + } + accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) + baiduTokenStore.Store(apiKey, accessToken) + return &accessToken, nil +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 1bb463fa..e8dab514 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/model" "strings" + "time" ) const ( @@ -24,9 +25,13 @@ const ( ) var httpClient *http.Client +var impatientHTTPClient *http.Client func init() { httpClient = &http.Client{} + impatientHTTPClient = &http.Client{ + Timeout: 5 * time.Second, + } } func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -145,7 +150,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") - fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days + var err error + if apiKey, err = getBaiduAccessToken(apiKey); err != nil { + return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) + } + fullRequestURL += "?access_token=" + apiKey case APITypePaLM: fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" if baseURL != "" { diff --git a/i18n/en.json b/i18n/en.json index 67ce8a56..a9402419 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -519,5 +519,6 @@ "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!", "代理": "Proxy", "此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com", - "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?" + "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?", + "按照如下格式输入:": "Enter in the following format:" } diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 0d7a4a01..b5fb524e 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -355,7 +355,7 @@ const EditChannel = () => { label='密钥' name='key' required - placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} + placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} onChange={handleInputChange} value={inputs.key} autoComplete='new-password'