chore: update implementation

This commit is contained in:
JustSong 2023-08-12 23:40:02 +08:00
parent c92a48b48b
commit 143b61cbf4
4 changed files with 41 additions and 71 deletions

View File

@ -2,7 +2,6 @@ package controller
import ( import (
"bufio" "bufio"
"crypto/md5"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -18,12 +17,8 @@ import (
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
type BaiduTokenResponse struct { type BaiduTokenResponse struct {
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
SessionKey string `json:"session_key"`
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
Scope string `json:"scope"`
SessionSecret string `json:"session_secret"`
} }
type BaiduMessage struct { type BaiduMessage struct {
@ -79,20 +74,14 @@ type BaiduEmbeddingResponse struct {
} }
type BaiduAccessToken struct { type BaiduAccessToken struct {
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
SessionKey string `json:"session_key"`
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
Scope string `json:"scope"`
SessionSecret string `json:"session_secret"`
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"` ErrorDescription string `json:"error_description,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"` ExpiresIn int64 `json:"expires_in,omitempty"`
SecretKey string `json:"secret_key,omitempty"` ExpiresAt time.Time `json:"-"`
ApiKey string `json:"api_key,omitempty"`
} }
var baiduAccessTokens sync.Map var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages)) messages := make([]BaiduMessage, 0, len(request.Messages))
@ -322,82 +311,58 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
} }
func getBaiduAccessToken(apiKey string) (string, error) { func getBaiduAccessToken(apiKey string) (string, error) {
if val, ok := baiduTokenStore.Load(apiKey); ok {
var accessToken BaiduAccessToken var accessToken BaiduAccessToken
md5Key := md5.Sum([]byte(apiKey))
if val, ok := baiduAccessTokens.Load(md5Key); ok {
if accessToken, ok = val.(BaiduAccessToken); ok { if accessToken, ok = val.(BaiduAccessToken); ok {
// 提前1小时刷新 // soon this will expire
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go refreshBaiduAccessToken(&accessToken) go func() {
_, _ = getBaiduAccessTokenHelper(apiKey)
}()
} }
return accessToken.AccessToken, nil return accessToken.AccessToken, nil
} }
} }
accessToken, err := getBaiduAccessTokenHelper(apiKey)
splits := strings.Split(apiKey, "|") if err != nil {
if len(splits) == 1 {
accessToken.AccessToken = apiKey
accessToken.ExpiresAt = time.Now().Add(30 * 24 * time.Hour)
baiduAccessTokens.Store(md5Key, accessToken)
return apiKey, nil
}
var token string
var err error
if token, err = initBaiduAccessToken(splits[0], splits[1], ""); err != nil {
return "", err return "", err
} }
if accessToken == nil {
return token, nil return "", errors.New("getBaiduAccessToken return a nil token")
}
return (*accessToken).AccessToken, nil
} }
func refreshBaiduAccessToken(accessToken *BaiduAccessToken) error { func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
if accessToken.RefreshToken == "" { parts := strings.Split(apiKey, "|")
return nil if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
} }
_, err := initBaiduAccessToken(accessToken.SecretKey, accessToken.ApiKey, accessToken.RefreshToken) req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
return err parts[0], parts[1]), nil)
}
func initBaiduAccessToken(secretKey, apiKey, refreshToken string) (string, error) {
url := "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + apiKey + "&client_secret=" + secretKey
if refreshToken != "" {
url += "&refresh_token=" + refreshToken
}
req, err := http.NewRequest("POST", url, nil)
if err != nil { if err != nil {
return "", errors.New(fmt.Sprintf("initBaiduAccessToken err: %s", err.Error())) return nil, err
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/json")
res, err := impatientHTTPClient.Do(req)
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil { if err != nil {
return "", errors.New(fmt.Sprintf("initBaiduAccessToken request err: %s", err.Error())) return nil, err
} }
defer res.Body.Close() defer res.Body.Close()
var accessToken BaiduAccessToken var accessToken BaiduAccessToken
err = json.NewDecoder(res.Body).Decode(&accessToken) err = json.NewDecoder(res.Body).Decode(&accessToken)
if err != nil { if err != nil {
return "", errors.New(fmt.Sprintf("initBaiduAccessToken decode access token err: %s", err.Error())) return nil, err
} }
if accessToken.Error != "" { if accessToken.Error != "" {
return "", errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
} }
if accessToken.AccessToken == "" { if accessToken.AccessToken == "" {
return "", errors.New("initBaiduAccessToken get access token empty") return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
} }
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
accessToken.SecretKey = secretKey baiduTokenStore.Store(apiKey, accessToken)
accessToken.ApiKey = apiKey return &accessToken, nil
baiduAccessTokens.Store(md5.Sum([]byte(secretKey+"|"+apiKey)), accessToken)
return accessToken.AccessToken, nil
} }

View File

@ -5,13 +5,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strings" "strings"
"time"
"github.com/gin-gonic/gin"
) )
const ( const (
@ -25,9 +25,13 @@ const (
) )
var httpClient *http.Client var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() { func init() {
httpClient = &http.Client{} httpClient = &http.Client{}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
} }
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@ -148,7 +152,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
var err error var err error
if apiKey, err = getBaiduAccessToken(apiKey); err != nil { if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
return errorWrapper(err, "invalid_auth", http.StatusBadRequest) return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
} }
fullRequestURL += "?access_token=" + apiKey fullRequestURL += "?access_token=" + apiKey
case APITypePaLM: case APITypePaLM:

View File

@ -519,5 +519,6 @@
"令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!", "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!",
"代理": "Proxy", "代理": "Proxy",
"此项可选,用于通过代理站来进行 API 调用请输入代理站地址格式为https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com", "此项可选,用于通过代理站来进行 API 调用请输入代理站地址格式为https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
"取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?" "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?",
"按照如下格式输入:": "Enter in the following format:"
} }

View File

@ -355,7 +355,7 @@ const EditChannel = () => {
label='密钥' label='密钥'
name='key' name='key'
required required
placeholder={inputs.type === 15 ? '按照如下格式输入:SecretKey|APIKey 或者直接输入access token' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.key} value={inputs.key}
autoComplete='new-password' autoComplete='new-password'