From fc9601f0f9963f64d2030b4053feee64636d22ae Mon Sep 17 00:00:00 2001 From: forry Date: Sun, 25 Feb 2024 00:28:19 +0800 Subject: [PATCH] aliqwen: add memory --- common/config/config.go | 1 + common/redis.go | 20 ++++++++++ model/main.go | 2 +- relay/channel/aiproxy/adaptor.go | 9 +++++ relay/channel/ali/adaptor.go | 13 ++++++- relay/channel/ali/main.go | 5 ++- relay/channel/anthropic/adaptor.go | 8 ++++ relay/channel/baidu/adaptor.go | 9 +++++ relay/channel/gemini/adaptor.go | 8 ++++ relay/channel/interface.go | 1 + relay/channel/openai/adaptor.go | 11 +++++- relay/channel/palm/adaptor.go | 8 ++++ relay/channel/tencent/adaptor.go | 3 +- relay/channel/tencent/main.go | 7 ++++ relay/channel/xunfei/adaptor.go | 10 ++++- relay/channel/zhipu/adaptor.go | 9 +++++ relay/controller/text.go | 61 ++++++++++++++++++++++++++++++ relay/util/relay_meta.go | 4 ++ start.bat | 5 ++- 19 files changed, 186 insertions(+), 8 deletions(-) diff --git a/common/config/config.go b/common/config/config.go index dd0236b4..8c8b04c1 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -125,3 +125,4 @@ var ( ) var RateLimitKeyExpirationDuration = 20 * time.Minute +var MemoryMaxNum = helper.GetOrDefaultEnvInt("MEMORY_MAX_NUM", 40) diff --git a/common/redis.go b/common/redis.go index f3205567..958f4398 100644 --- a/common/redis.go +++ b/common/redis.go @@ -67,3 +67,23 @@ func RedisDecrease(key string, value int64) error { ctx := context.Background() return RDB.DecrBy(ctx, key, value).Err() } + +func RedisLRange(key string, s, e int64) []string { + ctx := context.Background() + return RDB.LRange(ctx, key, s, e).Val() +} + +func RedisLLen(key string) int64 { + ctx := context.Background() + return RDB.LLen(ctx, key).Val() +} + +func RedisLPush(key, v string) error { + ctx := context.Background() + return RDB.LPush(ctx, key, v).Err() +} + +func RedisRPop(key, v string) string { + ctx := context.Background() + return RDB.RPop(ctx, key).Val() +} diff --git a/model/main.go b/model/main.go index 18ed01d0..1c03331e 100644 --- a/model/main.go +++ b/model/main.go @@ -55,7 +55,7 @@ func chooseDB() (*gorm.DB, error) { }) } // Use MySQL - logger.SysLog("using MySQL as database") + logger.SysLog("using MySQL as database" + dsn) return gorm.Open(mysql.Open(dsn), &gorm.Config{ PrepareStmt: true, // precompile SQL }) diff --git a/relay/channel/aiproxy/adaptor.go b/relay/channel/aiproxy/adaptor.go index 2b4e3022..fd73aad8 100644 --- a/relay/channel/aiproxy/adaptor.go +++ b/relay/channel/aiproxy/adaptor.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -13,6 +14,7 @@ import ( ) type Adaptor struct { + textResponse *openai.TextResponse } func (a *Adaptor) Init(meta *util.RelayMeta) { @@ -58,3 +60,10 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { return "aiproxy" } + +func (a *Adaptor) GetLastTextResp() string { + if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil { + return a.textResponse.Choices[0].Content.(string) + } + return "" +} diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 6c6f433e..915ed7f9 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -6,9 +6,11 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" ) @@ -16,6 +18,7 @@ import ( // https://help.aliyun.com/zh/dashscope/developer-reference/api-details type Adaptor struct { + textResponse *openai.TextResponse } func (a *Adaptor) Init(meta *util.RelayMeta) { @@ -68,7 +71,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel case constant.RelayModeEmbeddings: err, usage = EmbeddingHandler(c, resp) default: - err, usage = Handler(c, resp) + err, a.textResponse = Handler(c, resp) + usage = &(a.textResponse.Usage) } } return @@ -81,3 +85,10 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { return "ali" } + +func (a *Adaptor) GetLastTextResp() string { + if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil { + return a.textResponse.Choices[0].Content.(string) + } + return "" +} diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index b9625584..41bca38e 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -219,7 +219,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC return nil, &usage } -func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *openai.TextResponse) { var aliResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -233,6 +233,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } + if aliResponse.Code != "" { return &model.ErrorWithStatusCode{ Error: model.Error{ @@ -253,5 +254,5 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage + return nil, fullTextResponse } diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go index 4b873715..b7a8be43 100644 --- a/relay/channel/anthropic/adaptor.go +++ b/relay/channel/anthropic/adaptor.go @@ -13,6 +13,7 @@ import ( ) type Adaptor struct { + textResponse *openai.TextResponse } func (a *Adaptor) Init(meta *util.RelayMeta) { @@ -63,3 +64,10 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { return "authropic" } + +func (a *Adaptor) GetLastTextResp() string { + if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil { + return a.textResponse.Choices[0].Content.(string) + } + return "" +} diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index d2d06ce0..d79781d1 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" @@ -12,6 +13,7 @@ import ( ) type Adaptor struct { + textResponse *openai.TextResponse } func (a *Adaptor) Init(meta *util.RelayMeta) { @@ -91,3 +93,10 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { return "baidu" } + +func (a *Adaptor) GetLastTextResp() string { + if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil { + return a.textResponse.Choices[0].Content.(string) + } + return "" +} diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index f3305e5d..91fd9cc5 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -14,6 +14,7 @@ import ( ) type Adaptor struct { + textResponse *openai.TextResponse } func (a *Adaptor) Init(meta *util.RelayMeta) { @@ -64,3 +65,10 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { return "google gemini" } + +func (a *Adaptor) GetLastTextResp() string { + if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil { + return a.textResponse.Choices[0].Content.(string) + } + return "" +} diff --git a/relay/channel/interface.go b/relay/channel/interface.go index e25db83f..8975aa70 100644 --- a/relay/channel/interface.go +++ b/relay/channel/interface.go @@ -17,4 +17,5 @@ type Adaptor interface { DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) GetModelList() []string GetChannelName() string + GetLastTextResp() string } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 1313e317..d5f9c829 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -16,7 +16,8 @@ import ( ) type Adaptor struct { - ChannelType int + ChannelType int + textResponse *TextResponse } func (a *Adaptor) Init(meta *util.RelayMeta) { @@ -68,6 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp, meta.Mode) @@ -101,3 +103,10 @@ func (a *Adaptor) GetChannelName() string { return "openai" } } + +func (a *Adaptor) GetLastTextResp() string { + if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil { + return a.textResponse.Choices[0].Content.(string) + } + return "" +} diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index efd0620c..18b8513a 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -13,6 +13,7 @@ import ( ) type Adaptor struct { + textResponse *openai.TextResponse } func (a *Adaptor) Init(meta *util.RelayMeta) { @@ -58,3 +59,10 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { return "google palm" } + +func (a *Adaptor) GetLastTextResp() string { + if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil { + return a.textResponse.Choices[0].Content.(string) + } + return "" +} diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index f348674e..0abe4943 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -16,7 +16,8 @@ import ( // https://cloud.tencent.com/document/api/1729/101837 type Adaptor struct { - Sign string + Sign string + textResponse *openai.TextResponse } func (a *Adaptor) Init(meta *util.RelayMeta) { diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go index 05edac20..dc5c3a43 100644 --- a/relay/channel/tencent/main.go +++ b/relay/channel/tencent/main.go @@ -236,3 +236,10 @@ func GetSign(req ChatRequest, secretKey string) string { sign := mac.Sum([]byte(nil)) return base64.StdEncoding.EncodeToString(sign) } + +func (a *Adaptor) GetLastTextResp() string { + if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil { + return a.textResponse.Choices[0].Content.(string) + } + return "" +} diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 92d9d7d6..9772db82 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -13,7 +13,8 @@ import ( ) type Adaptor struct { - request *model.GeneralOpenAIRequest + request *model.GeneralOpenAIRequest + textResponse *openai.TextResponse } func (a *Adaptor) Init(meta *util.RelayMeta) { @@ -68,3 +69,10 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { return "xunfei" } + +func (a *Adaptor) GetLastTextResp() string { + if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil { + return a.textResponse.Choices[0].Content.(string) + } + return "" +} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 7a822853..b0404cd1 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -12,6 +13,7 @@ import ( ) type Adaptor struct { + textResponse *openai.TextResponse } func (a *Adaptor) Init(meta *util.RelayMeta) { @@ -60,3 +62,10 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { return "zhipu" } + +func (a *Adaptor) GetLastTextResp() string { + if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil { + return a.textResponse.Choices[0].Content.(string) + } + return "" +} diff --git a/relay/controller/text.go b/relay/controller/text.go index cc460511..6dd41c8b 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" @@ -26,6 +27,17 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) } + + memory := GetMemory(meta.UserId, meta.TokenName) + reqMsg := textRequest.Messages[0] + + //fmt.Printf("----req msg %v \n", textRequest.Messages) + //fmt.Printf("----memory%v \n", memory) + //fmt.Println("-------------------------------------------------------------") + + memory = append(memory, textRequest.Messages...) + textRequest.Messages = memory + meta.IsStream = textRequest.Stream // map model name @@ -95,7 +107,56 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return respErr } + + if !meta.IsStream { + //非流式,保存历史记录 + respMsg := adaptor.GetLastTextResp() + SaveMemory(meta.UserId, meta.TokenName, respMsg, reqMsg) + + } // post-consume quota go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) return nil } + +func SaveMemory(userId int, tokenName, resp string, req model.Message) { + if len(resp) < 1 { + return + } + + msgs := []model.Message{} + msgs = append(msgs, req) + msgs = append(msgs, model.Message{Role: "assistant", Content: resp}) + + v, _ := json.Marshal(&msgs) + + key := fmt.Sprintf("one_api_memory:%d:%s", userId, tokenName) + + common.RedisLPush(key, string(v)) +} + +func GetMemory(userId int, tokenName string) []model.Message { + + key := fmt.Sprintf("one_api_memory:%d:%s", userId, tokenName) + + ss := common.RedisLRange(key, 0, int64(config.MemoryMaxNum)) + var memory []model.Message + + i := len(ss) - 1 + + for i >= 0 { + s := ss[i] + var msgItem []model.Message + if e := json.Unmarshal([]byte(s), &msgItem); e != nil { + continue + } + + for _, v := range msgItem { + if v.Content != nil && len(v.Content.(string)) > 0 { + memory = append(memory, v) + } + } + i-- + } + return memory +} diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go index 31b9d2b4..9d1b70d6 100644 --- a/relay/util/relay_meta.go +++ b/relay/util/relay_meta.go @@ -1,6 +1,8 @@ package util import ( + "encoding/json" + "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/constant" @@ -51,5 +53,7 @@ func GetRelayMeta(c *gin.Context) *RelayMeta { meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] } meta.APIType = constant.ChannelType2APIType(meta.ChannelType) + ss, _ := json.Marshal(&meta) + fmt.Println("RelayMeta:>>" + string(ss)) return &meta } diff --git a/start.bat b/start.bat index e039b003..1f9f921c 100644 --- a/start.bat +++ b/start.bat @@ -1,2 +1,5 @@ -set SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" +set GIN_MODE=debug +set SQL_DSN=root:123456@tcp(localhost:3306)/oneapi +set REDIS_CONN_STRING=redis://:jifeng123Redis@www.jifeng.online:8867/3 +set SYNC_FREQUENCY=1800 one-api.exe --port 3000 --log-dir ./logs \ No newline at end of file