aliqwen: add memory

This commit is contained in:
forry 2024-02-25 00:28:19 +08:00
parent a5c72fb59f
commit fc9601f0f9
19 changed files with 186 additions and 8 deletions

View File

@ -125,3 +125,4 @@ var (
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
var MemoryMaxNum = helper.GetOrDefaultEnvInt("MEMORY_MAX_NUM", 40)

View File

@ -67,3 +67,23 @@ func RedisDecrease(key string, value int64) error {
ctx := context.Background()
return RDB.DecrBy(ctx, key, value).Err()
}
func RedisLRange(key string, s, e int64) []string {
ctx := context.Background()
return RDB.LRange(ctx, key, s, e).Val()
}
func RedisLLen(key string) int64 {
ctx := context.Background()
return RDB.LLen(ctx, key).Val()
}
func RedisLPush(key, v string) error {
ctx := context.Background()
return RDB.LPush(ctx, key, v).Err()
}
func RedisRPop(key, v string) string {
ctx := context.Background()
return RDB.RPop(ctx, key).Val()
}

View File

@ -55,7 +55,7 @@ func chooseDB() (*gorm.DB, error) {
})
}
// Use MySQL
logger.SysLog("using MySQL as database")
logger.SysLog("using MySQL as database" + dsn)
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})

View File

@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
@ -13,6 +14,7 @@ import (
)
type Adaptor struct {
textResponse *openai.TextResponse
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -58,3 +60,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string {
return "aiproxy"
}
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -6,9 +6,11 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
@ -16,6 +18,7 @@ import (
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
type Adaptor struct {
textResponse *openai.TextResponse
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -68,7 +71,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
err, a.textResponse = Handler(c, resp)
usage = &(a.textResponse.Usage)
}
}
return
@ -81,3 +85,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string {
return "ali"
}
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -219,7 +219,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *openai.TextResponse) {
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@ -233,6 +233,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
@ -253,5 +254,5 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
return nil, fullTextResponse
}

View File

@ -13,6 +13,7 @@ import (
)
type Adaptor struct {
textResponse *openai.TextResponse
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -63,3 +64,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string {
return "authropic"
}
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -4,6 +4,7 @@ import (
"errors"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
@ -12,6 +13,7 @@ import (
)
type Adaptor struct {
textResponse *openai.TextResponse
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -91,3 +93,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string {
return "baidu"
}
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -14,6 +14,7 @@ import (
)
type Adaptor struct {
textResponse *openai.TextResponse
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -64,3 +65,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string {
return "google gemini"
}
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -17,4 +17,5 @@ type Adaptor interface {
DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
GetLastTextResp() string
}

View File

@ -16,7 +16,8 @@ import (
)
type Adaptor struct {
ChannelType int
ChannelType int
textResponse *TextResponse
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -68,6 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp, meta.Mode)
@ -101,3 +103,10 @@ func (a *Adaptor) GetChannelName() string {
return "openai"
}
}
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -13,6 +13,7 @@ import (
)
type Adaptor struct {
textResponse *openai.TextResponse
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -58,3 +59,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string {
return "google palm"
}
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -16,7 +16,8 @@ import (
// https://cloud.tencent.com/document/api/1729/101837
type Adaptor struct {
Sign string
Sign string
textResponse *openai.TextResponse
}
func (a *Adaptor) Init(meta *util.RelayMeta) {

View File

@ -236,3 +236,10 @@ func GetSign(req ChatRequest, secretKey string) string {
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
}
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -13,7 +13,8 @@ import (
)
type Adaptor struct {
request *model.GeneralOpenAIRequest
request *model.GeneralOpenAIRequest
textResponse *openai.TextResponse
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -68,3 +69,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string {
return "xunfei"
}
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
@ -12,6 +13,7 @@ import (
)
type Adaptor struct {
textResponse *openai.TextResponse
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -60,3 +62,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string {
return "zhipu"
}
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
@ -26,6 +27,17 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
}
memory := GetMemory(meta.UserId, meta.TokenName)
reqMsg := textRequest.Messages[0]
//fmt.Printf("----req msg %v \n", textRequest.Messages)
//fmt.Printf("----memory%v \n", memory)
//fmt.Println("-------------------------------------------------------------")
memory = append(memory, textRequest.Messages...)
textRequest.Messages = memory
meta.IsStream = textRequest.Stream
// map model name
@ -95,7 +107,56 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return respErr
}
if !meta.IsStream {
//非流式,保存历史记录
respMsg := adaptor.GetLastTextResp()
SaveMemory(meta.UserId, meta.TokenName, respMsg, reqMsg)
}
// post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
return nil
}
func SaveMemory(userId int, tokenName, resp string, req model.Message) {
if len(resp) < 1 {
return
}
msgs := []model.Message{}
msgs = append(msgs, req)
msgs = append(msgs, model.Message{Role: "assistant", Content: resp})
v, _ := json.Marshal(&msgs)
key := fmt.Sprintf("one_api_memory:%d:%s", userId, tokenName)
common.RedisLPush(key, string(v))
}
func GetMemory(userId int, tokenName string) []model.Message {
key := fmt.Sprintf("one_api_memory:%d:%s", userId, tokenName)
ss := common.RedisLRange(key, 0, int64(config.MemoryMaxNum))
var memory []model.Message
i := len(ss) - 1
for i >= 0 {
s := ss[i]
var msgItem []model.Message
if e := json.Unmarshal([]byte(s), &msgItem); e != nil {
continue
}
for _, v := range msgItem {
if v.Content != nil && len(v.Content.(string)) > 0 {
memory = append(memory, v)
}
}
i--
}
return memory
}

View File

@ -1,6 +1,8 @@
package util
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/constant"
@ -51,5 +53,7 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
}
meta.APIType = constant.ChannelType2APIType(meta.ChannelType)
ss, _ := json.Marshal(&meta)
fmt.Println("RelayMeta:>>" + string(ss))
return &meta
}

View File

@ -1,2 +1,5 @@
set SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"
set GIN_MODE=debug
set SQL_DSN=root:123456@tcp(localhost:3306)/oneapi
set REDIS_CONN_STRING=redis://:jifeng123Redis@www.jifeng.online:8867/3
set SYNC_FREQUENCY=1800
one-api.exe --port 3000 --log-dir ./logs