diff --git a/controller/channel-billing.go b/controller/channel-billing.go new file mode 100644 index 00000000..9166a964 --- /dev/null +++ b/controller/channel-billing.go @@ -0,0 +1,158 @@ +package controller + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + "time" +) + +type OpenAISubscriptionResponse struct { + HasPaymentMethod bool `json:"has_payment_method"` + HardLimitUSD float64 `json:"hard_limit_usd"` +} + +type OpenAIUsageResponse struct { + TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar +} + +func updateChannelBalance(channel *model.Channel) (float64, error) { + baseURL := common.ChannelBaseURLs[channel.Type] + switch channel.Type { + case common.ChannelTypeAzure: + return 0, errors.New("尚未实现") + } + url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) + + client := &http.Client{} + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return 0, err + } + auth := fmt.Sprintf("Bearer %s", channel.Key) + req.Header.Add("Authorization", auth) + res, err := client.Do(req) + if err != nil { + return 0, err + } + body, err := io.ReadAll(res.Body) + if err != nil { + return 0, err + } + err = res.Body.Close() + if err != nil { + return 0, err + } + subscription := OpenAISubscriptionResponse{} + err = json.Unmarshal(body, &subscription) + if err != nil { + return 0, err + } + now := time.Now() + startDate := fmt.Sprintf("%s-01", now.Format("2006-01")) + //endDate := now.Format("2006-01-02") + url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, "2023-06-01") + req, err = http.NewRequest("GET", url, nil) + if err != nil { + return 0, err + } + req.Header.Add("Authorization", auth) + res, err = client.Do(req) + if err != nil { + return 0, err + } + body, err = io.ReadAll(res.Body) + if err != nil { + return 0, err + } + err = res.Body.Close() + if err != nil { + return 0, err + } + usage := OpenAIUsageResponse{} + err = json.Unmarshal(body, &usage) + if err != nil { + return 0, err + } + balance := subscription.HardLimitUSD - usage.TotalUsage/100 + channel.UpdateBalance(balance) + return balance, nil +} + +func UpdateChannelBalance(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + balance, err := updateChannelBalance(channel) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "balance": balance, + }) + return +} + +func updateAllChannelsBalance() error { + channels, err := model.GetAllChannels(0, 0, true) + if err != nil { + return err + } + for _, channel := range channels { + if channel.Status != common.ChannelStatusEnabled { + continue + } + balance, err := updateChannelBalance(channel) + if err != nil { + continue + } else { + // err is nil & balance <= 0 means quota is used up + if balance <= 0 { + disableChannel(channel.Id, channel.Name, "余额不足") + } + } + } + return nil +} + +func UpdateAllChannelsBalance(c *gin.Context) { + // TODO: make it async + err := updateAllChannelsBalance() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/model/channel.go b/model/channel.go index 0335207b..35d65827 100644 --- a/model/channel.go +++ b/model/channel.go @@ -6,17 +6,19 @@ import ( ) type Channel struct { - Id int `json:"id"` - Type int `json:"type" gorm:"default:0"` - Key string `json:"key" gorm:"not null"` - Status int `json:"status" gorm:"default:1"` - Name string `json:"name" gorm:"index"` - Weight int `json:"weight"` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - TestTime int64 `json:"test_time" gorm:"bigint"` - ResponseTime int `json:"response_time"` // in milliseconds - BaseURL string `json:"base_url" gorm:"column:base_url"` - Other string `json:"other"` + Id int `json:"id"` + Type int `json:"type" gorm:"default:0"` + Key string `json:"key" gorm:"not null"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index"` + Weight int `json:"weight"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + TestTime int64 `json:"test_time" gorm:"bigint"` + ResponseTime int `json:"response_time"` // in milliseconds + BaseURL string `json:"base_url" gorm:"column:base_url"` + Other string `json:"other"` + Balance float64 `json:"balance"` // in USD + BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -86,6 +88,16 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { } } +func (channel *Channel) UpdateBalance(balance float64) { + err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ + BalanceUpdatedTime: common.GetTimestamp(), + Balance: balance, + }).Error + if err != nil { + common.SysError("failed to update balance: " + err.Error()) + } +} + func (channel *Channel) Delete() error { var err error err = DB.Delete(channel).Error diff --git a/router/api-router.go b/router/api-router.go index 5cd86e3e..9ca2226a 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -66,6 +66,8 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/test", controller.TestAllChannels) channelRoute.GET("/test/:id", controller.TestChannel) + channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) + channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) channelRoute.POST("/", controller.AddChannel) channelRoute.PUT("/", controller.UpdateChannel) channelRoute.DELETE("/:id", controller.DeleteChannel) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index f0f33e96..25183621 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -32,6 +32,7 @@ const ChannelsTable = () => { const [activePage, setActivePage] = useState(1); const [searchKeyword, setSearchKeyword] = useState(''); const [searching, setSearching] = useState(false); + const [updatingBalance, setUpdatingBalance] = useState(false); const loadChannels = async (startIdx) => { const res = await API.get(`/api/channel/?p=${startIdx}`); @@ -63,7 +64,7 @@ const ChannelsTable = () => { const refresh = async () => { setLoading(true); await loadChannels(0); - } + }; useEffect(() => { loadChannels(0) @@ -127,7 +128,7 @@ const ChannelsTable = () => { const renderResponseTime = (responseTime) => { let time = responseTime / 1000; - time = time.toFixed(2) + " 秒"; + time = time.toFixed(2) + ' 秒'; if (responseTime === 0) { return ; } else if (responseTime <= 1000) { @@ -179,11 +180,38 @@ const ChannelsTable = () => { const res = await API.get(`/api/channel/test`); const { success, message } = res.data; if (success) { - showInfo("已成功开始测试所有已启用通道,请刷新页面查看结果。"); + showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。'); } else { showError(message); } - } + }; + + const updateChannelBalance = async (id, name, idx) => { + const res = await API.get(`/api/channel/update_balance/${id}/`); + const { success, message, balance } = res.data; + if (success) { + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].balance = balance; + newChannels[realIdx].balance_updated_time = Date.now() / 1000; + setChannels(newChannels); + showInfo(`通道 ${name} 余额更新成功!`); + } else { + showError(message); + } + }; + + const updateAllChannelsBalance = async () => { + setUpdatingBalance(true); + const res = await API.get(`/api/channel/update_balance`); + const { success, message } = res.data; + if (success) { + showInfo('已更新完毕所有已启用通道余额!'); + } else { + showError(message); + } + setUpdatingBalance(false); + }; const handleKeywordChange = async (e, { value }) => { setSearchKeyword(value.trim()); @@ -263,10 +291,10 @@ const ChannelsTable = () => {