feat:baidu channel support apiKey and secretKey

添加百度文心渠道时支持填写secretKey|apiKey或者accessToken,支持自动刷新accessToken
This commit is contained in:
igophper 2023-08-12 21:07:31 +08:00
parent 7e2bca7e9c
commit 4c384e6143
3 changed files with 107 additions and 2 deletions

View File

@ -2,12 +2,17 @@ package controller
import ( import (
"bufio" "bufio"
"crypto/md5"
"encoding/json" "encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"strings" "strings"
"sync"
"time"
) )
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
@ -73,6 +78,22 @@ type BaiduEmbeddingResponse struct {
BaiduError BaiduError
} }
type BaiduAccessToken struct {
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
SessionKey string `json:"session_key"`
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
SessionSecret string `json:"session_secret"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
SecretKey string `json:"secret_key,omitempty"`
ApiKey string `json:"api_key,omitempty"`
}
var baiduAccessTokens sync.Map
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages)) messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages { for _, message := range request.Messages {
@ -299,3 +320,83 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage return nil, &fullTextResponse.Usage
} }
func getBaiduAccessToken(apiKey string) (string, error) {
var accessToken BaiduAccessToken
if val, ok := baiduAccessTokens.Load(md5.Sum([]byte(apiKey))); ok {
if accessToken, ok = val.(BaiduAccessToken); ok {
// 提前1小时刷新
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go refreshBaiduAccessToken(&accessToken)
return accessToken.AccessToken, nil
}
return accessToken.AccessToken, nil
}
}
splits := strings.Split(apiKey, "|")
if len(splits) == 1 {
accessToken.AccessToken = apiKey
accessToken.ExpiresAt = time.Now().Add(30 * 24 * time.Hour)
return apiKey, nil
}
var token string
var err error
if token, err = initBaiduAccessToken(splits[0], splits[1], ""); err != nil {
return "", err
}
return token, nil
}
func refreshBaiduAccessToken(accessToken *BaiduAccessToken) error {
if accessToken.RefreshToken == "" {
return nil
}
_, err := initBaiduAccessToken(accessToken.SecretKey, accessToken.ApiKey, accessToken.RefreshToken)
return err
}
func initBaiduAccessToken(secretKey, apiKey, refreshToken string) (string, error) {
url := "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + apiKey + "&client_secret=" + secretKey
if refreshToken != "" {
url += "&refresh_token=" + refreshToken
}
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return "", errors.New(fmt.Sprintf("initBaiduAccessToken err: %s", err.Error()))
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
return "", errors.New(fmt.Sprintf("initBaiduAccessToken request err: %s", err.Error()))
}
defer res.Body.Close()
var accessToken BaiduAccessToken
err = json.NewDecoder(res.Body).Decode(&accessToken)
if err != nil {
return "", errors.New(fmt.Sprintf("initBaiduAccessToken decode access token err: %s", err.Error()))
}
if accessToken.Error != "" {
return "", errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
}
if accessToken.AccessToken == "" {
return "", errors.New("initBaiduAccessToken get access token empty")
}
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
accessToken.SecretKey = secretKey
accessToken.ApiKey = apiKey
baiduAccessTokens.Store(md5.Sum([]byte(secretKey+"|"+apiKey)), accessToken)
return accessToken.AccessToken, nil
}

View File

@ -146,7 +146,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
apiKey := c.Request.Header.Get("Authorization") apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days var err error
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
return errorWrapper(err, "invalid_auth", http.StatusBadRequest)
}
fullRequestURL += "?access_token=" + apiKey
case APITypePaLM: case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
if baseURL != "" { if baseURL != "" {

View File

@ -355,7 +355,7 @@ const EditChannel = () => {
label='密钥' label='密钥'
name='key' name='key'
required required
placeholder={inputs.type === 15 ? '请输入 access token当前版本暂不支持自动刷新请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} placeholder={inputs.type === 15 ? '按照如下格式输入APISecret|APIKey 或者直接输入access token' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.key} value={inputs.key}
autoComplete='new-password' autoComplete='new-password'