Add: support update AIProxy balance
This commit is contained in:
parent
57b213a035
commit
0fb110a479
@ -4,13 +4,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://github.com/songquanpeng/one-api/issues/79
|
// https://github.com/songquanpeng/one-api/issues/79
|
||||||
@ -44,12 +45,31 @@ type OpenAISBUsageResponse struct {
|
|||||||
} `json:"data"`
|
} `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error) {
|
type AIProxyUserOverviewResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
ErrorCode int `json:"error_code"`
|
||||||
|
Data struct {
|
||||||
|
TotalPoints float64 `json:"totalPoints"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthHeader get auth header
|
||||||
|
func GetAuthHeader(token string) http.Header {
|
||||||
|
h := http.Header{}
|
||||||
|
h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
req, err := http.NewRequest(method, url, nil)
|
req, err := http.NewRequest(method, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
for k := range headers {
|
||||||
|
req.Header.Add(k, headers.Get(k))
|
||||||
|
}
|
||||||
auth := fmt.Sprintf("Bearer %s", channel.Key)
|
auth := fmt.Sprintf("Bearer %s", channel.Key)
|
||||||
req.Header.Add("Authorization", auth)
|
req.Header.Add("Authorization", auth)
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
@ -69,7 +89,7 @@ func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error)
|
|||||||
|
|
||||||
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
|
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
|
||||||
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
|
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
|
||||||
body, err := GetResponseBody("GET", url, channel)
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -89,6 +109,26 @@ func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
|
|||||||
return balance, nil
|
return balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
|
||||||
|
url := "https://aiproxy.io/api/report/getUserOverview"
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Add("Api-Key", channel.Key)
|
||||||
|
body, err := GetResponseBody("GET", url, channel, headers)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
response := AIProxyUserOverviewResponse{}
|
||||||
|
err = json.Unmarshal(body, &response)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if !response.Success {
|
||||||
|
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
|
||||||
|
}
|
||||||
|
channel.UpdateBalance(response.Data.TotalPoints)
|
||||||
|
return response.Data.TotalPoints, nil
|
||||||
|
}
|
||||||
|
|
||||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
@ -102,12 +142,14 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|||||||
baseURL = channel.BaseURL
|
baseURL = channel.BaseURL
|
||||||
case common.ChannelTypeOpenAISB:
|
case common.ChannelTypeOpenAISB:
|
||||||
return updateChannelOpenAISBBalance(channel)
|
return updateChannelOpenAISBBalance(channel)
|
||||||
|
case common.ChannelTypeAIProxy:
|
||||||
|
return updateChannelAIProxyBalance(channel)
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
|
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
|
||||||
|
|
||||||
body, err := GetResponseBody("GET", url, channel)
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -123,7 +165,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|||||||
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
|
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
|
||||||
}
|
}
|
||||||
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
|
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
|
||||||
body, err = GetResponseBody("GET", url, channel)
|
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user