chore: update implementation
This commit is contained in:
parent
c92a48b48b
commit
143b61cbf4
@ -2,7 +2,6 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/md5"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -18,12 +17,8 @@ import (
|
|||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||||
|
|
||||||
type BaiduTokenResponse struct {
|
type BaiduTokenResponse struct {
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
ExpiresIn int `json:"expires_in"`
|
||||||
SessionKey string `json:"session_key"`
|
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
Scope string `json:"scope"`
|
|
||||||
SessionSecret string `json:"session_secret"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type BaiduMessage struct {
|
type BaiduMessage struct {
|
||||||
@ -79,20 +74,14 @@ type BaiduEmbeddingResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type BaiduAccessToken struct {
|
type BaiduAccessToken struct {
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
SessionKey string `json:"session_key"`
|
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
Scope string `json:"scope"`
|
|
||||||
SessionSecret string `json:"session_secret"`
|
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
ErrorDescription string `json:"error_description,omitempty"`
|
ErrorDescription string `json:"error_description,omitempty"`
|
||||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||||
SecretKey string `json:"secret_key,omitempty"`
|
ExpiresAt time.Time `json:"-"`
|
||||||
ApiKey string `json:"api_key,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var baiduAccessTokens sync.Map
|
var baiduTokenStore sync.Map
|
||||||
|
|
||||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||||
@ -322,82 +311,58 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getBaiduAccessToken(apiKey string) (string, error) {
|
func getBaiduAccessToken(apiKey string) (string, error) {
|
||||||
|
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||||
var accessToken BaiduAccessToken
|
var accessToken BaiduAccessToken
|
||||||
md5Key := md5.Sum([]byte(apiKey))
|
|
||||||
if val, ok := baiduAccessTokens.Load(md5Key); ok {
|
|
||||||
if accessToken, ok = val.(BaiduAccessToken); ok {
|
if accessToken, ok = val.(BaiduAccessToken); ok {
|
||||||
// 提前1小时刷新
|
// soon this will expire
|
||||||
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
||||||
go refreshBaiduAccessToken(&accessToken)
|
go func() {
|
||||||
|
_, _ = getBaiduAccessTokenHelper(apiKey)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
return accessToken.AccessToken, nil
|
return accessToken.AccessToken, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
accessToken, err := getBaiduAccessTokenHelper(apiKey)
|
||||||
splits := strings.Split(apiKey, "|")
|
if err != nil {
|
||||||
if len(splits) == 1 {
|
|
||||||
accessToken.AccessToken = apiKey
|
|
||||||
accessToken.ExpiresAt = time.Now().Add(30 * 24 * time.Hour)
|
|
||||||
baiduAccessTokens.Store(md5Key, accessToken)
|
|
||||||
return apiKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var token string
|
|
||||||
var err error
|
|
||||||
if token, err = initBaiduAccessToken(splits[0], splits[1], ""); err != nil {
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
if accessToken == nil {
|
||||||
return token, nil
|
return "", errors.New("getBaiduAccessToken return a nil token")
|
||||||
|
}
|
||||||
|
return (*accessToken).AccessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func refreshBaiduAccessToken(accessToken *BaiduAccessToken) error {
|
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
||||||
if accessToken.RefreshToken == "" {
|
parts := strings.Split(apiKey, "|")
|
||||||
return nil
|
if len(parts) != 2 {
|
||||||
|
return nil, errors.New("invalid baidu apikey")
|
||||||
}
|
}
|
||||||
_, err := initBaiduAccessToken(accessToken.SecretKey, accessToken.ApiKey, accessToken.RefreshToken)
|
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
||||||
return err
|
parts[0], parts[1]), nil)
|
||||||
}
|
|
||||||
|
|
||||||
func initBaiduAccessToken(secretKey, apiKey, refreshToken string) (string, error) {
|
|
||||||
url := "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + apiKey + "&client_secret=" + secretKey
|
|
||||||
if refreshToken != "" {
|
|
||||||
url += "&refresh_token=" + refreshToken
|
|
||||||
}
|
|
||||||
req, err := http.NewRequest("POST", url, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.New(fmt.Sprintf("initBaiduAccessToken err: %s", err.Error()))
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
req.Header.Add("Accept", "application/json")
|
req.Header.Add("Accept", "application/json")
|
||||||
|
res, err := impatientHTTPClient.Do(req)
|
||||||
client := http.Client{
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := client.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.New(fmt.Sprintf("initBaiduAccessToken request err: %s", err.Error()))
|
return nil, err
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
var accessToken BaiduAccessToken
|
var accessToken BaiduAccessToken
|
||||||
err = json.NewDecoder(res.Body).Decode(&accessToken)
|
err = json.NewDecoder(res.Body).Decode(&accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.New(fmt.Sprintf("initBaiduAccessToken decode access token err: %s", err.Error()))
|
return nil, err
|
||||||
}
|
}
|
||||||
if accessToken.Error != "" {
|
if accessToken.Error != "" {
|
||||||
return "", errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
||||||
}
|
}
|
||||||
|
|
||||||
if accessToken.AccessToken == "" {
|
if accessToken.AccessToken == "" {
|
||||||
return "", errors.New("initBaiduAccessToken get access token empty")
|
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
||||||
accessToken.SecretKey = secretKey
|
baiduTokenStore.Store(apiKey, accessToken)
|
||||||
accessToken.ApiKey = apiKey
|
return &accessToken, nil
|
||||||
baiduAccessTokens.Store(md5.Sum([]byte(secretKey+"|"+apiKey)), accessToken)
|
|
||||||
return accessToken.AccessToken, nil
|
|
||||||
}
|
}
|
||||||
|
@ -5,13 +5,13 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -25,9 +25,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var httpClient *http.Client
|
var httpClient *http.Client
|
||||||
|
var impatientHTTPClient *http.Client
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
httpClient = &http.Client{}
|
httpClient = &http.Client{}
|
||||||
|
impatientHTTPClient = &http.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
@ -148,7 +152,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
var err error
|
var err error
|
||||||
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
|
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
|
||||||
return errorWrapper(err, "invalid_auth", http.StatusBadRequest)
|
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
fullRequestURL += "?access_token=" + apiKey
|
fullRequestURL += "?access_token=" + apiKey
|
||||||
case APITypePaLM:
|
case APITypePaLM:
|
||||||
|
@ -519,5 +519,6 @@
|
|||||||
"令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!",
|
"令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!",
|
||||||
"代理": "Proxy",
|
"代理": "Proxy",
|
||||||
"此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
|
"此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
|
||||||
"取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?"
|
"取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?",
|
||||||
|
"按照如下格式输入:": "Enter in the following format:"
|
||||||
}
|
}
|
||||||
|
@ -355,7 +355,7 @@ const EditChannel = () => {
|
|||||||
label='密钥'
|
label='密钥'
|
||||||
name='key'
|
name='key'
|
||||||
required
|
required
|
||||||
placeholder={inputs.type === 15 ? '按照如下格式输入:SecretKey|APIKey 或者直接输入access token' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
|
placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
value={inputs.key}
|
value={inputs.key}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
|
Loading…
Reference in New Issue
Block a user