From 563b6e123f48fdb95eb5a17b163aa9ce7e8fdcb7 Mon Sep 17 00:00:00 2001 From: "yu.deng" Date: Thu, 13 Jun 2024 18:03:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AE=A1=E8=AE=A1=E5=BC=80?= =?UTF-8?q?=E5=85=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/audit/audit.go | 27 +++++++++++++ common/audit/response.go | 79 +++++++++++++++++++++++++++++++++++++++ common/config/config.go | 2 + controller/relay.go | 23 ++++++++++++ go.mod | 5 +++ go.sum | 12 ++++++ relay/controller/audio.go | 21 +++++++++++ relay/controller/image.go | 26 +++++++++++++ relay/controller/text.go | 27 +++++++++++++ relay/meta/relay_meta.go | 23 ++++++++++++ 10 files changed, 245 insertions(+) create mode 100644 common/audit/audit.go create mode 100644 common/audit/response.go diff --git a/common/audit/audit.go b/common/audit/audit.go new file mode 100644 index 00000000..2af7d5f8 --- /dev/null +++ b/common/audit/audit.go @@ -0,0 +1,27 @@ +package audit + +import ( + "github.com/sirupsen/logrus" + "gopkg.in/natefinch/lumberjack.v2" +) + +var ( + loger *lumberjack.Logger + logger *logrus.Logger +) + +func init() { + loger = &lumberjack.Logger{ + Filename: "logs/audit.log", + MaxSize: 50, // megabytes + MaxBackups: 300, + MaxAge: 90, // days + } + logger = logrus.New() + logger.SetOutput(loger) + logger.SetFormatter(&logrus.JSONFormatter{}) +} + +func Logger() *logrus.Logger { + return logger +} diff --git a/common/audit/response.go b/common/audit/response.go new file mode 100644 index 00000000..c3cd64fc --- /dev/null +++ b/common/audit/response.go @@ -0,0 +1,79 @@ +package audit + +import ( + "bytes" + "encoding/base64" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +type AuditLogger struct { + gin.ResponseWriter + buf *bytes.Buffer +} + +func (l *AuditLogger) Write(p []byte) (int, error) { + l.buf.Write(p) + return l.ResponseWriter.Write(p) +} + +func CaptureResponseBody(c *gin.Context) *bytes.Buffer { + al := &AuditLogger{ + ResponseWriter: c.Writer, + buf: &bytes.Buffer{}, + } + c.Writer = al + return al.buf +} + +func B64encode(data []byte) string { + return base64.StdEncoding.EncodeToString(data) +} + +type AuditReadCloser struct { + Reader io.Reader + Closer io.Closer + Buffer *bytes.Buffer +} + +func (arc *AuditReadCloser) Read(p []byte) (int, error) { + n, err := arc.Reader.Read(p) + if n > 0 { + arc.Buffer.Write(p[:n]) + } + return n, err +} + +func (arc *AuditReadCloser) Close() error { + return arc.Closer.Close() +} + +func CaptureHTTPResponseBody(resp *http.Response) *bytes.Buffer { + buf := &bytes.Buffer{} + arc := &AuditReadCloser{ + Reader: resp.Body, + Closer: resp.Body, + Buffer: buf, + } + resp.Body = arc + return buf +} + +func ParseOPENAIStreamResponse(buf *bytes.Buffer) string { + lines := strings.Split(buf.String(), "\n") + bts := []string{} + for _, line := range lines { + line = strings.TrimSpace(line) + line = strings.Trim(line, "\n") + if strings.HasPrefix(string(line), "data:") { + line = line[5:] + } + content := gjson.Get(line, "choices.0.delta.content").String() + bts = append(bts, content) + } + return strings.Join(bts, "") +} diff --git a/common/config/config.go b/common/config/config.go index c4edecdd..9423f37a 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -56,6 +56,8 @@ var EmailDomainWhitelist = []string{ var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true" var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true" var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true" +var ClientAuditEnabled = env.Bool("CLIENT_AUDIT_ENABLED", false) +var UpstreamAuditEnabled = env.Bool("UPSTREAM_AUDIT_ENABLED", false) var LogConsumeEnabled = true diff --git a/controller/relay.go b/controller/relay.go index 048e694c..77cf9df6 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/audit" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" @@ -17,6 +18,7 @@ import ( dbmodel "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/controller" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" ) @@ -49,6 +51,18 @@ func NewRelayController(opts Options) *RelayController { } func (ctrl *RelayController) relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { + if config.ClientAuditEnabled { + buf := audit.CaptureResponseBody(c) + m := meta.GetByContext(c) + defer func() { + audit.Logger(). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("parsed", audit.ParseOPENAIStreamResponse(buf)). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(m.ToLogrusFields()). + Info("client response") + }() + } var err *model.ErrorWithStatusCode switch relayMode { case relaymode.ImagesGenerations: @@ -72,6 +86,15 @@ func (ctrl *RelayController) Relay(c *gin.Context) { requestBody, _ := common.GetRequestBody(c) logger.Debugf(ctx, "request body: %s", string(requestBody)) } + if config.ClientAuditEnabled { + requestBody, _ := common.GetRequestBody(c) + m := meta.GetByContext(c) + audit.Logger(). + WithField("raw", audit.B64encode(requestBody)). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(m.ToLogrusFields()). + Info("client request") + } channelId := c.GetInt(ctxkey.ChannelId) bizErr := ctrl.relayHelper(c, relayMode) if bizErr == nil { diff --git a/go.mod b/go.mod index 1ed937ae..8fb66d83 100644 --- a/go.mod +++ b/go.mod @@ -20,10 +20,13 @@ require ( github.com/jinzhu/copier v0.4.0 github.com/pkg/errors v0.9.1 github.com/pkoukk/tiktoken-go v0.1.7 + github.com/sirupsen/logrus v1.8.1 github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.9.0 + github.com/tidwall/gjson v1.17.1 golang.org/x/crypto v0.23.0 golang.org/x/image v0.16.0 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 @@ -73,6 +76,8 @@ require ( github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/smarty/assertions v1.15.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.8.0 // indirect diff --git a/go.sum b/go.sum index 117fb587..56ecf850 100644 --- a/go.sum +++ b/go.sum @@ -127,6 +127,8 @@ github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYde github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= +github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= +github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= @@ -135,6 +137,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -143,6 +146,12 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= +github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= @@ -158,6 +167,7 @@ golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= @@ -169,6 +179,8 @@ google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFW google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 613700a4..f500a95b 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -13,9 +13,11 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/audit" "github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/billing" "github.com/songquanpeng/one-api/relay/channeltype" @@ -134,6 +136,14 @@ func (rl *defaultRelay) RelayAudioHelper(c *gin.Context, relayMode int) *relaymo if err != nil { return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } + if config.UpstreamAuditEnabled { + audit.Logger(). + WithField("stage", "upstream request"). + WithField("raw", audit.B64encode(requestBody.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream request") + } if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api @@ -151,6 +161,17 @@ func (rl *defaultRelay) RelayAudioHelper(c *gin.Context, relayMode int) *relaymo if err != nil { return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } + if config.UpstreamAuditEnabled { + buf := audit.CaptureHTTPResponseBody(resp) + defer func() { + audit.Logger(). + WithField("stage", "upstream response"). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream response") + }() + } err = req.Body.Close() if err != nil { diff --git a/relay/controller/image.go b/relay/controller/image.go index 5113fcab..eab3651b 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -9,7 +9,10 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/audit" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -118,6 +121,18 @@ func (rl *defaultRelay) RelayImageHelper(c *gin.Context, relayMode int) *relaymo rl.RefundQuota(c, preConsumedQuota, meta.TokenId) } } + if config.UpstreamAuditEnabled { + buf := bytes.Buffer{} + requestBody = io.TeeReader(requestBody, &buf) + defer func() { + audit.Logger(). + WithField("stage", "upstream request"). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream request") + }() + } // do request resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { @@ -125,6 +140,17 @@ func (rl *defaultRelay) RelayImageHelper(c *gin.Context, relayMode int) *relaymo refund() return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } + if config.UpstreamAuditEnabled { + buf := audit.CaptureHTTPResponseBody(resp) + defer func() { + audit.Logger(). + WithField("stage", "upstream response"). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream response") + }() + } defer func(ctx context.Context) { if resp != nil && resp.StatusCode != http.StatusOK { diff --git a/relay/controller/text.go b/relay/controller/text.go index 01d482e7..272d6195 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -8,7 +8,10 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/audit" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -88,12 +91,36 @@ func (rl *defaultRelay) RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCo requestBody = bytes.NewBuffer(jsonData) } + if config.UpstreamAuditEnabled { + buf := bytes.Buffer{} + requestBody = io.TeeReader(requestBody, &buf) + defer func() { + audit.Logger(). + WithField("stage", "upstream request"). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream request") + }() + } + // do request resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { logger.Errorf(c, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } + if config.UpstreamAuditEnabled { + buf := audit.CaptureHTTPResponseBody(resp) + defer func() { + audit.Logger(). + WithField("stage", "upstream response"). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream response") + }() + } refund := func() { if rl.Bookkeeper != nil && preConsumedQuota > 0 { rl.RefundQuota(c, preConsumedQuota, meta.TokenId) diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index 23272f1f..032c2a11 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -30,6 +30,29 @@ type Meta struct { PromptTokens int // only for DoResponse } +func (m *Meta) ToLogrusFields() map[string]interface{} { + return map[string]interface{}{ + "mode": m.Mode, + "channel_type": m.ChannelType, + "channel_id": m.ChannelId, + "token_id": m.TokenId, + "token_name": m.TokenName, + "user_id": m.UserId, + "group": m.Group, + "model_mapping": m.ModelMapping, + "base_url": m.BaseURL, + "api_key": m.APIKey, + "api_type": m.APIType, + "config": m.Config, + "is_stream": m.IsStream, + "origin_model_name": m.OriginModelName, + "actual_model_name": m.ActualModelName, + "request_url_path": m.RequestURLPath, + "prompt_tokens": m.PromptTokens, + } + +} + func GetByContext(c *gin.Context) *Meta { meta := Meta{ Mode: relaymode.GetByPath(c.Request.URL.Path),