diff --git a/controller/channel-billing.go b/controller/channel-billing.go index b1926545..66a4338e 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -4,13 +4,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" "strconv" "time" + + "github.com/gin-gonic/gin" ) // https://github.com/songquanpeng/one-api/issues/79 @@ -44,12 +45,31 @@ type OpenAISBUsageResponse struct { } `json:"data"` } -func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error) { +type AIProxyUserOverviewResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + ErrorCode int `json:"error_code"` + Data struct { + TotalPoints float64 `json:"totalPoints"` + } `json:"data"` +} + +// GetAuthHeader get auth header +func GetAuthHeader(token string) http.Header { + h := http.Header{} + h.Add("Authorization", fmt.Sprintf("Bearer %s", token)) + return h +} + +func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { client := &http.Client{} req, err := http.NewRequest(method, url, nil) if err != nil { return nil, err } + for k := range headers { + req.Header.Add(k, headers.Get(k)) + } auth := fmt.Sprintf("Bearer %s", channel.Key) req.Header.Add("Authorization", auth) res, err := client.Do(req) @@ -69,7 +89,7 @@ func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error) func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) - body, err := GetResponseBody("GET", url, channel) + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } @@ -89,6 +109,26 @@ func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { return balance, nil } +func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) { + url := "https://aiproxy.io/api/report/getUserOverview" + headers := http.Header{} + headers.Add("Api-Key", channel.Key) + body, err := GetResponseBody("GET", url, channel, headers) + if err != nil { + return 0, err + } + response := AIProxyUserOverviewResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + if !response.Success { + return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) + } + channel.UpdateBalance(response.Data.TotalPoints) + return response.Data.TotalPoints, nil +} + func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := common.ChannelBaseURLs[channel.Type] switch channel.Type { @@ -102,12 +142,14 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL = channel.BaseURL case common.ChannelTypeOpenAISB: return updateChannelOpenAISBBalance(channel) + case common.ChannelTypeAIProxy: + return updateChannelAIProxyBalance(channel) default: return 0, errors.New("尚未实现") } url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) - body, err := GetResponseBody("GET", url, channel) + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } @@ -123,7 +165,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { startDate = now.AddDate(0, 0, -100).Format("2006-01-02") } url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) - body, err = GetResponseBody("GET", url, channel) + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err }