aliqwen: add memory

This commit is contained in:
forry 2024-02-25 00:28:19 +08:00
parent a5c72fb59f
commit fc9601f0f9
19 changed files with 186 additions and 8 deletions

View File

@ -125,3 +125,4 @@ var (
) )
var RateLimitKeyExpirationDuration = 20 * time.Minute var RateLimitKeyExpirationDuration = 20 * time.Minute
var MemoryMaxNum = helper.GetOrDefaultEnvInt("MEMORY_MAX_NUM", 40)

View File

@ -67,3 +67,23 @@ func RedisDecrease(key string, value int64) error {
ctx := context.Background() ctx := context.Background()
return RDB.DecrBy(ctx, key, value).Err() return RDB.DecrBy(ctx, key, value).Err()
} }
func RedisLRange(key string, s, e int64) []string {
ctx := context.Background()
return RDB.LRange(ctx, key, s, e).Val()
}
func RedisLLen(key string) int64 {
ctx := context.Background()
return RDB.LLen(ctx, key).Val()
}
func RedisLPush(key, v string) error {
ctx := context.Background()
return RDB.LPush(ctx, key, v).Err()
}
func RedisRPop(key, v string) string {
ctx := context.Background()
return RDB.RPop(ctx, key).Val()
}

View File

@ -55,7 +55,7 @@ func chooseDB() (*gorm.DB, error) {
}) })
} }
// Use MySQL // Use MySQL
logger.SysLog("using MySQL as database") logger.SysLog("using MySQL as database" + dsn)
return gorm.Open(mysql.Open(dsn), &gorm.Config{ return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })

View File

@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
@ -13,6 +14,7 @@ import (
) )
type Adaptor struct { type Adaptor struct {
textResponse *openai.TextResponse
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -58,3 +60,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetChannelName() string {
return "aiproxy" return "aiproxy"
} }
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -6,9 +6,11 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
) )
@ -16,6 +18,7 @@ import (
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details // https://help.aliyun.com/zh/dashscope/developer-reference/api-details
type Adaptor struct { type Adaptor struct {
textResponse *openai.TextResponse
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -68,7 +71,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp) err, usage = EmbeddingHandler(c, resp)
default: default:
err, usage = Handler(c, resp) err, a.textResponse = Handler(c, resp)
usage = &(a.textResponse.Usage)
} }
} }
return return
@ -81,3 +85,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetChannelName() string {
return "ali" return "ali"
} }
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -219,7 +219,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return nil, &usage return nil, &usage
} }
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *openai.TextResponse) {
var aliResponse ChatResponse var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
@ -233,6 +233,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if aliResponse.Code != "" { if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{ return &model.ErrorWithStatusCode{
Error: model.Error{ Error: model.Error{
@ -253,5 +254,5 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage return nil, fullTextResponse
} }

View File

@ -13,6 +13,7 @@ import (
) )
type Adaptor struct { type Adaptor struct {
textResponse *openai.TextResponse
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -63,3 +64,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetChannelName() string {
return "authropic" return "authropic"
} }
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
@ -12,6 +13,7 @@ import (
) )
type Adaptor struct { type Adaptor struct {
textResponse *openai.TextResponse
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -91,3 +93,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetChannelName() string {
return "baidu" return "baidu"
} }
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -14,6 +14,7 @@ import (
) )
type Adaptor struct { type Adaptor struct {
textResponse *openai.TextResponse
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -64,3 +65,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetChannelName() string {
return "google gemini" return "google gemini"
} }
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -17,4 +17,5 @@ type Adaptor interface {
DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string GetModelList() []string
GetChannelName() string GetChannelName() string
GetLastTextResp() string
} }

View File

@ -17,6 +17,7 @@ import (
type Adaptor struct { type Adaptor struct {
ChannelType int ChannelType int
textResponse *TextResponse
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -68,6 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
var responseText string var responseText string
err, responseText = StreamHandler(c, resp, meta.Mode) err, responseText = StreamHandler(c, resp, meta.Mode)
@ -101,3 +103,10 @@ func (a *Adaptor) GetChannelName() string {
return "openai" return "openai"
} }
} }
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -13,6 +13,7 @@ import (
) )
type Adaptor struct { type Adaptor struct {
textResponse *openai.TextResponse
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -58,3 +59,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetChannelName() string {
return "google palm" return "google palm"
} }
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -17,6 +17,7 @@ import (
type Adaptor struct { type Adaptor struct {
Sign string Sign string
textResponse *openai.TextResponse
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *util.RelayMeta) {

View File

@ -236,3 +236,10 @@ func GetSign(req ChatRequest, secretKey string) string {
sign := mac.Sum([]byte(nil)) sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign) return base64.StdEncoding.EncodeToString(sign)
} }
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -14,6 +14,7 @@ import (
type Adaptor struct { type Adaptor struct {
request *model.GeneralOpenAIRequest request *model.GeneralOpenAIRequest
textResponse *openai.TextResponse
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -68,3 +69,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetChannelName() string {
return "xunfei" return "xunfei"
} }
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
@ -12,6 +13,7 @@ import (
) )
type Adaptor struct { type Adaptor struct {
textResponse *openai.TextResponse
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *util.RelayMeta) {
@ -60,3 +62,10 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetChannelName() string {
return "zhipu" return "zhipu"
} }
func (a *Adaptor) GetLastTextResp() string {
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
return a.textResponse.Choices[0].Content.(string)
}
return ""
}

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
@ -26,6 +27,17 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error()) logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
} }
memory := GetMemory(meta.UserId, meta.TokenName)
reqMsg := textRequest.Messages[0]
//fmt.Printf("----req msg %v \n", textRequest.Messages)
//fmt.Printf("----memory%v \n", memory)
//fmt.Println("-------------------------------------------------------------")
memory = append(memory, textRequest.Messages...)
textRequest.Messages = memory
meta.IsStream = textRequest.Stream meta.IsStream = textRequest.Stream
// map model name // map model name
@ -95,7 +107,56 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return respErr return respErr
} }
if !meta.IsStream {
//非流式,保存历史记录
respMsg := adaptor.GetLastTextResp()
SaveMemory(meta.UserId, meta.TokenName, respMsg, reqMsg)
}
// post-consume quota // post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
return nil return nil
} }
func SaveMemory(userId int, tokenName, resp string, req model.Message) {
if len(resp) < 1 {
return
}
msgs := []model.Message{}
msgs = append(msgs, req)
msgs = append(msgs, model.Message{Role: "assistant", Content: resp})
v, _ := json.Marshal(&msgs)
key := fmt.Sprintf("one_api_memory:%d:%s", userId, tokenName)
common.RedisLPush(key, string(v))
}
func GetMemory(userId int, tokenName string) []model.Message {
key := fmt.Sprintf("one_api_memory:%d:%s", userId, tokenName)
ss := common.RedisLRange(key, 0, int64(config.MemoryMaxNum))
var memory []model.Message
i := len(ss) - 1
for i >= 0 {
s := ss[i]
var msgItem []model.Message
if e := json.Unmarshal([]byte(s), &msgItem); e != nil {
continue
}
for _, v := range msgItem {
if v.Content != nil && len(v.Content.(string)) > 0 {
memory = append(memory, v)
}
}
i--
}
return memory
}

View File

@ -1,6 +1,8 @@
package util package util
import ( import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
@ -51,5 +53,7 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
} }
meta.APIType = constant.ChannelType2APIType(meta.ChannelType) meta.APIType = constant.ChannelType2APIType(meta.ChannelType)
ss, _ := json.Marshal(&meta)
fmt.Println("RelayMeta:>>" + string(ss))
return &meta return &meta
} }

View File

@ -1,2 +1,5 @@
set SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" set GIN_MODE=debug
set SQL_DSN=root:123456@tcp(localhost:3306)/oneapi
set REDIS_CONN_STRING=redis://:jifeng123Redis@www.jifeng.online:8867/3
set SYNC_FREQUENCY=1800
one-api.exe --port 3000 --log-dir ./logs one-api.exe --port 3000 --log-dir ./logs