Merge 563b6e123f
into 9fc5f427dc
This commit is contained in:
commit
1ca677e847
3
.gitignore
vendored
3
.gitignore
vendored
@ -8,4 +8,5 @@ build
|
||||
logs
|
||||
data
|
||||
/web/node_modules
|
||||
cmd.md
|
||||
cmd.md
|
||||
vendor/*
|
||||
|
27
common/audit/audit.go
Normal file
27
common/audit/audit.go
Normal file
@ -0,0 +1,27 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
var (
|
||||
loger *lumberjack.Logger
|
||||
logger *logrus.Logger
|
||||
)
|
||||
|
||||
func init() {
|
||||
loger = &lumberjack.Logger{
|
||||
Filename: "logs/audit.log",
|
||||
MaxSize: 50, // megabytes
|
||||
MaxBackups: 300,
|
||||
MaxAge: 90, // days
|
||||
}
|
||||
logger = logrus.New()
|
||||
logger.SetOutput(loger)
|
||||
logger.SetFormatter(&logrus.JSONFormatter{})
|
||||
}
|
||||
|
||||
func Logger() *logrus.Logger {
|
||||
return logger
|
||||
}
|
79
common/audit/response.go
Normal file
79
common/audit/response.go
Normal file
@ -0,0 +1,79 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type AuditLogger struct {
|
||||
gin.ResponseWriter
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func (l *AuditLogger) Write(p []byte) (int, error) {
|
||||
l.buf.Write(p)
|
||||
return l.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
func CaptureResponseBody(c *gin.Context) *bytes.Buffer {
|
||||
al := &AuditLogger{
|
||||
ResponseWriter: c.Writer,
|
||||
buf: &bytes.Buffer{},
|
||||
}
|
||||
c.Writer = al
|
||||
return al.buf
|
||||
}
|
||||
|
||||
func B64encode(data []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
type AuditReadCloser struct {
|
||||
Reader io.Reader
|
||||
Closer io.Closer
|
||||
Buffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func (arc *AuditReadCloser) Read(p []byte) (int, error) {
|
||||
n, err := arc.Reader.Read(p)
|
||||
if n > 0 {
|
||||
arc.Buffer.Write(p[:n])
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (arc *AuditReadCloser) Close() error {
|
||||
return arc.Closer.Close()
|
||||
}
|
||||
|
||||
func CaptureHTTPResponseBody(resp *http.Response) *bytes.Buffer {
|
||||
buf := &bytes.Buffer{}
|
||||
arc := &AuditReadCloser{
|
||||
Reader: resp.Body,
|
||||
Closer: resp.Body,
|
||||
Buffer: buf,
|
||||
}
|
||||
resp.Body = arc
|
||||
return buf
|
||||
}
|
||||
|
||||
func ParseOPENAIStreamResponse(buf *bytes.Buffer) string {
|
||||
lines := strings.Split(buf.String(), "\n")
|
||||
bts := []string{}
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
line = strings.Trim(line, "\n")
|
||||
if strings.HasPrefix(string(line), "data:") {
|
||||
line = line[5:]
|
||||
}
|
||||
content := gjson.Get(line, "choices.0.delta.content").String()
|
||||
bts = append(bts, content)
|
||||
}
|
||||
return strings.Join(bts, "")
|
||||
}
|
@ -1,13 +1,14 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common/env"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/env"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@ -55,6 +56,8 @@ var EmailDomainWhitelist = []string{
|
||||
var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true"
|
||||
var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true"
|
||||
var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true"
|
||||
var ClientAuditEnabled = env.Bool("CLIENT_AUDIT_ENABLED", false)
|
||||
var UpstreamAuditEnabled = env.Bool("UPSTREAM_AUDIT_ENABLED", false)
|
||||
|
||||
var LogConsumeEnabled = true
|
||||
|
||||
@ -135,6 +138,7 @@ var (
|
||||
|
||||
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
||||
|
||||
var EnableBilling = env.Bool("ENABLE_BILLING", true)
|
||||
var EnableMetric = env.Bool("ENABLE_METRIC", false)
|
||||
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
|
||||
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/audit"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
@ -17,48 +18,97 @@ import (
|
||||
dbmodel "github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/monitor"
|
||||
"github.com/songquanpeng/one-api/relay/controller"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat
|
||||
|
||||
func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||
type Options struct {
|
||||
Debug bool
|
||||
EnableMonitor bool
|
||||
EnableBilling bool
|
||||
}
|
||||
|
||||
type RelayController struct {
|
||||
opts Options
|
||||
controller.RelayInstance
|
||||
monitor.MonitorInstance
|
||||
}
|
||||
|
||||
func NewRelayController(opts Options) *RelayController {
|
||||
ctrl := &RelayController{
|
||||
opts: opts,
|
||||
}
|
||||
ctrl.RelayInstance = controller.NewRelayInstance(controller.Options{
|
||||
EnableBilling: opts.EnableBilling,
|
||||
})
|
||||
if opts.EnableMonitor {
|
||||
ctrl.MonitorInstance = monitor.NewMonitorInstance()
|
||||
}
|
||||
return ctrl
|
||||
}
|
||||
|
||||
func (ctrl *RelayController) relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||
if config.ClientAuditEnabled {
|
||||
buf := audit.CaptureResponseBody(c)
|
||||
m := meta.GetByContext(c)
|
||||
defer func() {
|
||||
audit.Logger().
|
||||
WithField("raw", audit.B64encode(buf.Bytes())).
|
||||
WithField("parsed", audit.ParseOPENAIStreamResponse(buf)).
|
||||
WithField("requestid", c.GetString(helper.RequestIdKey)).
|
||||
WithFields(m.ToLogrusFields()).
|
||||
Info("client response")
|
||||
}()
|
||||
}
|
||||
var err *model.ErrorWithStatusCode
|
||||
switch relayMode {
|
||||
case relaymode.ImagesGenerations:
|
||||
err = controller.RelayImageHelper(c, relayMode)
|
||||
err = ctrl.RelayImageHelper(c, relayMode)
|
||||
case relaymode.AudioSpeech:
|
||||
fallthrough
|
||||
case relaymode.AudioTranslation:
|
||||
fallthrough
|
||||
case relaymode.AudioTranscription:
|
||||
err = controller.RelayAudioHelper(c, relayMode)
|
||||
err = ctrl.RelayAudioHelper(c, relayMode)
|
||||
default:
|
||||
err = controller.RelayTextHelper(c)
|
||||
err = ctrl.RelayTextHelper(c)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
func (ctrl *RelayController) Relay(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
relayMode := relaymode.GetByPath(c.Request.URL.Path)
|
||||
if config.DebugEnabled {
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
logger.Debugf(ctx, "request body: %s", string(requestBody))
|
||||
}
|
||||
if config.ClientAuditEnabled {
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
m := meta.GetByContext(c)
|
||||
audit.Logger().
|
||||
WithField("raw", audit.B64encode(requestBody)).
|
||||
WithField("requestid", c.GetString(helper.RequestIdKey)).
|
||||
WithFields(m.ToLogrusFields()).
|
||||
Info("client request")
|
||||
}
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
userId := c.GetInt("id")
|
||||
bizErr := relayHelper(c, relayMode)
|
||||
bizErr := ctrl.relayHelper(c, relayMode)
|
||||
if bizErr == nil {
|
||||
monitor.Emit(channelId, true)
|
||||
if ctrl.MonitorInstance != nil {
|
||||
ctrl.Emit(channelId, true)
|
||||
}
|
||||
return
|
||||
}
|
||||
lastFailedChannelId := channelId
|
||||
channelName := c.GetString(ctxkey.ChannelName)
|
||||
group := c.GetString(ctxkey.Group)
|
||||
originalModel := c.GetString(ctxkey.OriginalModel)
|
||||
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
go ctrl.processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
||||
requestId := c.GetString(helper.RequestIdKey)
|
||||
retryTimes := config.RetryTimes
|
||||
if !shouldRetry(c, bizErr.StatusCode) {
|
||||
@ -77,15 +127,19 @@ func Relay(c *gin.Context) {
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
requestBody, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "GetRequestBody failed: %+v", err)
|
||||
break
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
bizErr = relayHelper(c, relayMode)
|
||||
bizErr = ctrl.relayHelper(c, relayMode)
|
||||
if bizErr == nil {
|
||||
return
|
||||
}
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
lastFailedChannelId = channelId
|
||||
channelName := c.GetString(ctxkey.ChannelName)
|
||||
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
||||
go ctrl.processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
||||
}
|
||||
if bizErr != nil {
|
||||
if bizErr.StatusCode == http.StatusTooManyRequests {
|
||||
@ -117,13 +171,16 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
|
||||
func (ctrl *RelayController) processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
|
||||
if ctrl.MonitorInstance == nil {
|
||||
return
|
||||
}
|
||||
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
|
||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
||||
monitor.DisableChannel(channelId, channelName, err.Message)
|
||||
if ctrl.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
||||
ctrl.DisableChannel(channelId, channelName, err.Message)
|
||||
} else {
|
||||
monitor.Emit(channelId, false)
|
||||
ctrl.Emit(channelId, false)
|
||||
}
|
||||
}
|
||||
|
||||
|
5
go.mod
5
go.mod
@ -20,10 +20,13 @@ require (
|
||||
github.com/jinzhu/copier v0.4.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pkoukk/tiktoken-go v0.1.7
|
||||
github.com/sirupsen/logrus v1.8.1
|
||||
github.com/smartystreets/goconvey v1.8.1
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.17.1
|
||||
golang.org/x/crypto v0.23.0
|
||||
golang.org/x/image v0.16.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gorm.io/driver/mysql v1.5.6
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
gorm.io/driver/sqlite v1.5.5
|
||||
@ -73,6 +76,8 @@ require (
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/smarty/assertions v1.15.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
|
61
go.sum
61
go.sum
@ -1,40 +1,25 @@
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
|
||||
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
|
||||
github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo=
|
||||
github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ=
|
||||
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
||||
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
||||
github.com/bytedance/sonic v1.11.5 h1:G00FYjjqll5iQ1PYXynbg/hyzqBqavH8Mo9/oTopd9k=
|
||||
github.com/bytedance/sonic v1.11.5/go.mod h1:X2PC2giUdj/Cv2lliWFLk6c/DUQok5rViJSemeB0wDw=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.0/go.mod h1:UmRT+IRTGKz/DAkzcEGzyVqQFJ7H9BqwBO3pm9H/+HY=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cloudwego/base64x v0.1.3 h1:b5J/l8xolB7dyDTTmhJP2oTs5LdrjyrUFuNxdfq5hAg=
|
||||
github.com/cloudwego/base64x v0.1.3/go.mod h1:1+1K5BUHIQzyapgpF7LwvOGAEDicKtt1umPV+aN8pi8=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||
@ -51,26 +36,16 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/cors v1.7.1 h1:s9SIppU/rk8enVvkzwiC2VK3UZ/0NNGsWfUKvV55rqs=
|
||||
github.com/gin-contrib/cors v1.7.1/go.mod h1:n/Zj7B4xyrgk/cX1WCX2dkzFfaNm/xJb6oIUk7WTtps=
|
||||
github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
|
||||
github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
|
||||
github.com/gin-contrib/gzip v1.0.0 h1:UKN586Po/92IDX6ie5CWLgMI81obiIp5nSP85T3wlTk=
|
||||
github.com/gin-contrib/gzip v1.0.0/go.mod h1:CtG7tQrPB3vIBo6Gat9FVUsis+1emjvQqd66ME5TdnE=
|
||||
github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE=
|
||||
github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4=
|
||||
github.com/gin-contrib/sessions v1.0.0 h1:r5GLta4Oy5xo9rAwMHx8B4wLpeRGHMdz9NafzJAdP8Y=
|
||||
github.com/gin-contrib/sessions v1.0.0/go.mod h1:DN0f4bvpqMQElDdi+gNGScrP2QEI04IErRyMFyorUOI=
|
||||
github.com/gin-contrib/sessions v1.0.1 h1:3hsJyNs7v7N8OtelFmYXFrulAf6zSR7nW/putcPEHxI=
|
||||
github.com/gin-contrib/sessions v1.0.1/go.mod h1:ouxSFM24/OgIud5MJYQJLpy6AwxQ5EYO9yLhbtObGkM=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-contrib/static v1.1.1 h1:XEvBd4DDLG1HBlyPBQU1XO8NlTpw6mgdqcPteetYA5k=
|
||||
github.com/gin-contrib/static v1.1.1/go.mod h1:yRGmar7+JYvbMLRPIi4H5TVVSBwULfT9vetnVD0IO74=
|
||||
github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0NglqmlZ4=
|
||||
github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw=
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
@ -78,8 +53,6 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4=
|
||||
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
@ -87,8 +60,6 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
|
||||
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
@ -147,19 +118,17 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg=
|
||||
github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
|
||||
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
|
||||
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
|
||||
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
|
||||
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
|
||||
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
|
||||
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
|
||||
@ -168,6 +137,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
@ -176,46 +146,41 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U=
|
||||
github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc=
|
||||
golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
|
||||
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw=
|
||||
golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs=
|
||||
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@ -228,8 +193,6 @@ gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4c
|
||||
gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E=
|
||||
gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE=
|
||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gorm.io/gorm v1.25.9 h1:wct0gxZIELDk8+ZqF/MVnHLkA1rvYlBWUMv2EdsK1g8=
|
||||
gorm.io/gorm v1.25.9/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
|
||||
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
|
||||
|
32
monitor/monitor.go
Normal file
32
monitor/monitor.go
Normal file
@ -0,0 +1,32 @@
|
||||
package monitor
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
|
||||
type MonitorInstance interface {
|
||||
Emit(ChannelId int, success bool)
|
||||
ShouldDisableChannel(err *model.Error, statusCode int) bool
|
||||
DisableChannel(channelId int, channelName string, reason string)
|
||||
}
|
||||
|
||||
type defaultMonitor struct {
|
||||
}
|
||||
|
||||
func NewMonitorInstance() MonitorInstance {
|
||||
return &defaultMonitor{}
|
||||
}
|
||||
|
||||
func (m *defaultMonitor) Emit(channelId int, success bool) {
|
||||
if success {
|
||||
metricSuccessChan <- channelId
|
||||
} else {
|
||||
metricFailChan <- channelId
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultMonitor) ShouldDisableChannel(err *model.Error, statusCode int) bool {
|
||||
return ShouldDisableChannel(err, statusCode)
|
||||
}
|
||||
|
||||
func (m *defaultMonitor) DisableChannel(channelId int, channelName string, reason string) {
|
||||
DisableChannel(channelId, channelName, reason)
|
||||
}
|
160
relay/billing/billing-instance.go
Normal file
160
relay/billing/billing-instance.go
Normal file
@ -0,0 +1,160 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// Bookkeeper 记账员逻辑,用于处理用户的配额消费
|
||||
// 预扣配额检测逻辑:
|
||||
//
|
||||
// 开始请求前,根据不同的请求类型,预先计算需要消费的配额,根据请求用户和token的配额余量来判断是否有足够的配额来满足这个请求
|
||||
// 如果余量配额不能满足这个请求,直接返回错误, 如果余量配额可以满足这个请求,那么预先消费这个配额,然后开始请求, 如果余量远远超过这个请求,那么不需要预先消费配额
|
||||
// 由于预先计算的配额不是实际消费的配额,所以需要在请求结束后,根据实际消费的配额来更新用户和token的配额,退费或者扣费。
|
||||
type Bookkeeper interface {
|
||||
// 获取模型的费率
|
||||
ModelRatio(model string) float64
|
||||
// 获取组的费率
|
||||
GroupRation(group string) float64
|
||||
// 获取模型的补全费率
|
||||
ModelCompletionRatio(model string) float64
|
||||
// 根据消费记录,扣除用户,token 的配额
|
||||
Consume(ctx context.Context, consumeLog *ConsumeLog)
|
||||
// 预消费配额, 当用户配额不足时,预消费配额, 预消费成功返回预消费的配额,失败返回错误, 如果预消费的配额为0,表示用户有足够的配额
|
||||
PreConsumeQuota(ctx context.Context, preConsumedQuota int64, userId, tokenId int) (int64, *relaymodel.ErrorWithStatusCode)
|
||||
// 退回预消费的配额, 这通常在调用上游api失败的时候执行
|
||||
RefundQuota(ctx context.Context, preConsumedQuota int64, tokenId int)
|
||||
|
||||
// 检测用户是否有足够的配额
|
||||
// UserHasEnoughQuota(ctx context.Context, userID int, quota int64) bool
|
||||
// 检测用户是否有远远超过需求的配额, 如果用户的配额远远超过需求,那么不需要预消费配额
|
||||
// UserHasMuchMoreQuota(ctx context.Context, userID int, quota int64) bool
|
||||
}
|
||||
|
||||
type defaultBookkeeper struct {
|
||||
}
|
||||
|
||||
func NewBookkeeper() Bookkeeper {
|
||||
return &defaultBookkeeper{}
|
||||
}
|
||||
|
||||
func (b *defaultBookkeeper) ModelRatio(model string) float64 {
|
||||
return billingratio.GetModelRatio(model)
|
||||
}
|
||||
|
||||
func (b *defaultBookkeeper) GroupRation(group string) float64 {
|
||||
return billingratio.GetGroupRatio(group)
|
||||
}
|
||||
|
||||
func (b *defaultBookkeeper) ModelCompletionRatio(model string) float64 {
|
||||
return billingratio.GetCompletionRatio(model)
|
||||
}
|
||||
|
||||
func (b *defaultBookkeeper) Ratio(group, model string) float64 {
|
||||
modelRatio := billingratio.GetModelRatio(model)
|
||||
groupRatio := billingratio.GetGroupRatio(group)
|
||||
return modelRatio * groupRatio
|
||||
}
|
||||
|
||||
// ConsumeLog 消费记录实体
|
||||
type ConsumeLog struct {
|
||||
UserId int
|
||||
ChannelId int
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
ModelName string
|
||||
TokenId int
|
||||
TokenName string
|
||||
Quota int64
|
||||
Content string
|
||||
PreConsumedQuota int64
|
||||
}
|
||||
|
||||
func (b *defaultBookkeeper) UserHasEnoughQuota(ctx context.Context, userID int, quota int64) bool {
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, userID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return userQuota >= quota
|
||||
}
|
||||
|
||||
func (b *defaultBookkeeper) UserHasMuchMoreQuota(ctx context.Context, userID int, quota int64) bool {
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, userID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return userQuota > 100*quota
|
||||
}
|
||||
|
||||
func (b *defaultBookkeeper) Consume(ctx context.Context, consumeLog *ConsumeLog) {
|
||||
// 更新 access_token 的配额
|
||||
quotaDelta := consumeLog.Quota - consumeLog.PreConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(consumeLog.TokenId, quotaDelta)
|
||||
if err != nil {
|
||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(ctx, consumeLog.UserId)
|
||||
if err != nil {
|
||||
logger.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
// 更新用户的配额
|
||||
model.UpdateUserUsedQuotaAndRequestCount(consumeLog.UserId, consumeLog.Quota)
|
||||
// 更新渠道的配额
|
||||
model.UpdateChannelUsedQuota(consumeLog.ChannelId, consumeLog.Quota)
|
||||
// 记录消费日志
|
||||
model.RecordConsumeLog(
|
||||
ctx,
|
||||
consumeLog.UserId,
|
||||
consumeLog.ChannelId,
|
||||
consumeLog.PromptTokens,
|
||||
consumeLog.CompletionTokens,
|
||||
consumeLog.ModelName,
|
||||
consumeLog.TokenName,
|
||||
consumeLog.Quota,
|
||||
consumeLog.Content,
|
||||
)
|
||||
}
|
||||
|
||||
func (b *defaultBookkeeper) PreConsumeQuota(ctx context.Context, preConsumedQuota int64, userId, tokenId int) (int64, *relaymodel.ErrorWithStatusCode) {
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, userId)
|
||||
if err != nil {
|
||||
return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
return preConsumedQuota, nil
|
||||
}
|
||||
|
||||
func (b *defaultBookkeeper) RefundQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
|
||||
if preConsumedQuota != 0 {
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
@ -3,10 +3,12 @@ package billing
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
)
|
||||
|
||||
// ReturnPreConsumedQuota 在请求失败的时候,退回预消费的配额
|
||||
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
|
||||
if preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
|
@ -7,36 +7,35 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/audit"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
ctx := c.Request.Context()
|
||||
func (rl *defaultRelay) RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
meta := meta.GetByContext(c)
|
||||
audioModel := "whisper-1"
|
||||
|
||||
tokenId := c.GetInt(ctxkey.TokenId)
|
||||
channelType := c.GetInt(ctxkey.Channel)
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
// channelId := c.GetInt(ctxkey.ChannelId)
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
group := c.GetString(ctxkey.Group)
|
||||
tokenName := c.GetString(ctxkey.TokenName)
|
||||
// tokenName := c.GetString(ctxkey.TokenName)
|
||||
|
||||
var ttsRequest openai.TextToSpeechRequest
|
||||
if relayMode == relaymode.AudioSpeech {
|
||||
@ -53,58 +52,45 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
}
|
||||
|
||||
modelRatio := billingratio.GetModelRatio(audioModel)
|
||||
groupRatio := billingratio.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
var quota int64
|
||||
var preConsumedQuota int64
|
||||
switch relayMode {
|
||||
case relaymode.AudioSpeech:
|
||||
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
|
||||
quota = preConsumedQuota
|
||||
default:
|
||||
preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio)
|
||||
}
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, userId)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
var (
|
||||
modelRatio float64
|
||||
groupRatio float64
|
||||
ratio float64
|
||||
quota int64
|
||||
preConsumeQuota int64
|
||||
preConsumedQuota int64
|
||||
bizErr *relaymodel.ErrorWithStatusCode
|
||||
)
|
||||
|
||||
if rl.Bookkeeper != nil {
|
||||
modelRatio = rl.ModelRatio(audioModel)
|
||||
groupRatio = rl.GroupRation(group)
|
||||
ratio = modelRatio * groupRatio
|
||||
}
|
||||
|
||||
// Check if user quota is enough
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
switch relayMode {
|
||||
// speech 类型,消费的配额直接根据输入的文本长度计算
|
||||
case relaymode.AudioSpeech:
|
||||
preConsumeQuota = int64(float64(len(ttsRequest.Input)) * ratio)
|
||||
quota = preConsumeQuota
|
||||
// 其他类型,假设消费的配额是预设的配额的 ratio 倍
|
||||
default:
|
||||
preConsumeQuota = int64(float64(config.PreConsumedQuota) * ratio)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
if rl.Bookkeeper != nil {
|
||||
preConsumedQuota, bizErr = rl.PreConsumeQuota(c, preConsumeQuota, userId, tokenId)
|
||||
if bizErr != nil {
|
||||
return bizErr
|
||||
}
|
||||
}
|
||||
|
||||
succeed := false
|
||||
defer func() {
|
||||
if succeed {
|
||||
return
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
// we need to roll back the pre-consumed quota
|
||||
defer func(ctx context.Context) {
|
||||
go func() {
|
||||
// negative means add quota back for token & user
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
if rl.Bookkeeper != nil {
|
||||
rl.Bookkeeper.RefundQuota(c.Request.Context(), preConsumedQuota, tokenId)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -140,8 +126,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
|
||||
requestBody := &bytes.Buffer{}
|
||||
_, err = io.Copy(requestBody, c.Request.Body)
|
||||
if err != nil {
|
||||
if _, err := io.Copy(requestBody, c.Request.Body); err != nil {
|
||||
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
|
||||
@ -151,6 +136,14 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if config.UpstreamAuditEnabled {
|
||||
audit.Logger().
|
||||
WithField("stage", "upstream request").
|
||||
WithField("raw", audit.B64encode(requestBody.Bytes())).
|
||||
WithField("requestid", c.GetString(helper.RequestIdKey)).
|
||||
WithFields(meta.ToLogrusFields()).
|
||||
Info("upstream request")
|
||||
}
|
||||
|
||||
if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
@ -168,6 +161,17 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if config.UpstreamAuditEnabled {
|
||||
buf := audit.CaptureHTTPResponseBody(resp)
|
||||
defer func() {
|
||||
audit.Logger().
|
||||
WithField("stage", "upstream response").
|
||||
WithField("raw", audit.B64encode(buf.Bytes())).
|
||||
WithField("requestid", c.GetString(helper.RequestIdKey)).
|
||||
WithFields(meta.ToLogrusFields()).
|
||||
Info("upstream response")
|
||||
}()
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
@ -220,9 +224,28 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
return RelayErrorHandler(resp)
|
||||
}
|
||||
succeed = true
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
// quotaDelta := quota - preConsumedQuota
|
||||
defer func(ctx context.Context) {
|
||||
go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||
// go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||
// post-consume quota
|
||||
if rl.Bookkeeper != nil {
|
||||
// go postConsumeQuota(c, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
|
||||
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
consumeLog := &billing.ConsumeLog{
|
||||
UserId: meta.UserId,
|
||||
ChannelId: meta.ChannelId,
|
||||
ModelName: audioModel,
|
||||
TokenName: c.GetString(ctxkey.TokenName),
|
||||
TokenId: meta.TokenId,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
PromptTokens: int(preConsumeQuota),
|
||||
CompletionTokens: 0,
|
||||
PreConsumedQuota: preConsumedQuota,
|
||||
}
|
||||
rl.Bookkeeper.Consume(c, consumeLog)
|
||||
}
|
||||
}(c.Request.Context())
|
||||
|
||||
for k, v := range resp.Header {
|
||||
|
30
relay/controller/controller.go
Normal file
30
relay/controller/controller.go
Normal file
@ -0,0 +1,30 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
EnableBilling bool
|
||||
}
|
||||
|
||||
// RelayInstance is the interface for relay controller
|
||||
type RelayInstance interface {
|
||||
RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode
|
||||
RelayImageHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode
|
||||
RelayAudioHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode
|
||||
}
|
||||
|
||||
type defaultRelay struct {
|
||||
billing.Bookkeeper
|
||||
}
|
||||
|
||||
func NewRelayInstance(opts Options) RelayInstance {
|
||||
relay := &defaultRelay{}
|
||||
if opts.EnableBilling {
|
||||
relay.Bookkeeper = billing.NewBookkeeper()
|
||||
}
|
||||
return relay
|
||||
}
|
@ -4,6 +4,10 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
@ -16,9 +20,6 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
|
||||
@ -208,10 +209,7 @@ func getMappedModelName(modelName string, mapping map[string]string) (string, bo
|
||||
|
||||
func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
|
||||
if resp == nil {
|
||||
if meta.ChannelType == channeltype.AwsClaude {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return meta.ChannelType != channeltype.Azure
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return true
|
||||
|
@ -4,20 +4,23 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/audit"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func isWithinRange(element string, value int) bool {
|
||||
@ -29,33 +32,40 @@ func isWithinRange(element string, value int) bool {
|
||||
return value >= min && value <= max
|
||||
}
|
||||
|
||||
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
ctx := c.Request.Context()
|
||||
func (rl *defaultRelay) RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
meta := meta.GetByContext(c)
|
||||
imageRequest, err := getImageRequest(c, meta.Mode)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
|
||||
logger.Errorf(c, "getImageRequest failed: %s", err.Error())
|
||||
return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
var isModelMapped bool
|
||||
var (
|
||||
isModelMapped bool
|
||||
preConsumeQuota int64
|
||||
preConsumedQuota int64
|
||||
imageCostRatio float64
|
||||
bizErr *relaymodel.ErrorWithStatusCode
|
||||
)
|
||||
meta.OriginModelName = imageRequest.Model
|
||||
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
|
||||
meta.ActualModelName = imageRequest.Model
|
||||
|
||||
// model validation
|
||||
bizErr := validateImageRequest(imageRequest, meta)
|
||||
bizErr = validateImageRequest(imageRequest, meta)
|
||||
if bizErr != nil {
|
||||
return bizErr
|
||||
}
|
||||
|
||||
imageCostRatio, err := getImageCostRatio(imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
|
||||
if rl.Bookkeeper != nil {
|
||||
imageCostRatio, err = getImageCostRatio(imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
imageModel := imageRequest.Model
|
||||
originModel := imageRequest.Model
|
||||
// Convert the original image model
|
||||
imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName)
|
||||
c.Set("response_format", imageRequest.ResponseFormat)
|
||||
@ -94,51 +104,89 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
|
||||
modelRatio := billingratio.GetModelRatio(imageModel)
|
||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||
|
||||
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||
|
||||
if userQuota-quota < 0 {
|
||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
if rl.Bookkeeper != nil {
|
||||
modelRatio := rl.ModelRatio(originModel)
|
||||
groupRatio := rl.GroupRation(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumeQuota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||
preConsumedQuota, bizErr = rl.PreConsumeQuota(c, preConsumeQuota, meta.UserId, meta.TokenId)
|
||||
if bizErr != nil {
|
||||
logger.Warnf(c, "preConsumeQuota failed: %+v", *bizErr)
|
||||
return bizErr
|
||||
}
|
||||
}
|
||||
|
||||
refund := func() {
|
||||
if rl.Bookkeeper != nil && preConsumedQuota > 0 {
|
||||
rl.RefundQuota(c, preConsumedQuota, meta.TokenId)
|
||||
}
|
||||
}
|
||||
if config.UpstreamAuditEnabled {
|
||||
buf := bytes.Buffer{}
|
||||
requestBody = io.TeeReader(requestBody, &buf)
|
||||
defer func() {
|
||||
audit.Logger().
|
||||
WithField("stage", "upstream request").
|
||||
WithField("raw", audit.B64encode(buf.Bytes())).
|
||||
WithField("requestid", c.GetString(helper.RequestIdKey)).
|
||||
WithFields(meta.ToLogrusFields()).
|
||||
Info("upstream request")
|
||||
}()
|
||||
}
|
||||
// do request
|
||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
|
||||
logger.Errorf(c, "DoRequest failed: %s", err.Error())
|
||||
refund()
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if config.UpstreamAuditEnabled {
|
||||
buf := audit.CaptureHTTPResponseBody(resp)
|
||||
defer func() {
|
||||
audit.Logger().
|
||||
WithField("stage", "upstream response").
|
||||
WithField("raw", audit.B64encode(buf.Bytes())).
|
||||
WithField("requestid", c.GetString(helper.RequestIdKey)).
|
||||
WithFields(meta.ToLogrusFields()).
|
||||
Info("upstream response")
|
||||
}()
|
||||
}
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
if rl.Bookkeeper == nil {
|
||||
return
|
||||
}
|
||||
modelRatio := rl.ModelRatio(originModel)
|
||||
groupRatio := rl.GroupRation(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
consumedQuota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||
|
||||
err := model.PostConsumeTokenQuota(meta.TokenId, quota)
|
||||
if err != nil {
|
||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(ctx, meta.UserId)
|
||||
if err != nil {
|
||||
logger.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
if consumedQuota != 0 {
|
||||
tokenName := c.GetString(ctxkey.TokenName)
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
consumeLog := &billing.ConsumeLog{
|
||||
UserId: meta.UserId,
|
||||
ChannelId: meta.ChannelId,
|
||||
ModelName: imageRequest.Model,
|
||||
TokenName: tokenName,
|
||||
TokenId: meta.TokenId,
|
||||
Quota: consumedQuota,
|
||||
Content: logContent,
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
PreConsumedQuota: preConsumedQuota,
|
||||
}
|
||||
rl.Bookkeeper.Consume(c, consumeLog)
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}(c)
|
||||
|
||||
// do response
|
||||
_, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
|
||||
logger.Errorf(c, "respErr is not nil: %+v", respErr)
|
||||
return respErr
|
||||
}
|
||||
|
||||
|
@ -4,27 +4,30 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/audit"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
ctx := c.Request.Context()
|
||||
func (rl *defaultRelay) RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
meta := meta.GetByContext(c)
|
||||
// get & validate textRequest
|
||||
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
|
||||
logger.Errorf(c, "getAndValidateTextRequest failed: %s", err.Error())
|
||||
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
meta.IsStream = textRequest.Stream
|
||||
@ -35,18 +38,26 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||
meta.ActualModelName = textRequest.Model
|
||||
// get model ratio & group ratio
|
||||
modelRatio := billingratio.GetModelRatio(textRequest.Model)
|
||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
// pre-consume quota
|
||||
promptTokens := getPromptTokens(textRequest, meta.Mode)
|
||||
meta.PromptTokens = promptTokens
|
||||
preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
|
||||
if bizErr != nil {
|
||||
logger.Warnf(ctx, "preConsumeQuota failed: %+v", *bizErr)
|
||||
return bizErr
|
||||
var (
|
||||
preConsumedQuota int64
|
||||
modelRatio float64
|
||||
groupRatio float64
|
||||
ratio float64
|
||||
)
|
||||
if rl.Bookkeeper != nil {
|
||||
modelRatio = rl.ModelRatio(textRequest.Model)
|
||||
groupRatio = rl.GroupRation(meta.Group)
|
||||
ratio = modelRatio * groupRatio
|
||||
// pre-consume quota
|
||||
meta.PromptTokens = getPromptTokens(textRequest, meta.Mode)
|
||||
preConsumeQuota := getPreConsumedQuota(textRequest, meta.PromptTokens, ratio)
|
||||
consumedQuota, bizErr := rl.PreConsumeQuota(c, preConsumeQuota, meta.UserId, meta.TokenId)
|
||||
if bizErr != nil {
|
||||
logger.Warnf(c, "preConsumeQuota failed: %+v", *bizErr)
|
||||
return bizErr
|
||||
}
|
||||
preConsumedQuota = consumedQuota
|
||||
}
|
||||
|
||||
adaptor := relay.GetAdaptor(meta.APIType)
|
||||
if adaptor == nil {
|
||||
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
|
||||
@ -76,29 +87,75 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
|
||||
logger.Debugf(c, "converted request: \n%s", string(jsonData))
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
if config.UpstreamAuditEnabled {
|
||||
buf := bytes.Buffer{}
|
||||
requestBody = io.TeeReader(requestBody, &buf)
|
||||
defer func() {
|
||||
audit.Logger().
|
||||
WithField("stage", "upstream request").
|
||||
WithField("raw", audit.B64encode(buf.Bytes())).
|
||||
WithField("requestid", c.GetString(helper.RequestIdKey)).
|
||||
WithFields(meta.ToLogrusFields()).
|
||||
Info("upstream request")
|
||||
}()
|
||||
}
|
||||
|
||||
// do request
|
||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
|
||||
logger.Errorf(c, "DoRequest failed: %s", err.Error())
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if config.UpstreamAuditEnabled {
|
||||
buf := audit.CaptureHTTPResponseBody(resp)
|
||||
defer func() {
|
||||
audit.Logger().
|
||||
WithField("stage", "upstream response").
|
||||
WithField("raw", audit.B64encode(buf.Bytes())).
|
||||
WithField("requestid", c.GetString(helper.RequestIdKey)).
|
||||
WithFields(meta.ToLogrusFields()).
|
||||
Info("upstream response")
|
||||
}()
|
||||
}
|
||||
refund := func() {
|
||||
if rl.Bookkeeper != nil && preConsumedQuota > 0 {
|
||||
rl.RefundQuota(c, preConsumedQuota, meta.TokenId)
|
||||
}
|
||||
}
|
||||
if isErrorHappened(meta, resp) {
|
||||
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
refund()
|
||||
return RelayErrorHandler(resp)
|
||||
}
|
||||
|
||||
// do response
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
|
||||
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
logger.Errorf(c, "respErr is not nil: %+v", respErr)
|
||||
refund()
|
||||
return respErr
|
||||
}
|
||||
// post-consume quota
|
||||
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
|
||||
if rl.Bookkeeper != nil {
|
||||
// go postConsumeQuota(c, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
|
||||
completionRatio := rl.ModelCompletionRatio(textRequest.Model)
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
||||
consumeLog := &billing.ConsumeLog{
|
||||
UserId: meta.UserId,
|
||||
ChannelId: meta.ChannelId,
|
||||
ModelName: textRequest.Model,
|
||||
TokenName: c.GetString(ctxkey.TokenName),
|
||||
TokenId: meta.TokenId,
|
||||
Quota: usage.Quota(completionRatio, ratio),
|
||||
Content: logContent,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
PreConsumedQuota: preConsumedQuota,
|
||||
}
|
||||
rl.Bookkeeper.Consume(c, consumeLog)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -1,12 +1,13 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Meta struct {
|
||||
@ -29,6 +30,29 @@ type Meta struct {
|
||||
PromptTokens int // only for DoResponse
|
||||
}
|
||||
|
||||
func (m *Meta) ToLogrusFields() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"mode": m.Mode,
|
||||
"channel_type": m.ChannelType,
|
||||
"channel_id": m.ChannelId,
|
||||
"token_id": m.TokenId,
|
||||
"token_name": m.TokenName,
|
||||
"user_id": m.UserId,
|
||||
"group": m.Group,
|
||||
"model_mapping": m.ModelMapping,
|
||||
"base_url": m.BaseURL,
|
||||
"api_key": m.APIKey,
|
||||
"api_type": m.APIType,
|
||||
"config": m.Config,
|
||||
"is_stream": m.IsStream,
|
||||
"origin_model_name": m.OriginModelName,
|
||||
"actual_model_name": m.ActualModelName,
|
||||
"request_url_path": m.RequestURLPath,
|
||||
"prompt_tokens": m.PromptTokens,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func GetByContext(c *gin.Context) *Meta {
|
||||
meta := Meta{
|
||||
Mode: relaymode.GetByPath(c.Request.URL.Path),
|
||||
|
@ -1,11 +1,21 @@
|
||||
package model
|
||||
|
||||
import "math"
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
func (u *Usage) Quota(completionRatio, finalRatio float64) int64 {
|
||||
quota := int64(math.Ceil((float64(u.PromptTokens) + float64(u.CompletionTokens)*completionRatio) * finalRatio))
|
||||
if finalRatio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
return quota
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
|
@ -1,6 +1,7 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/controller"
|
||||
"github.com/songquanpeng/one-api/middleware"
|
||||
|
||||
@ -17,19 +18,26 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
modelsRouter.GET("/:model", controller.RetrieveModel)
|
||||
}
|
||||
relayV1Router := router.Group("/v1")
|
||||
opt := controller.Options{
|
||||
EnableMonitor: config.EnableMetric,
|
||||
EnableBilling: config.EnableBilling,
|
||||
Debug: config.DebugEnabled,
|
||||
}
|
||||
ctrl := controller.NewRelayController(opt)
|
||||
|
||||
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
relayV1Router.POST("/completions", controller.Relay)
|
||||
relayV1Router.POST("/chat/completions", controller.Relay)
|
||||
relayV1Router.POST("/edits", controller.Relay)
|
||||
relayV1Router.POST("/images/generations", controller.Relay)
|
||||
relayV1Router.POST("/completions", ctrl.Relay)
|
||||
relayV1Router.POST("/chat/completions", ctrl.Relay)
|
||||
relayV1Router.POST("/edits", ctrl.Relay)
|
||||
relayV1Router.POST("/images/generations", ctrl.Relay)
|
||||
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/embeddings", controller.Relay)
|
||||
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
||||
relayV1Router.POST("/audio/transcriptions", controller.Relay)
|
||||
relayV1Router.POST("/audio/translations", controller.Relay)
|
||||
relayV1Router.POST("/audio/speech", controller.Relay)
|
||||
relayV1Router.POST("/embeddings", ctrl.Relay)
|
||||
relayV1Router.POST("/engines/:model/embeddings", ctrl.Relay)
|
||||
relayV1Router.POST("/audio/transcriptions", ctrl.Relay)
|
||||
relayV1Router.POST("/audio/translations", ctrl.Relay)
|
||||
relayV1Router.POST("/audio/speech", ctrl.Relay)
|
||||
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/files", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
||||
@ -41,7 +49,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/moderations", controller.Relay)
|
||||
relayV1Router.POST("/moderations", ctrl.Relay)
|
||||
relayV1Router.POST("/assistants", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented)
|
||||
|
Loading…
Reference in New Issue
Block a user