✨ 完善余额查询
This commit is contained in:
parent
da87fca2a2
commit
58fc40a744
@ -1,13 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/providers"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -46,217 +45,19 @@ type OpenAIUsageResponse struct {
|
|||||||
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
|
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAISBUsageResponse struct {
|
|
||||||
Msg string `json:"msg"`
|
|
||||||
Data *struct {
|
|
||||||
Credit string `json:"credit"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyUserOverviewResponse struct {
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
ErrorCode int `json:"error_code"`
|
|
||||||
Data struct {
|
|
||||||
TotalPoints float64 `json:"totalPoints"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type API2GPTUsageResponse struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
TotalGranted float64 `json:"total_granted"`
|
|
||||||
TotalUsed float64 `json:"total_used"`
|
|
||||||
TotalRemaining float64 `json:"total_remaining"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type APGC2DGPTUsageResponse struct {
|
|
||||||
//Grants interface{} `json:"grants"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
TotalAvailable float64 `json:"total_available"`
|
|
||||||
TotalGranted float64 `json:"total_granted"`
|
|
||||||
TotalUsed float64 `json:"total_used"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAuthHeader get auth header
|
|
||||||
func GetAuthHeader(token string) http.Header {
|
|
||||||
h := http.Header{}
|
|
||||||
h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
|
|
||||||
req, err := http.NewRequest(method, url, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for k := range headers {
|
|
||||||
req.Header.Add(k, headers.Get(k))
|
|
||||||
}
|
|
||||||
res, err := common.HttpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if res.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("status code: %d", res.StatusCode)
|
|
||||||
}
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = res.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return body, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
|
|
||||||
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
|
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
response := OpenAICreditGrants{}
|
|
||||||
err = json.Unmarshal(body, &response)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
channel.UpdateBalance(response.TotalAvailable)
|
|
||||||
return response.TotalAvailable, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
|
|
||||||
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
|
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
response := OpenAISBUsageResponse{}
|
|
||||||
err = json.Unmarshal(body, &response)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if response.Data == nil {
|
|
||||||
return 0, errors.New(response.Msg)
|
|
||||||
}
|
|
||||||
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
channel.UpdateBalance(balance)
|
|
||||||
return balance, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
|
|
||||||
url := "https://aiproxy.io/api/report/getUserOverview"
|
|
||||||
headers := http.Header{}
|
|
||||||
headers.Add("Api-Key", channel.Key)
|
|
||||||
body, err := GetResponseBody("GET", url, channel, headers)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
response := AIProxyUserOverviewResponse{}
|
|
||||||
err = json.Unmarshal(body, &response)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if !response.Success {
|
|
||||||
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
|
|
||||||
}
|
|
||||||
channel.UpdateBalance(response.Data.TotalPoints)
|
|
||||||
return response.Data.TotalPoints, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
|
|
||||||
url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
|
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
response := API2GPTUsageResponse{}
|
|
||||||
err = json.Unmarshal(body, &response)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
channel.UpdateBalance(response.TotalRemaining)
|
|
||||||
return response.TotalRemaining, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
|
|
||||||
url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
|
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
response := APGC2DGPTUsageResponse{}
|
|
||||||
err = json.Unmarshal(body, &response)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
channel.UpdateBalance(response.TotalAvailable)
|
|
||||||
return response.TotalAvailable, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
provider := providers.GetProvider(channel.Type, nil)
|
||||||
if channel.GetBaseURL() == "" {
|
if provider == nil {
|
||||||
channel.BaseURL = &baseURL
|
return 0, errors.New("provider not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch channel.Type {
|
balanceProvider, ok := provider.(providersBase.BalanceInterface)
|
||||||
case common.ChannelTypeOpenAI:
|
if !ok {
|
||||||
if channel.GetBaseURL() != "" {
|
return 0, errors.New("provider not implemented")
|
||||||
baseURL = channel.GetBaseURL()
|
|
||||||
}
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
return 0, errors.New("尚未实现")
|
|
||||||
case common.ChannelTypeCustom:
|
|
||||||
baseURL = channel.GetBaseURL()
|
|
||||||
case common.ChannelTypeCloseAI:
|
|
||||||
return updateChannelCloseAIBalance(channel)
|
|
||||||
case common.ChannelTypeOpenAISB:
|
|
||||||
return updateChannelOpenAISBBalance(channel)
|
|
||||||
case common.ChannelTypeAIProxy:
|
|
||||||
return updateChannelAIProxyBalance(channel)
|
|
||||||
case common.ChannelTypeAPI2GPT:
|
|
||||||
return updateChannelAPI2GPTBalance(channel)
|
|
||||||
case common.ChannelTypeAIGC2D:
|
|
||||||
return updateChannelAIGC2DBalance(channel)
|
|
||||||
default:
|
|
||||||
return 0, errors.New("尚未实现")
|
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
|
|
||||||
|
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
return balanceProvider.BalanceAction(channel)
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
subscription := OpenAISubscriptionResponse{}
|
|
||||||
err = json.Unmarshal(body, &subscription)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
now := time.Now()
|
|
||||||
startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
|
|
||||||
endDate := now.Format("2006-01-02")
|
|
||||||
if !subscription.HasPaymentMethod {
|
|
||||||
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
|
|
||||||
}
|
|
||||||
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
|
|
||||||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
usage := OpenAIUsageResponse{}
|
|
||||||
err = json.Unmarshal(body, &usage)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
balance := subscription.HardLimitUSD - usage.TotalUsage/100
|
|
||||||
channel.UpdateBalance(balance)
|
|
||||||
return balance, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannelBalance(c *gin.Context) {
|
func UpdateChannelBalance(c *gin.Context) {
|
||||||
|
46
providers/openai/balance.go
Normal file
46
providers/openai/balance.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
|
||||||
|
fullRequestURL := p.GetFullRequestURL("/v1/dashboard/billing/subscription", "")
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
var subscription OpenAISubscriptionResponse
|
||||||
|
_, errWithCode := common.SendRequest(req, &subscription, false)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
|
||||||
|
endDate := now.Format("2006-01-02")
|
||||||
|
if !subscription.HasPaymentMethod {
|
||||||
|
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
|
||||||
|
}
|
||||||
|
|
||||||
|
fullRequestURL = p.GetFullRequestURL(fmt.Sprintf("/v1/dashboard/billing/usage?start_date=%s&end_date=%s", startDate, endDate), "")
|
||||||
|
req, err = client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
usage := OpenAIUsageResponse{}
|
||||||
|
_, errWithCode = common.SendRequest(req, &usage, false)
|
||||||
|
|
||||||
|
balance := subscription.HardLimitUSD - usage.TotalUsage/100
|
||||||
|
channel.UpdateBalance(balance)
|
||||||
|
return balance, nil
|
||||||
|
}
|
@ -42,3 +42,18 @@ type OpenAIProviderImageResponseResponse struct {
|
|||||||
types.ImageResponse
|
types.ImageResponse
|
||||||
types.OpenAIErrorResponse
|
types.OpenAIErrorResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpenAISubscriptionResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
HasPaymentMethod bool `json:"has_payment_method"`
|
||||||
|
SoftLimitUSD float64 `json:"soft_limit_usd"`
|
||||||
|
HardLimitUSD float64 `json:"hard_limit_usd"`
|
||||||
|
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
|
||||||
|
AccessUntil int64 `json:"access_until"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIUsageResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
|
||||||
|
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user