feat: support channel remain quota query (close #79)

This commit is contained in:
JustSong 2023-05-21 16:09:54 +08:00
parent bcca0cc0bc
commit 171b818504
4 changed files with 244 additions and 19 deletions

View File

@ -0,0 +1,158 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
type OpenAISubscriptionResponse struct {
HasPaymentMethod bool `json:"has_payment_method"`
HardLimitUSD float64 `json:"hard_limit_usd"`
}
type OpenAIUsageResponse struct {
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type]
switch channel.Type {
case common.ChannelTypeAzure:
return 0, errors.New("尚未实现")
}
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
client := &http.Client{}
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return 0, err
}
auth := fmt.Sprintf("Bearer %s", channel.Key)
req.Header.Add("Authorization", auth)
res, err := client.Do(req)
if err != nil {
return 0, err
}
body, err := io.ReadAll(res.Body)
if err != nil {
return 0, err
}
err = res.Body.Close()
if err != nil {
return 0, err
}
subscription := OpenAISubscriptionResponse{}
err = json.Unmarshal(body, &subscription)
if err != nil {
return 0, err
}
now := time.Now()
startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
//endDate := now.Format("2006-01-02")
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, "2023-06-01")
req, err = http.NewRequest("GET", url, nil)
if err != nil {
return 0, err
}
req.Header.Add("Authorization", auth)
res, err = client.Do(req)
if err != nil {
return 0, err
}
body, err = io.ReadAll(res.Body)
if err != nil {
return 0, err
}
err = res.Body.Close()
if err != nil {
return 0, err
}
usage := OpenAIUsageResponse{}
err = json.Unmarshal(body, &usage)
if err != nil {
return 0, err
}
balance := subscription.HardLimitUSD - usage.TotalUsage/100
channel.UpdateBalance(balance)
return balance, nil
}
func UpdateChannelBalance(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
balance, err := updateChannelBalance(channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"balance": balance,
})
return
}
func updateAllChannelsBalance() error {
channels, err := model.GetAllChannels(0, 0, true)
if err != nil {
return err
}
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
balance, err := updateChannelBalance(channel)
if err != nil {
continue
} else {
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
disableChannel(channel.Id, channel.Name, "余额不足")
}
}
}
return nil
}
func UpdateAllChannelsBalance(c *gin.Context) {
// TODO: make it async
err := updateAllChannelsBalance()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

View File

@ -17,6 +17,8 @@ type Channel struct {
ResponseTime int `json:"response_time"` // in milliseconds ResponseTime int `json:"response_time"` // in milliseconds
BaseURL string `json:"base_url" gorm:"column:base_url"` BaseURL string `json:"base_url" gorm:"column:base_url"`
Other string `json:"other"` Other string `json:"other"`
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
} }
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@ -86,6 +88,16 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
} }
} }
func (channel *Channel) UpdateBalance(balance float64) {
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
BalanceUpdatedTime: common.GetTimestamp(),
Balance: balance,
}).Error
if err != nil {
common.SysError("failed to update balance: " + err.Error())
}
}
func (channel *Channel) Delete() error { func (channel *Channel) Delete() error {
var err error var err error
err = DB.Delete(channel).Error err = DB.Delete(channel).Error

View File

@ -66,6 +66,8 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/:id", controller.GetChannel)
channelRoute.GET("/test", controller.TestAllChannels) channelRoute.GET("/test", controller.TestAllChannels)
channelRoute.GET("/test/:id", controller.TestChannel) channelRoute.GET("/test/:id", controller.TestChannel)
channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance)
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
channelRoute.POST("/", controller.AddChannel) channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel) channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/:id", controller.DeleteChannel) channelRoute.DELETE("/:id", controller.DeleteChannel)

View File

@ -32,6 +32,7 @@ const ChannelsTable = () => {
const [activePage, setActivePage] = useState(1); const [activePage, setActivePage] = useState(1);
const [searchKeyword, setSearchKeyword] = useState(''); const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false); const [searching, setSearching] = useState(false);
const [updatingBalance, setUpdatingBalance] = useState(false);
const loadChannels = async (startIdx) => { const loadChannels = async (startIdx) => {
const res = await API.get(`/api/channel/?p=${startIdx}`); const res = await API.get(`/api/channel/?p=${startIdx}`);
@ -63,7 +64,7 @@ const ChannelsTable = () => {
const refresh = async () => { const refresh = async () => {
setLoading(true); setLoading(true);
await loadChannels(0); await loadChannels(0);
} };
useEffect(() => { useEffect(() => {
loadChannels(0) loadChannels(0)
@ -127,7 +128,7 @@ const ChannelsTable = () => {
const renderResponseTime = (responseTime) => { const renderResponseTime = (responseTime) => {
let time = responseTime / 1000; let time = responseTime / 1000;
time = time.toFixed(2) + " 秒"; time = time.toFixed(2) + ' 秒';
if (responseTime === 0) { if (responseTime === 0) {
return <Label basic color='grey'>未测试</Label>; return <Label basic color='grey'>未测试</Label>;
} else if (responseTime <= 1000) { } else if (responseTime <= 1000) {
@ -179,11 +180,38 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/test`); const res = await API.get(`/api/channel/test`);
const { success, message } = res.data; const { success, message } = res.data;
if (success) { if (success) {
showInfo("已成功开始测试所有已启用通道,请刷新页面查看结果。"); showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。');
} else { } else {
showError(message); showError(message);
} }
};
const updateChannelBalance = async (id, name, idx) => {
const res = await API.get(`/api/channel/update_balance/${id}/`);
const { success, message, balance } = res.data;
if (success) {
let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].balance = balance;
newChannels[realIdx].balance_updated_time = Date.now() / 1000;
setChannels(newChannels);
showInfo(`通道 ${name} 余额更新成功!`);
} else {
showError(message);
} }
};
const updateAllChannelsBalance = async () => {
setUpdatingBalance(true);
const res = await API.get(`/api/channel/update_balance`);
const { success, message } = res.data;
if (success) {
showInfo('已更新完毕所有已启用通道余额!');
} else {
showError(message);
}
setUpdatingBalance(false);
};
const handleKeywordChange = async (e, { value }) => { const handleKeywordChange = async (e, { value }) => {
setSearchKeyword(value.trim()); setSearchKeyword(value.trim());
@ -263,10 +291,10 @@ const ChannelsTable = () => {
<Table.HeaderCell <Table.HeaderCell
style={{ cursor: 'pointer' }} style={{ cursor: 'pointer' }}
onClick={() => { onClick={() => {
sortChannel('test_time'); sortChannel('balance');
}} }}
> >
测试时间 余额
</Table.HeaderCell> </Table.HeaderCell>
<Table.HeaderCell>操作</Table.HeaderCell> <Table.HeaderCell>操作</Table.HeaderCell>
</Table.Row> </Table.Row>
@ -286,8 +314,22 @@ const ChannelsTable = () => {
<Table.Cell>{channel.name ? channel.name : '无'}</Table.Cell> <Table.Cell>{channel.name ? channel.name : '无'}</Table.Cell>
<Table.Cell>{renderType(channel.type)}</Table.Cell> <Table.Cell>{renderType(channel.type)}</Table.Cell>
<Table.Cell>{renderStatus(channel.status)}</Table.Cell> <Table.Cell>{renderStatus(channel.status)}</Table.Cell>
<Table.Cell>{renderResponseTime(channel.response_time)}</Table.Cell> <Table.Cell>
<Table.Cell>{channel.test_time ? renderTimestamp(channel.test_time) : "未测试"}</Table.Cell> <Popup
content={channel.test_time ? renderTimestamp(channel.test_time) : '未测试'}
key={channel.id}
trigger={renderResponseTime(channel.response_time)}
basic
/>
</Table.Cell>
<Table.Cell>
<Popup
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
key={channel.id}
trigger={<span>${channel.balance.toFixed(2)}</span>}
basic
/>
</Table.Cell>
<Table.Cell> <Table.Cell>
<div> <div>
<Button <Button
@ -299,6 +341,16 @@ const ChannelsTable = () => {
> >
测试 测试
</Button> </Button>
<Button
size={'small'}
positive
loading={updatingBalance}
onClick={() => {
updateChannelBalance(channel.id, channel.name, idx);
}}
>
更新余额
</Button>
<Popup <Popup
trigger={ trigger={
<Button size='small' negative> <Button size='small' negative>
@ -353,6 +405,7 @@ const ChannelsTable = () => {
<Button size='small' loading={loading} onClick={testAllChannels}> <Button size='small' loading={loading} onClick={testAllChannels}>
测试所有已启用通道 测试所有已启用通道
</Button> </Button>
<Button size='small' onClick={updateAllChannelsBalance} loading={updatingBalance}>更新所有已启用通道余额</Button>
<Pagination <Pagination
floated='right' floated='right'
activePage={activePage} activePage={activePage}