diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 8b6d787b..b1926545 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -37,13 +37,14 @@ type OpenAIUsageResponse struct { TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar } -type OpenaiSbUsageResponse struct { - Data struct { +type OpenAISBUsageResponse struct { + Msg string `json:"msg"` + Data *struct { Credit string `json:"credit"` } `json:"data"` } -func GetRequestBody(method, url string, channel *model.Channel) ([]byte, error) { +func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error) { client := &http.Client{} req, err := http.NewRequest(method, url, nil) if err != nil { @@ -66,19 +67,21 @@ func GetRequestBody(method, url string, channel *model.Channel) ([]byte, error) return body, nil } -func getOpenaiSBCredit(channel *model.Channel) (float64, error) { +func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) - body, err := GetRequestBody("GET", url, channel) + body, err := GetResponseBody("GET", url, channel) if err != nil { return 0, err } - subscription := OpenaiSbUsageResponse{} - err = json.Unmarshal(body, &subscription) + response := OpenAISBUsageResponse{} + err = json.Unmarshal(body, &response) if err != nil { return 0, err } - - balance, err := strconv.ParseFloat(subscription.Data.Credit, 64) + if response.Data == nil { + return 0, errors.New(response.Msg) + } + balance, err := strconv.ParseFloat(response.Data.Credit, 64) if err != nil { return 0, err } @@ -98,13 +101,13 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { case common.ChannelTypeCustom: baseURL = channel.BaseURL case common.ChannelTypeOpenAISB: - return getOpenaiSBCredit(channel) + return updateChannelOpenAISBBalance(channel) default: return 0, errors.New("尚未实现") } url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) - body, err := GetRequestBody("GET", url, channel) + body, err := GetResponseBody("GET", url, channel) if err != nil { return 0, err } @@ -120,7 +123,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { startDate = now.AddDate(0, 0, -100).Format("2006-01-02") } url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) - body, err = GetRequestBody("GET", url, channel) + body, err = GetResponseBody("GET", url, channel) if err != nil { return 0, err } diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index be0bba16..999027fc 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -27,6 +27,13 @@ function renderType(type) { return ; } +function renderBalance(type, balance) { + if (type === 5) { + return {balance.toFixed(2)} + } + return ${balance.toFixed(2)} +} + const ChannelsTable = () => { const [channels, setChannels] = useState([]); const [loading, setLoading] = useState(true); @@ -336,7 +343,7 @@ const ChannelsTable = () => { ${channel.balance.toFixed(2)}} + trigger={renderBalance(channel.type, channel.balance)} basic />