diff --git a/controller/channel-billing.go b/controller/channel-billing.go index de3fc5f9..8b6d787b 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -37,6 +37,55 @@ type OpenAIUsageResponse struct { TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar } +type OpenaiSbUsageResponse struct { + Data struct { + Credit string `json:"credit"` + } `json:"data"` +} + +func GetRequestBody(method, url string, channel *model.Channel) ([]byte, error) { + client := &http.Client{} + req, err := http.NewRequest(method, url, nil) + if err != nil { + return nil, err + } + auth := fmt.Sprintf("Bearer %s", channel.Key) + req.Header.Add("Authorization", auth) + res, err := client.Do(req) + if err != nil { + return nil, err + } + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + err = res.Body.Close() + if err != nil { + return nil, err + } + return body, nil +} + +func getOpenaiSBCredit(channel *model.Channel) (float64, error) { + url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) + body, err := GetRequestBody("GET", url, channel) + if err != nil { + return 0, err + } + subscription := OpenaiSbUsageResponse{} + err = json.Unmarshal(body, &subscription) + if err != nil { + return 0, err + } + + balance, err := strconv.ParseFloat(subscription.Data.Credit, 64) + if err != nil { + return 0, err + } + channel.UpdateBalance(balance) + return balance, nil +} + func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := common.ChannelBaseURLs[channel.Type] switch channel.Type { @@ -48,27 +97,14 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { return 0, errors.New("尚未实现") case common.ChannelTypeCustom: baseURL = channel.BaseURL + case common.ChannelTypeOpenAISB: + return getOpenaiSBCredit(channel) default: return 0, errors.New("尚未实现") } url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) - client := &http.Client{} - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return 0, err - } - auth := fmt.Sprintf("Bearer %s", channel.Key) - req.Header.Add("Authorization", auth) - res, err := client.Do(req) - if err != nil { - return 0, err - } - body, err := io.ReadAll(res.Body) - if err != nil { - return 0, err - } - err = res.Body.Close() + body, err := GetRequestBody("GET", url, channel) if err != nil { return 0, err } @@ -84,20 +120,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { startDate = now.AddDate(0, 0, -100).Format("2006-01-02") } url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) - req, err = http.NewRequest("GET", url, nil) - if err != nil { - return 0, err - } - req.Header.Add("Authorization", auth) - res, err = client.Do(req) - if err != nil { - return 0, err - } - body, err = io.ReadAll(res.Body) - if err != nil { - return 0, err - } - err = res.Body.Close() + body, err = GetRequestBody("GET", url, channel) if err != nil { return 0, err }