Add support for updating channel balance in OpenAISB

This commit is contained in:
quzard 2023-06-11 18:19:16 +08:00
parent 39481eb6c0
commit 146dc840e6

View File

@ -37,6 +37,55 @@ type OpenAIUsageResponse struct {
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
} }
type OpenaiSbUsageResponse struct {
Data struct {
Credit string `json:"credit"`
} `json:"data"`
}
func GetRequestBody(method, url string, channel *model.Channel) ([]byte, error) {
client := &http.Client{}
req, err := http.NewRequest(method, url, nil)
if err != nil {
return nil, err
}
auth := fmt.Sprintf("Bearer %s", channel.Key)
req.Header.Add("Authorization", auth)
res, err := client.Do(req)
if err != nil {
return nil, err
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
err = res.Body.Close()
if err != nil {
return nil, err
}
return body, nil
}
func getOpenaiSBCredit(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
body, err := GetRequestBody("GET", url, channel)
if err != nil {
return 0, err
}
subscription := OpenaiSbUsageResponse{}
err = json.Unmarshal(body, &subscription)
if err != nil {
return 0, err
}
balance, err := strconv.ParseFloat(subscription.Data.Credit, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type] baseURL := common.ChannelBaseURLs[channel.Type]
switch channel.Type { switch channel.Type {
@ -48,27 +97,14 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
case common.ChannelTypeCustom: case common.ChannelTypeCustom:
baseURL = channel.BaseURL baseURL = channel.BaseURL
case common.ChannelTypeOpenAISB:
return getOpenaiSBCredit(channel)
default: default:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
} }
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
client := &http.Client{} body, err := GetRequestBody("GET", url, channel)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return 0, err
}
auth := fmt.Sprintf("Bearer %s", channel.Key)
req.Header.Add("Authorization", auth)
res, err := client.Do(req)
if err != nil {
return 0, err
}
body, err := io.ReadAll(res.Body)
if err != nil {
return 0, err
}
err = res.Body.Close()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -84,20 +120,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
startDate = now.AddDate(0, 0, -100).Format("2006-01-02") startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
} }
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
req, err = http.NewRequest("GET", url, nil) body, err = GetRequestBody("GET", url, channel)
if err != nil {
return 0, err
}
req.Header.Add("Authorization", auth)
res, err = client.Do(req)
if err != nil {
return 0, err
}
body, err = io.ReadAll(res.Body)
if err != nil {
return 0, err
}
err = res.Body.Close()
if err != nil { if err != nil {
return 0, err return 0, err
} }