diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index b1926545..1ff7ff42 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -4,13 +4,14 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
+
+ "github.com/gin-gonic/gin"
)
// https://github.com/songquanpeng/one-api/issues/79
@@ -44,14 +45,31 @@ type OpenAISBUsageResponse struct {
} `json:"data"`
}
-func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error) {
+type AIProxyUserOverviewResponse struct {
+ Success bool `json:"success"`
+ Message string `json:"message"`
+ ErrorCode int `json:"error_code"`
+ Data struct {
+ TotalPoints float64 `json:"totalPoints"`
+ } `json:"data"`
+}
+
+// GetAuthHeader get auth header
+func GetAuthHeader(token string) http.Header {
+ h := http.Header{}
+ h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
+ return h
+}
+
+func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
client := &http.Client{}
req, err := http.NewRequest(method, url, nil)
if err != nil {
return nil, err
}
- auth := fmt.Sprintf("Bearer %s", channel.Key)
- req.Header.Add("Authorization", auth)
+ for k := range headers {
+ req.Header.Add(k, headers.Get(k))
+ }
res, err := client.Do(req)
if err != nil {
return nil, err
@@ -69,7 +87,7 @@ func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error)
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
- body, err := GetResponseBody("GET", url, channel)
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
@@ -89,6 +107,26 @@ func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
return balance, nil
}
+func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
+ url := "https://aiproxy.io/api/report/getUserOverview"
+ headers := http.Header{}
+ headers.Add("Api-Key", channel.Key)
+ body, err := GetResponseBody("GET", url, channel, headers)
+ if err != nil {
+ return 0, err
+ }
+ response := AIProxyUserOverviewResponse{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ if !response.Success {
+ return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
+ }
+ channel.UpdateBalance(response.Data.TotalPoints)
+ return response.Data.TotalPoints, nil
+}
+
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type]
switch channel.Type {
@@ -102,12 +140,14 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL = channel.BaseURL
case common.ChannelTypeOpenAISB:
return updateChannelOpenAISBBalance(channel)
+ case common.ChannelTypeAIProxy:
+ return updateChannelAIProxyBalance(channel)
default:
return 0, errors.New("尚未实现")
}
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
- body, err := GetResponseBody("GET", url, channel)
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
@@ -123,7 +163,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
}
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
- body, err = GetResponseBody("GET", url, channel)
+ body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js
index 48fd521d..f5f25ae9 100644
--- a/web/src/components/ChannelsTable.js
+++ b/web/src/components/ChannelsTable.js
@@ -4,7 +4,7 @@ import { Link } from 'react-router-dom';
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
-import { renderGroup } from '../helpers/render';
+import { renderGroup, renderNumber } from '../helpers/render';
function renderTimestamp(timestamp) {
return (
@@ -28,10 +28,17 @@ function renderType(type) {
}
function renderBalance(type, balance) {
- if (type === 5) {
- return ¥{(balance / 10000).toFixed(2)}
+ switch (type) {
+ case 1: // OpenAI
+ case 8: // 自定义
+ return ${balance.toFixed(2)};
+ case 5: // OpenAI-SB
+ return ¥{(balance / 10000).toFixed(2)};
+ case 10: // AI Proxy
+ return {renderNumber(balance)};
+ default:
+ return 不支持;
}
- return ${balance.toFixed(2)}
}
const ChannelsTable = () => {
@@ -422,7 +429,8 @@ const ChannelsTable = () => {
-
+