🐛 fix: 修复余额的问题
This commit is contained in:
parent
58fc40a744
commit
c97c8a0f65
3
.gitignore
vendored
3
.gitignore
vendored
@ -7,4 +7,5 @@ build
|
|||||||
*.db-journal
|
*.db-journal
|
||||||
logs
|
logs
|
||||||
data
|
data
|
||||||
tmp/
|
tmp/
|
||||||
|
.env
|
@ -6,6 +6,8 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/joho/godotenv"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -23,6 +25,11 @@ func printHelp() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
// 加载.env文件
|
||||||
|
err := godotenv.Load()
|
||||||
|
if err != nil {
|
||||||
|
SysLog("failed to load .env file: " + err.Error())
|
||||||
|
}
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if *PrintVersion {
|
if *PrintVersion {
|
||||||
|
@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/providers"
|
"one-api/providers"
|
||||||
@ -46,7 +47,18 @@ type OpenAIUsageResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||||
provider := providers.GetProvider(channel.Type, nil)
|
req, err := http.NewRequest("POST", "/balance", nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
setChannelToContext(c, channel)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
provider := providers.GetProvider(channel.Type, c)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
return 0, errors.New("provider not found")
|
return 0, errors.New("provider not found")
|
||||||
}
|
}
|
||||||
@ -56,7 +68,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|||||||
return 0, errors.New("provider not implemented")
|
return 0, errors.New("provider not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
return balanceProvider.BalanceAction(channel)
|
return balanceProvider.Balance(channel)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -30,32 +30,26 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
|||||||
c.Request = req
|
c.Request = req
|
||||||
|
|
||||||
setChannelToContext(c, channel)
|
setChannelToContext(c, channel)
|
||||||
|
// 创建映射
|
||||||
switch channel.Type {
|
channelTypeToModel := map[int]string{
|
||||||
case common.ChannelTypePaLM:
|
common.ChannelTypePaLM: "PaLM-2",
|
||||||
request.Model = "PaLM-2"
|
common.ChannelTypeAnthropic: "claude-2",
|
||||||
case common.ChannelTypeAnthropic:
|
common.ChannelTypeBaidu: "ERNIE-Bot",
|
||||||
request.Model = "claude-2"
|
common.ChannelTypeZhipu: "chatglm_lite",
|
||||||
case common.ChannelTypeBaidu:
|
common.ChannelTypeAli: "qwen-turbo",
|
||||||
request.Model = "ERNIE-Bot"
|
common.ChannelType360: "360GPT_S2_V9",
|
||||||
case common.ChannelTypeZhipu:
|
common.ChannelTypeXunfei: "SparkDesk",
|
||||||
request.Model = "chatglm_lite"
|
common.ChannelTypeTencent: "hunyuan",
|
||||||
case common.ChannelTypeAli:
|
common.ChannelTypeAzure: "gpt-3.5-turbo",
|
||||||
request.Model = "qwen-turbo"
|
|
||||||
case common.ChannelType360:
|
|
||||||
request.Model = "360GPT_S2_V9"
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
request.Model = "SparkDesk"
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeTencent:
|
|
||||||
request.Model = "hunyuan"
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
request.Model = "gpt-3.5-turbo"
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
default:
|
|
||||||
request.Model = "gpt-3.5-turbo"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 从映射中获取模型名称
|
||||||
|
model, ok := channelTypeToModel[channel.Type]
|
||||||
|
if !ok {
|
||||||
|
model = "gpt-3.5-turbo" // 默认值
|
||||||
|
}
|
||||||
|
request.Model = model
|
||||||
|
|
||||||
provider := providers.GetProvider(channel.Type, c)
|
provider := providers.GetProvider(channel.Type, c)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
return errors.New("channel not implemented"), nil
|
return errors.New("channel not implemented"), nil
|
||||||
@ -65,18 +59,16 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
|||||||
return errors.New("channel not implemented"), nil
|
return errors.New("channel not implemented"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
isModelMapped := false
|
|
||||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
if modelMap != nil && modelMap[request.Model] != "" {
|
if modelMap != nil && modelMap[request.Model] != "" {
|
||||||
request.Model = modelMap[request.Model]
|
request.Model = modelMap[request.Model]
|
||||||
isModelMapped = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
||||||
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, isModelMapped, promptTokens)
|
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if openAIErrorWithStatusCode != nil {
|
||||||
return nil, &openAIErrorWithStatusCode.OpenAIError
|
return nil, &openAIErrorWithStatusCode.OpenAIError
|
||||||
}
|
}
|
||||||
|
1
go.mod
1
go.mod
@ -42,6 +42,7 @@ require (
|
|||||||
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
|
github.com/joho/godotenv v1.5.1 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||||
github.com/leodido/go-urn v1.2.4 // indirect
|
github.com/leodido/go-urn v1.2.4 // indirect
|
||||||
|
2
go.sum
2
go.sum
@ -80,6 +80,8 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
|
|||||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
30
providers/aigc2d/balance.go
Normal file
30
providers/aigc2d/balance.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package aigc2d
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/providers/base"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) {
|
||||||
|
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
var response base.BalanceResponse
|
||||||
|
_, errWithCode := common.SendRequest(req, &response, false)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
channel.UpdateBalance(response.TotalAvailable)
|
||||||
|
|
||||||
|
return response.TotalAvailable, nil
|
||||||
|
}
|
20
providers/aigc2d/base.go
Normal file
20
providers/aigc2d/base.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package aigc2d
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/providers/base"
|
||||||
|
"one-api/providers/openai"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Aigc2dProviderFactory struct{}
|
||||||
|
|
||||||
|
func (f Aigc2dProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||||
|
return &Aigc2dProvider{
|
||||||
|
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aigc2d.com"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Aigc2dProvider struct {
|
||||||
|
*openai.OpenAIProvider
|
||||||
|
}
|
@ -4,10 +4,11 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/providers/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
|
func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
|
||||||
fullRequestURL := "https://api.aigc2d.com/dashboard/billing/credit_grants"
|
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
|
||||||
headers := p.GetRequestHeaders()
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
client := common.NewClient()
|
client := common.NewClient()
|
||||||
@ -17,7 +18,7 @@ func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
var response APGC2DGPTUsageResponse
|
var response base.BalanceResponse
|
||||||
_, errWithCode := common.SendRequest(req, &response, false)
|
_, errWithCode := common.SendRequest(req, &response, false)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||||
|
30
providers/api2gpt/balance.go
Normal file
30
providers/api2gpt/balance.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package api2gpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/providers/base"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) {
|
||||||
|
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
var response base.BalanceResponse
|
||||||
|
_, errWithCode := common.SendRequest(req, &response, false)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
channel.UpdateBalance(response.TotalAvailable)
|
||||||
|
|
||||||
|
return response.TotalRemaining, nil
|
||||||
|
}
|
20
providers/api2gpt/base.go
Normal file
20
providers/api2gpt/base.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package api2gpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/providers/base"
|
||||||
|
"one-api/providers/openai"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Api2gptProviderFactory struct{}
|
||||||
|
|
||||||
|
func (f Api2gptProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||||
|
return &Api2gptProvider{
|
||||||
|
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.api2gpt.com"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Api2gptProvider struct {
|
||||||
|
*openai.OpenAIProvider
|
||||||
|
}
|
@ -75,7 +75,7 @@ type ImageVariationsInterface interface {
|
|||||||
|
|
||||||
// 余额接口
|
// 余额接口
|
||||||
type BalanceInterface interface {
|
type BalanceInterface interface {
|
||||||
BalanceAction(channel *model.Channel) (float64, error)
|
Balance(channel *model.Channel) (float64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProviderResponseHandler interface {
|
type ProviderResponseHandler interface {
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
package api2d
|
package base
|
||||||
|
|
||||||
type APGC2DGPTUsageResponse struct {
|
type BalanceResponse struct {
|
||||||
//Grants interface{} `json:"grants"`
|
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
TotalAvailable float64 `json:"total_available"`
|
|
||||||
TotalGranted float64 `json:"total_granted"`
|
TotalGranted float64 `json:"total_granted"`
|
||||||
TotalUsed float64 `json:"total_used"`
|
TotalUsed float64 `json:"total_used"`
|
||||||
|
TotalRemaining float64 `json:"total_remaining"`
|
||||||
|
TotalAvailable float64 `json:"total_available"`
|
||||||
}
|
}
|
@ -2,9 +2,11 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/providers/aigc2d"
|
||||||
"one-api/providers/aiproxy"
|
"one-api/providers/aiproxy"
|
||||||
"one-api/providers/ali"
|
"one-api/providers/ali"
|
||||||
"one-api/providers/api2d"
|
"one-api/providers/api2d"
|
||||||
|
"one-api/providers/api2gpt"
|
||||||
"one-api/providers/azure"
|
"one-api/providers/azure"
|
||||||
"one-api/providers/baidu"
|
"one-api/providers/baidu"
|
||||||
"one-api/providers/base"
|
"one-api/providers/base"
|
||||||
@ -43,6 +45,9 @@ func init() {
|
|||||||
providerFactories[common.ChannelTypeAPI2D] = api2d.Api2dProviderFactory{}
|
providerFactories[common.ChannelTypeAPI2D] = api2d.Api2dProviderFactory{}
|
||||||
providerFactories[common.ChannelTypeCloseAI] = closeai.CloseaiProviderFactory{}
|
providerFactories[common.ChannelTypeCloseAI] = closeai.CloseaiProviderFactory{}
|
||||||
providerFactories[common.ChannelTypeOpenAISB] = openaisb.OpenaiSBProviderFactory{}
|
providerFactories[common.ChannelTypeOpenAISB] = openaisb.OpenaiSBProviderFactory{}
|
||||||
|
providerFactories[common.ChannelTypeAIGC2D] = aigc2d.Aigc2dProviderFactory{}
|
||||||
|
providerFactories[common.ChannelTypeAPI2GPT] = api2gpt.Api2gptProviderFactory{}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
|
Loading…
Reference in New Issue
Block a user