From 04cabfac181aae4d6ed383861fe9a8e37e16a953 Mon Sep 17 00:00:00 2001 From: tangfei-china Date: Mon, 30 Oct 2023 21:05:32 +0800 Subject: [PATCH] feat: You can save chat content while using the OpenAI model. --- controller/relay-text.go | 6 ++++++ model/log-text.go | 31 +++++++++++++++++++++++++++++++ model/main.go | 4 ++++ 3 files changed, 41 insertions(+) create mode 100644 model/log-text.go diff --git a/controller/relay-text.go b/controller/relay-text.go index 25b8bc06..91b41b2a 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -198,10 +198,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) } var promptTokens int + var prompt string var completionTokens int switch relayMode { case RelayModeChatCompletions: promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) + prompt = textRequest.Messages[len(textRequest.Messages)-1].Content case RelayModeCompletions: promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) case RelayModeModerations: @@ -406,6 +408,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { var textResponse TextResponse tokenName := c.GetString("token_name") + var response string + defer func(ctx context.Context) { // c.Writer.Flush() go func() { @@ -438,6 +442,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if quota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + model.RecordConsumeText(userId, tokenName, prompt, response) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateChannelUsedQuota(channelId, quota) } @@ -453,6 +458,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } textResponse.Usage.PromptTokens = promptTokens textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + response = responseText return nil } else { err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) diff --git a/model/log-text.go b/model/log-text.go new file mode 100644 index 00000000..c9698522 --- /dev/null +++ b/model/log-text.go @@ -0,0 +1,31 @@ +package model + +import ( + "one-api/common" +) + +type LogText struct { + Id int `json:"id"` + UserId int `json:"user_id" gorm:"index"` + CreatedAt int64 `json:"created_at" gorm:"index"` + Username string `json:"username" gorm:"index;default:''"` + TokenName string `json:"token_name" gorm:"index;default:''"` + Prompt string `json:"prompt" gorm:"type:text"` + Completion string `json:"completion" gorm:"type:text"` +} + +func RecordConsumeText(userId int, token string, prompt string, completion string) { + + text := &LogText{ + UserId: userId, + Username: GetUsernameById(userId), + CreatedAt: common.GetTimestamp(), + TokenName: token, + Prompt: prompt, + Completion: completion, + } + err := DB.Create(text).Error + if err != nil { + common.SysError("failed to record text: " + err.Error()) + } +} diff --git a/model/main.go b/model/main.go index 08182634..4cdafaff 100644 --- a/model/main.go +++ b/model/main.go @@ -111,6 +111,10 @@ func InitDB() (err error) { if err != nil { return err } + err = db.AutoMigrate(&LogText{}) + if err != nil { + return err + } common.SysLog("database migrated") err = createRootAccountIfNeed() return err