From 5e08cc87197c04696849c2ac6f337ec0a6cafd8b Mon Sep 17 00:00:00 2001 From: MartialBE Date: Sat, 2 Dec 2023 17:51:28 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20=E6=B7=BB=E5=8A=A0=E4=BD=99?= =?UTF-8?q?=E9=A2=9D=E6=9F=A5=E8=AF=A2=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/aiproxy/balance.go | 35 +++++++++++++++++++++++++++++++++++ providers/aiproxy/base.go | 18 ++++++++++++++++++ providers/aiproxy/type.go | 10 ++++++++++ providers/api2d/balance.go | 29 +++++++++++++++++++++++++++++ providers/api2d/type.go | 9 +++++++++ providers/closeai/balance.go | 6 ++---- providers/openaisb/balance.go | 2 +- providers/providers.go | 12 ++++++++++++ 8 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 providers/aiproxy/balance.go create mode 100644 providers/aiproxy/base.go create mode 100644 providers/aiproxy/type.go create mode 100644 providers/api2d/balance.go create mode 100644 providers/api2d/type.go diff --git a/providers/aiproxy/balance.go b/providers/aiproxy/balance.go new file mode 100644 index 00000000..82a1653e --- /dev/null +++ b/providers/aiproxy/balance.go @@ -0,0 +1,35 @@ +package aiproxy + +import ( + "errors" + "fmt" + "one-api/common" + "one-api/model" +) + +func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) { + fullRequestURL := "https://aiproxy.io/api/report/getUserOverview" + headers := make(map[string]string) + headers["Api-Key"] = channel.Key + + client := common.NewClient() + req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + if err != nil { + return 0, err + } + + // 发送请求 + var response AIProxyUserOverviewResponse + _, errWithCode := common.SendRequest(req, &response, false) + if errWithCode != nil { + return 0, errors.New(errWithCode.OpenAIError.Message) + } + + if !response.Success { + return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) + } + + channel.UpdateBalance(response.Data.TotalPoints) + + return response.Data.TotalPoints, nil +} diff --git a/providers/aiproxy/base.go b/providers/aiproxy/base.go new file mode 100644 index 00000000..9acc859e --- /dev/null +++ b/providers/aiproxy/base.go @@ -0,0 +1,18 @@ +package aiproxy + +import ( + "one-api/providers/openai" + + "github.com/gin-gonic/gin" +) + +type AIProxyProvider struct { + *openai.OpenAIProvider +} + +// 创建 CreateAIProxyProvider +func CreateAIProxyProvider(c *gin.Context) *AIProxyProvider { + return &AIProxyProvider{ + OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aiproxy.io"), + } +} diff --git a/providers/aiproxy/type.go b/providers/aiproxy/type.go new file mode 100644 index 00000000..6aadf8f8 --- /dev/null +++ b/providers/aiproxy/type.go @@ -0,0 +1,10 @@ +package aiproxy + +type AIProxyUserOverviewResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + ErrorCode int `json:"error_code"` + Data struct { + TotalPoints float64 `json:"totalPoints"` + } `json:"data"` +} diff --git a/providers/api2d/balance.go b/providers/api2d/balance.go new file mode 100644 index 00000000..080a728f --- /dev/null +++ b/providers/api2d/balance.go @@ -0,0 +1,29 @@ +package api2d + +import ( + "errors" + "one-api/common" + "one-api/model" +) + +func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) { + fullRequestURL := "https://api.aigc2d.com/dashboard/billing/credit_grants" + headers := p.GetRequestHeaders() + + client := common.NewClient() + req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + if err != nil { + return 0, err + } + + // 发送请求 + var response APGC2DGPTUsageResponse + _, errWithCode := common.SendRequest(req, &response, false) + if errWithCode != nil { + return 0, errors.New(errWithCode.OpenAIError.Message) + } + + channel.UpdateBalance(response.TotalAvailable) + + return response.TotalAvailable, nil +} diff --git a/providers/api2d/type.go b/providers/api2d/type.go new file mode 100644 index 00000000..988bb7a6 --- /dev/null +++ b/providers/api2d/type.go @@ -0,0 +1,9 @@ +package api2d + +type APGC2DGPTUsageResponse struct { + //Grants interface{} `json:"grants"` + Object string `json:"object"` + TotalAvailable float64 `json:"total_available"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` +} diff --git a/providers/closeai/balance.go b/providers/closeai/balance.go index ae649766..80665df2 100644 --- a/providers/closeai/balance.go +++ b/providers/closeai/balance.go @@ -2,18 +2,16 @@ package closeai import ( "errors" - "fmt" "one-api/common" "one-api/model" ) func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) { - fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "") - fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key) + fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") headers := p.GetRequestHeaders() client := common.NewClient() - req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers)) + req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) if err != nil { return 0, err } diff --git a/providers/openaisb/balance.go b/providers/openaisb/balance.go index 8a789d44..f03bef97 100644 --- a/providers/openaisb/balance.go +++ b/providers/openaisb/balance.go @@ -14,7 +14,7 @@ func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) { headers := p.GetRequestHeaders() client := common.NewClient() - req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers)) + req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) if err != nil { return 0, err } diff --git a/providers/providers.go b/providers/providers.go index 9eeaa34e..cee972b8 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -2,12 +2,16 @@ package providers import ( "one-api/common" + "one-api/providers/aiproxy" "one-api/providers/ali" + "one-api/providers/api2d" "one-api/providers/azure" "one-api/providers/baidu" "one-api/providers/base" "one-api/providers/claude" + "one-api/providers/closeai" "one-api/providers/openai" + "one-api/providers/openaisb" "one-api/providers/palm" "one-api/providers/tencent" "one-api/providers/xunfei" @@ -36,6 +40,14 @@ func GetProvider(channelType int, c *gin.Context) base.ProviderInterface { return zhipu.CreateZhipuProvider(c) case common.ChannelTypeXunfei: return xunfei.CreateXunfeiProvider(c) + case common.ChannelTypeAIProxy: + return aiproxy.CreateAIProxyProvider(c) + case common.ChannelTypeAPI2D: + return api2d.CreateApi2dProvider(c) + case common.ChannelTypeCloseAI: + return closeai.CreateCloseaiProxyProvider(c) + case common.ChannelTypeOpenAISB: + return openaisb.CreateOpenaiSBProvider(c) default: baseURL := common.ChannelBaseURLs[channelType] if c.GetString("base_url") != "" {