349 lines
9.0 KiB
Go
349 lines
9.0 KiB
Go
package controller
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/songquanpeng/one-api/common"
|
|
"github.com/songquanpeng/one-api/common/config"
|
|
"github.com/songquanpeng/one-api/common/logger"
|
|
"github.com/songquanpeng/one-api/model"
|
|
"github.com/songquanpeng/one-api/relay/util"
|
|
"io"
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// https://github.com/songquanpeng/one-api/issues/79
|
|
|
|
type OpenAISubscriptionResponse struct {
|
|
Object string `json:"object"`
|
|
HasPaymentMethod bool `json:"has_payment_method"`
|
|
SoftLimitUSD float64 `json:"soft_limit_usd"`
|
|
HardLimitUSD float64 `json:"hard_limit_usd"`
|
|
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
|
|
AccessUntil int64 `json:"access_until"`
|
|
}
|
|
|
|
type OpenAIUsageDailyCost struct {
|
|
Timestamp float64 `json:"timestamp"`
|
|
LineItems []struct {
|
|
Name string `json:"name"`
|
|
Cost float64 `json:"cost"`
|
|
}
|
|
}
|
|
|
|
type OpenAICreditGrants struct {
|
|
Object string `json:"object"`
|
|
TotalGranted float64 `json:"total_granted"`
|
|
TotalUsed float64 `json:"total_used"`
|
|
TotalAvailable float64 `json:"total_available"`
|
|
}
|
|
|
|
type OpenAIUsageResponse struct {
|
|
Object string `json:"object"`
|
|
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
|
|
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
|
|
}
|
|
|
|
type OpenAISBUsageResponse struct {
|
|
Msg string `json:"msg"`
|
|
Data *struct {
|
|
Credit string `json:"credit"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
type AIProxyUserOverviewResponse struct {
|
|
Success bool `json:"success"`
|
|
Message string `json:"message"`
|
|
ErrorCode int `json:"error_code"`
|
|
Data struct {
|
|
TotalPoints float64 `json:"totalPoints"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
type API2GPTUsageResponse struct {
|
|
Object string `json:"object"`
|
|
TotalGranted float64 `json:"total_granted"`
|
|
TotalUsed float64 `json:"total_used"`
|
|
TotalRemaining float64 `json:"total_remaining"`
|
|
}
|
|
|
|
type APGC2DGPTUsageResponse struct {
|
|
//Grants interface{} `json:"grants"`
|
|
Object string `json:"object"`
|
|
TotalAvailable float64 `json:"total_available"`
|
|
TotalGranted float64 `json:"total_granted"`
|
|
TotalUsed float64 `json:"total_used"`
|
|
}
|
|
|
|
// GetAuthHeader get auth header
|
|
func GetAuthHeader(token string) http.Header {
|
|
h := http.Header{}
|
|
h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
|
return h
|
|
}
|
|
|
|
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
|
|
req, err := http.NewRequest(method, url, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for k := range headers {
|
|
req.Header.Add(k, headers.Get(k))
|
|
}
|
|
res, err := util.HTTPClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if res.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("status code: %d", res.StatusCode)
|
|
}
|
|
body, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = res.Body.Close()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
|
|
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
|
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
response := OpenAICreditGrants{}
|
|
err = json.Unmarshal(body, &response)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
channel.UpdateBalance(response.TotalAvailable)
|
|
return response.TotalAvailable, nil
|
|
}
|
|
|
|
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
|
|
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
|
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
response := OpenAISBUsageResponse{}
|
|
err = json.Unmarshal(body, &response)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if response.Data == nil {
|
|
return 0, errors.New(response.Msg)
|
|
}
|
|
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
channel.UpdateBalance(balance)
|
|
return balance, nil
|
|
}
|
|
|
|
func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
|
|
url := "https://aiproxy.io/api/report/getUserOverview"
|
|
headers := http.Header{}
|
|
headers.Add("Api-Key", channel.Key)
|
|
body, err := GetResponseBody("GET", url, channel, headers)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
response := AIProxyUserOverviewResponse{}
|
|
err = json.Unmarshal(body, &response)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if !response.Success {
|
|
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
|
|
}
|
|
channel.UpdateBalance(response.Data.TotalPoints)
|
|
return response.Data.TotalPoints, nil
|
|
}
|
|
|
|
func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
|
|
url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
|
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
response := API2GPTUsageResponse{}
|
|
err = json.Unmarshal(body, &response)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
channel.UpdateBalance(response.TotalRemaining)
|
|
return response.TotalRemaining, nil
|
|
}
|
|
|
|
func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
|
|
url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
|
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
response := APGC2DGPTUsageResponse{}
|
|
err = json.Unmarshal(body, &response)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
channel.UpdateBalance(response.TotalAvailable)
|
|
return response.TotalAvailable, nil
|
|
}
|
|
|
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|
baseURL := common.ChannelBaseURLs[channel.Type]
|
|
if channel.GetBaseURL() == "" {
|
|
channel.BaseURL = &baseURL
|
|
}
|
|
switch channel.Type {
|
|
case common.ChannelTypeOpenAI:
|
|
if channel.GetBaseURL() != "" {
|
|
baseURL = channel.GetBaseURL()
|
|
}
|
|
case common.ChannelTypeAzure:
|
|
return 0, errors.New("尚未实现")
|
|
case common.ChannelTypeCustom:
|
|
baseURL = channel.GetBaseURL()
|
|
case common.ChannelTypeCloseAI:
|
|
return updateChannelCloseAIBalance(channel)
|
|
case common.ChannelTypeOpenAISB:
|
|
return updateChannelOpenAISBBalance(channel)
|
|
case common.ChannelTypeAIProxy:
|
|
return updateChannelAIProxyBalance(channel)
|
|
case common.ChannelTypeAPI2GPT:
|
|
return updateChannelAPI2GPTBalance(channel)
|
|
case common.ChannelTypeAIGC2D:
|
|
return updateChannelAIGC2DBalance(channel)
|
|
default:
|
|
return 0, errors.New("尚未实现")
|
|
}
|
|
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
|
|
|
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
subscription := OpenAISubscriptionResponse{}
|
|
err = json.Unmarshal(body, &subscription)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
now := time.Now()
|
|
startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
|
|
endDate := now.Format("2006-01-02")
|
|
if !subscription.HasPaymentMethod {
|
|
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
|
|
}
|
|
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
|
|
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
usage := OpenAIUsageResponse{}
|
|
err = json.Unmarshal(body, &usage)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
balance := subscription.HardLimitUSD - usage.TotalUsage/100
|
|
channel.UpdateBalance(balance)
|
|
return balance, nil
|
|
}
|
|
|
|
func UpdateChannelBalance(c *gin.Context) {
|
|
id, err := strconv.Atoi(c.Param("id"))
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
channel, err := model.GetChannelById(id, true)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
balance, err := updateChannelBalance(channel)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"balance": balance,
|
|
})
|
|
return
|
|
}
|
|
|
|
func updateAllChannelsBalance() error {
|
|
channels, err := model.GetAllChannels(0, 0, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, channel := range channels {
|
|
if channel.Status != common.ChannelStatusEnabled {
|
|
continue
|
|
}
|
|
// TODO: support Azure
|
|
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
|
continue
|
|
}
|
|
balance, err := updateChannelBalance(channel)
|
|
if err != nil {
|
|
continue
|
|
} else {
|
|
// err is nil & balance <= 0 means quota is used up
|
|
if balance <= 0 {
|
|
disableChannel(channel.Id, channel.Name, "余额不足")
|
|
}
|
|
}
|
|
time.Sleep(config.RequestInterval)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func UpdateAllChannelsBalance(c *gin.Context) {
|
|
// TODO: make it async
|
|
err := updateAllChannelsBalance()
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
})
|
|
return
|
|
}
|
|
|
|
func AutomaticallyUpdateChannels(frequency int) {
|
|
for {
|
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
|
logger.SysLog("updating all channels")
|
|
_ = updateAllChannelsBalance()
|
|
logger.SysLog("channels update done")
|
|
}
|
|
}
|