From 58fc40a744e2cee98c314c689e80b0732da08f3d Mon Sep 17 00:00:00 2001 From: MartialBE Date: Sat, 2 Dec 2023 18:27:38 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20=E5=AE=8C=E5=96=84=E4=BD=99?= =?UTF-8?q?=E9=A2=9D=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-billing.go | 219 ++-------------------------------- providers/openai/balance.go | 46 +++++++ providers/openai/type.go | 15 +++ 3 files changed, 71 insertions(+), 209 deletions(-) create mode 100644 providers/openai/balance.go diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 8f388e6f..a9807417 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -1,13 +1,12 @@ package controller import ( - "encoding/json" "errors" - "fmt" - "io" "net/http" "one-api/common" "one-api/model" + "one-api/providers" + providersBase "one-api/providers/base" "strconv" "time" @@ -46,217 +45,19 @@ type OpenAIUsageResponse struct { TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar } -type OpenAISBUsageResponse struct { - Msg string `json:"msg"` - Data *struct { - Credit string `json:"credit"` - } `json:"data"` -} - -type AIProxyUserOverviewResponse struct { - Success bool `json:"success"` - Message string `json:"message"` - ErrorCode int `json:"error_code"` - Data struct { - TotalPoints float64 `json:"totalPoints"` - } `json:"data"` -} - -type API2GPTUsageResponse struct { - Object string `json:"object"` - TotalGranted float64 `json:"total_granted"` - TotalUsed float64 `json:"total_used"` - TotalRemaining float64 `json:"total_remaining"` -} - -type APGC2DGPTUsageResponse struct { - //Grants interface{} `json:"grants"` - Object string `json:"object"` - TotalAvailable float64 `json:"total_available"` - TotalGranted float64 `json:"total_granted"` - TotalUsed float64 `json:"total_used"` -} - -// GetAuthHeader get auth header -func GetAuthHeader(token string) http.Header { - h := http.Header{} - h.Add("Authorization", fmt.Sprintf("Bearer %s", token)) - return h -} - -func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { - req, err := http.NewRequest(method, url, nil) - if err != nil { - return nil, err - } - for k := range headers { - req.Header.Add(k, headers.Get(k)) - } - res, err := common.HttpClient.Do(req) - if err != nil { - return nil, err - } - if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("status code: %d", res.StatusCode) - } - body, err := io.ReadAll(res.Body) - if err != nil { - return nil, err - } - err = res.Body.Close() - if err != nil { - return nil, err - } - return body, nil -} - -func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { - url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) - body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) - - if err != nil { - return 0, err - } - response := OpenAICreditGrants{} - err = json.Unmarshal(body, &response) - if err != nil { - return 0, err - } - channel.UpdateBalance(response.TotalAvailable) - return response.TotalAvailable, nil -} - -func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { - url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) - body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) - if err != nil { - return 0, err - } - response := OpenAISBUsageResponse{} - err = json.Unmarshal(body, &response) - if err != nil { - return 0, err - } - if response.Data == nil { - return 0, errors.New(response.Msg) - } - balance, err := strconv.ParseFloat(response.Data.Credit, 64) - if err != nil { - return 0, err - } - channel.UpdateBalance(balance) - return balance, nil -} - -func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) { - url := "https://aiproxy.io/api/report/getUserOverview" - headers := http.Header{} - headers.Add("Api-Key", channel.Key) - body, err := GetResponseBody("GET", url, channel, headers) - if err != nil { - return 0, err - } - response := AIProxyUserOverviewResponse{} - err = json.Unmarshal(body, &response) - if err != nil { - return 0, err - } - if !response.Success { - return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) - } - channel.UpdateBalance(response.Data.TotalPoints) - return response.Data.TotalPoints, nil -} - -func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) { - url := "https://api.api2gpt.com/dashboard/billing/credit_grants" - body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) - - if err != nil { - return 0, err - } - response := API2GPTUsageResponse{} - err = json.Unmarshal(body, &response) - if err != nil { - return 0, err - } - channel.UpdateBalance(response.TotalRemaining) - return response.TotalRemaining, nil -} - -func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { - url := "https://api.aigc2d.com/dashboard/billing/credit_grants" - body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) - if err != nil { - return 0, err - } - response := APGC2DGPTUsageResponse{} - err = json.Unmarshal(body, &response) - if err != nil { - return 0, err - } - channel.UpdateBalance(response.TotalAvailable) - return response.TotalAvailable, nil -} - func updateChannelBalance(channel *model.Channel) (float64, error) { - baseURL := common.ChannelBaseURLs[channel.Type] - if channel.GetBaseURL() == "" { - channel.BaseURL = &baseURL + provider := providers.GetProvider(channel.Type, nil) + if provider == nil { + return 0, errors.New("provider not found") } - switch channel.Type { - case common.ChannelTypeOpenAI: - if channel.GetBaseURL() != "" { - baseURL = channel.GetBaseURL() - } - case common.ChannelTypeAzure: - return 0, errors.New("尚未实现") - case common.ChannelTypeCustom: - baseURL = channel.GetBaseURL() - case common.ChannelTypeCloseAI: - return updateChannelCloseAIBalance(channel) - case common.ChannelTypeOpenAISB: - return updateChannelOpenAISBBalance(channel) - case common.ChannelTypeAIProxy: - return updateChannelAIProxyBalance(channel) - case common.ChannelTypeAPI2GPT: - return updateChannelAPI2GPTBalance(channel) - case common.ChannelTypeAIGC2D: - return updateChannelAIGC2DBalance(channel) - default: - return 0, errors.New("尚未实现") + balanceProvider, ok := provider.(providersBase.BalanceInterface) + if !ok { + return 0, errors.New("provider not implemented") } - url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) - body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) - if err != nil { - return 0, err - } - subscription := OpenAISubscriptionResponse{} - err = json.Unmarshal(body, &subscription) - if err != nil { - return 0, err - } - now := time.Now() - startDate := fmt.Sprintf("%s-01", now.Format("2006-01")) - endDate := now.Format("2006-01-02") - if !subscription.HasPaymentMethod { - startDate = now.AddDate(0, 0, -100).Format("2006-01-02") - } - url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) - body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) - if err != nil { - return 0, err - } - usage := OpenAIUsageResponse{} - err = json.Unmarshal(body, &usage) - if err != nil { - return 0, err - } - balance := subscription.HardLimitUSD - usage.TotalUsage/100 - channel.UpdateBalance(balance) - return balance, nil + return balanceProvider.BalanceAction(channel) + } func UpdateChannelBalance(c *gin.Context) { diff --git a/providers/openai/balance.go b/providers/openai/balance.go new file mode 100644 index 00000000..8a616ea9 --- /dev/null +++ b/providers/openai/balance.go @@ -0,0 +1,46 @@ +package openai + +import ( + "errors" + "fmt" + "one-api/common" + "one-api/model" + "time" +) + +func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) { + fullRequestURL := p.GetFullRequestURL("/v1/dashboard/billing/subscription", "") + headers := p.GetRequestHeaders() + + client := common.NewClient() + req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + if err != nil { + return 0, err + } + + // 发送请求 + var subscription OpenAISubscriptionResponse + _, errWithCode := common.SendRequest(req, &subscription, false) + if errWithCode != nil { + return 0, errors.New(errWithCode.OpenAIError.Message) + } + + now := time.Now() + startDate := fmt.Sprintf("%s-01", now.Format("2006-01")) + endDate := now.Format("2006-01-02") + if !subscription.HasPaymentMethod { + startDate = now.AddDate(0, 0, -100).Format("2006-01-02") + } + + fullRequestURL = p.GetFullRequestURL(fmt.Sprintf("/v1/dashboard/billing/usage?start_date=%s&end_date=%s", startDate, endDate), "") + req, err = client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) + if err != nil { + return 0, err + } + usage := OpenAIUsageResponse{} + _, errWithCode = common.SendRequest(req, &usage, false) + + balance := subscription.HardLimitUSD - usage.TotalUsage/100 + channel.UpdateBalance(balance) + return balance, nil +} diff --git a/providers/openai/type.go b/providers/openai/type.go index 7153739f..b17e513b 100644 --- a/providers/openai/type.go +++ b/providers/openai/type.go @@ -42,3 +42,18 @@ type OpenAIProviderImageResponseResponse struct { types.ImageResponse types.OpenAIErrorResponse } + +type OpenAISubscriptionResponse struct { + Object string `json:"object"` + HasPaymentMethod bool `json:"has_payment_method"` + SoftLimitUSD float64 `json:"soft_limit_usd"` + HardLimitUSD float64 `json:"hard_limit_usd"` + SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` + AccessUntil int64 `json:"access_until"` +} + +type OpenAIUsageResponse struct { + Object string `json:"object"` + //DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"` + TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar +}