diff --git a/controller/billing.go b/controller/billing.go index 2ef2d99c..5f9de534 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -1,25 +1,27 @@ package controller import ( - "github.com/gin-gonic/gin" "one-api/common" "one-api/model" + + "github.com/gin-gonic/gin" ) func GetSubscription(c *gin.Context) { - var remainQuota int - var usedQuota int + var quota int var err error - var token *model.Token + var expirationDate int64 + + tokenId := c.GetInt("token_id") + token, err := model.GetTokenById(tokenId) + + expirationDate = token.ExpiredTime + if common.DisplayTokenStatEnabled { - tokenId := c.GetInt("token_id") - token, err = model.GetTokenById(tokenId) - remainQuota = token.RemainQuota - usedQuota = token.UsedQuota + quota = token.RemainQuota } else { userId := c.GetInt("id") - remainQuota, err = model.GetUserQuota(userId) - usedQuota, err = model.GetUserUsedQuota(userId) + quota, err = model.GetUserQuota(userId) } if err != nil { openAIError := OpenAIError{ @@ -31,7 +33,6 @@ func GetSubscription(c *gin.Context) { }) return } - quota := remainQuota + usedQuota amount := float64(quota) if common.DisplayInCurrencyEnabled { amount /= common.QuotaPerUnit @@ -45,6 +46,7 @@ func GetSubscription(c *gin.Context) { SoftLimitUSD: amount, HardLimitUSD: amount, SystemHardLimitUSD: amount, + AccessUntil: expirationDate, } c.JSON(200, subscription) return diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 31c9a133..894af426 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -22,6 +22,7 @@ type OpenAISubscriptionResponse struct { SoftLimitUSD float64 `json:"soft_limit_usd"` HardLimitUSD float64 `json:"hard_limit_usd"` SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` + AccessUntil int64 `json:"access_until"` } type OpenAIUsageDailyCost struct { @@ -96,6 +97,9 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He if err != nil { return nil, err } + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("status code: %d", res.StatusCode) + } body, err := io.ReadAll(res.Body) if err != nil { return nil, err