diff --git a/controller/channel-billing.go b/controller/channel-billing.go index b1926545..1ff7ff42 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -4,13 +4,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" "strconv" "time" + + "github.com/gin-gonic/gin" ) // https://github.com/songquanpeng/one-api/issues/79 @@ -44,14 +45,31 @@ type OpenAISBUsageResponse struct { } `json:"data"` } -func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error) { +type AIProxyUserOverviewResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + ErrorCode int `json:"error_code"` + Data struct { + TotalPoints float64 `json:"totalPoints"` + } `json:"data"` +} + +// GetAuthHeader get auth header +func GetAuthHeader(token string) http.Header { + h := http.Header{} + h.Add("Authorization", fmt.Sprintf("Bearer %s", token)) + return h +} + +func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { client := &http.Client{} req, err := http.NewRequest(method, url, nil) if err != nil { return nil, err } - auth := fmt.Sprintf("Bearer %s", channel.Key) - req.Header.Add("Authorization", auth) + for k := range headers { + req.Header.Add(k, headers.Get(k)) + } res, err := client.Do(req) if err != nil { return nil, err @@ -69,7 +87,7 @@ func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error) func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) - body, err := GetResponseBody("GET", url, channel) + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } @@ -89,6 +107,26 @@ func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { return balance, nil } +func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) { + url := "https://aiproxy.io/api/report/getUserOverview" + headers := http.Header{} + headers.Add("Api-Key", channel.Key) + body, err := GetResponseBody("GET", url, channel, headers) + if err != nil { + return 0, err + } + response := AIProxyUserOverviewResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + if !response.Success { + return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) + } + channel.UpdateBalance(response.Data.TotalPoints) + return response.Data.TotalPoints, nil +} + func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := common.ChannelBaseURLs[channel.Type] switch channel.Type { @@ -102,12 +140,14 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL = channel.BaseURL case common.ChannelTypeOpenAISB: return updateChannelOpenAISBBalance(channel) + case common.ChannelTypeAIProxy: + return updateChannelAIProxyBalance(channel) default: return 0, errors.New("尚未实现") } url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) - body, err := GetResponseBody("GET", url, channel) + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } @@ -123,7 +163,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { startDate = now.AddDate(0, 0, -100).Format("2006-01-02") } url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) - body, err = GetResponseBody("GET", url, channel) + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 48fd521d..f5f25ae9 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -4,7 +4,7 @@ import { Link } from 'react-router-dom'; import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; -import { renderGroup } from '../helpers/render'; +import { renderGroup, renderNumber } from '../helpers/render'; function renderTimestamp(timestamp) { return ( @@ -28,10 +28,17 @@ function renderType(type) { } function renderBalance(type, balance) { - if (type === 5) { - return ¥{(balance / 10000).toFixed(2)} + switch (type) { + case 1: // OpenAI + case 8: // 自定义 + return ${balance.toFixed(2)}; + case 5: // OpenAI-SB + return ¥{(balance / 10000).toFixed(2)}; + case 10: // AI Proxy + return {renderNumber(balance)}; + default: + return 不支持; } - return ${balance.toFixed(2)} } const ChannelsTable = () => { @@ -422,7 +429,8 @@ const ChannelsTable = () => { - +