添加余额查询方法

This commit is contained in:
MartialBE 2023-12-02 17:51:28 +08:00
parent d8b13b2c07
commit 5e08cc8719
8 changed files with 116 additions and 5 deletions

View File

@ -0,0 +1,35 @@
package aiproxy
import (
"errors"
"fmt"
"one-api/common"
"one-api/model"
)
func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := "https://aiproxy.io/api/report/getUserOverview"
headers := make(map[string]string)
headers["Api-Key"] = channel.Key
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response AIProxyUserOverviewResponse
_, errWithCode := common.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
if !response.Success {
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
}
channel.UpdateBalance(response.Data.TotalPoints)
return response.Data.TotalPoints, nil
}

18
providers/aiproxy/base.go Normal file
View File

@ -0,0 +1,18 @@
package aiproxy
import (
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type AIProxyProvider struct {
*openai.OpenAIProvider
}
// 创建 CreateAIProxyProvider
func CreateAIProxyProvider(c *gin.Context) *AIProxyProvider {
return &AIProxyProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aiproxy.io"),
}
}

10
providers/aiproxy/type.go Normal file
View File

@ -0,0 +1,10 @@
package aiproxy
type AIProxyUserOverviewResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
ErrorCode int `json:"error_code"`
Data struct {
TotalPoints float64 `json:"totalPoints"`
} `json:"data"`
}

View File

@ -0,0 +1,29 @@
package api2d
import (
"errors"
"one-api/common"
"one-api/model"
)
func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := "https://api.aigc2d.com/dashboard/billing/credit_grants"
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response APGC2DGPTUsageResponse
_, errWithCode := common.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
channel.UpdateBalance(response.TotalAvailable)
return response.TotalAvailable, nil
}

9
providers/api2d/type.go Normal file
View File

@ -0,0 +1,9 @@
package api2d
type APGC2DGPTUsageResponse struct {
//Grants interface{} `json:"grants"`
Object string `json:"object"`
TotalAvailable float64 `json:"total_available"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
}

View File

@ -2,18 +2,16 @@ package closeai
import (
"errors"
"fmt"
"one-api/common"
"one-api/model"
)
func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers))
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
if err != nil {
return 0, err
}

View File

@ -14,7 +14,7 @@ func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers))
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
if err != nil {
return 0, err
}

View File

@ -2,12 +2,16 @@ package providers
import (
"one-api/common"
"one-api/providers/aiproxy"
"one-api/providers/ali"
"one-api/providers/api2d"
"one-api/providers/azure"
"one-api/providers/baidu"
"one-api/providers/base"
"one-api/providers/claude"
"one-api/providers/closeai"
"one-api/providers/openai"
"one-api/providers/openaisb"
"one-api/providers/palm"
"one-api/providers/tencent"
"one-api/providers/xunfei"
@ -36,6 +40,14 @@ func GetProvider(channelType int, c *gin.Context) base.ProviderInterface {
return zhipu.CreateZhipuProvider(c)
case common.ChannelTypeXunfei:
return xunfei.CreateXunfeiProvider(c)
case common.ChannelTypeAIProxy:
return aiproxy.CreateAIProxyProvider(c)
case common.ChannelTypeAPI2D:
return api2d.CreateApi2dProvider(c)
case common.ChannelTypeCloseAI:
return closeai.CreateCloseaiProxyProvider(c)
case common.ChannelTypeOpenAISB:
return openaisb.CreateOpenaiSBProvider(c)
default:
baseURL := common.ChannelBaseURLs[channelType]
if c.GetString("base_url") != "" {