diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 8b6d787b..b1926545 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -37,13 +37,14 @@ type OpenAIUsageResponse struct {
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
}
-type OpenaiSbUsageResponse struct {
- Data struct {
+type OpenAISBUsageResponse struct {
+ Msg string `json:"msg"`
+ Data *struct {
Credit string `json:"credit"`
} `json:"data"`
}
-func GetRequestBody(method, url string, channel *model.Channel) ([]byte, error) {
+func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error) {
client := &http.Client{}
req, err := http.NewRequest(method, url, nil)
if err != nil {
@@ -66,19 +67,21 @@ func GetRequestBody(method, url string, channel *model.Channel) ([]byte, error)
return body, nil
}
-func getOpenaiSBCredit(channel *model.Channel) (float64, error) {
+func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
- body, err := GetRequestBody("GET", url, channel)
+ body, err := GetResponseBody("GET", url, channel)
if err != nil {
return 0, err
}
- subscription := OpenaiSbUsageResponse{}
- err = json.Unmarshal(body, &subscription)
+ response := OpenAISBUsageResponse{}
+ err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
-
- balance, err := strconv.ParseFloat(subscription.Data.Credit, 64)
+ if response.Data == nil {
+ return 0, errors.New(response.Msg)
+ }
+ balance, err := strconv.ParseFloat(response.Data.Credit, 64)
if err != nil {
return 0, err
}
@@ -98,13 +101,13 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
case common.ChannelTypeCustom:
baseURL = channel.BaseURL
case common.ChannelTypeOpenAISB:
- return getOpenaiSBCredit(channel)
+ return updateChannelOpenAISBBalance(channel)
default:
return 0, errors.New("尚未实现")
}
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
- body, err := GetRequestBody("GET", url, channel)
+ body, err := GetResponseBody("GET", url, channel)
if err != nil {
return 0, err
}
@@ -120,7 +123,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
}
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
- body, err = GetRequestBody("GET", url, channel)
+ body, err = GetResponseBody("GET", url, channel)
if err != nil {
return 0, err
}
diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js
index be0bba16..999027fc 100644
--- a/web/src/components/ChannelsTable.js
+++ b/web/src/components/ChannelsTable.js
@@ -27,6 +27,13 @@ function renderType(type) {
return ;
}
+function renderBalance(type, balance) {
+ if (type === 5) {
+ return {balance.toFixed(2)}
+ }
+ return ${balance.toFixed(2)}
+}
+
const ChannelsTable = () => {
const [channels, setChannels] = useState([]);
const [loading, setLoading] = useState(true);
@@ -336,7 +343,7 @@ const ChannelsTable = () => {
${channel.balance.toFixed(2)}}
+ trigger={renderBalance(channel.type, channel.balance)}
basic
/>