aliqwen: add memory
This commit is contained in:
parent
a5c72fb59f
commit
fc9601f0f9
@ -125,3 +125,4 @@ var (
|
||||
)
|
||||
|
||||
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
||||
var MemoryMaxNum = helper.GetOrDefaultEnvInt("MEMORY_MAX_NUM", 40)
|
||||
|
@ -67,3 +67,23 @@ func RedisDecrease(key string, value int64) error {
|
||||
ctx := context.Background()
|
||||
return RDB.DecrBy(ctx, key, value).Err()
|
||||
}
|
||||
|
||||
func RedisLRange(key string, s, e int64) []string {
|
||||
ctx := context.Background()
|
||||
return RDB.LRange(ctx, key, s, e).Val()
|
||||
}
|
||||
|
||||
func RedisLLen(key string) int64 {
|
||||
ctx := context.Background()
|
||||
return RDB.LLen(ctx, key).Val()
|
||||
}
|
||||
|
||||
func RedisLPush(key, v string) error {
|
||||
ctx := context.Background()
|
||||
return RDB.LPush(ctx, key, v).Err()
|
||||
}
|
||||
|
||||
func RedisRPop(key, v string) string {
|
||||
ctx := context.Background()
|
||||
return RDB.RPop(ctx, key).Val()
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ func chooseDB() (*gorm.DB, error) {
|
||||
})
|
||||
}
|
||||
// Use MySQL
|
||||
logger.SysLog("using MySQL as database")
|
||||
logger.SysLog("using MySQL as database" + dsn)
|
||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
@ -13,6 +14,7 @@ import (
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
textResponse *openai.TextResponse
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
@ -58,3 +60,10 @@ func (a *Adaptor) GetModelList() []string {
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "aiproxy"
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetLastTextResp() string {
|
||||
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
|
||||
return a.textResponse.Choices[0].Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -6,9 +6,11 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
@ -16,6 +18,7 @@ import (
|
||||
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
||||
|
||||
type Adaptor struct {
|
||||
textResponse *openai.TextResponse
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
@ -68,7 +71,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
err, a.textResponse = Handler(c, resp)
|
||||
usage = &(a.textResponse.Usage)
|
||||
}
|
||||
}
|
||||
return
|
||||
@ -81,3 +85,10 @@ func (a *Adaptor) GetModelList() []string {
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "ali"
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetLastTextResp() string {
|
||||
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
|
||||
return a.textResponse.Choices[0].Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -219,7 +219,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *openai.TextResponse) {
|
||||
var aliResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@ -233,6 +233,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
@ -253,5 +254,5 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
return nil, fullTextResponse
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
textResponse *openai.TextResponse
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
@ -63,3 +64,10 @@ func (a *Adaptor) GetModelList() []string {
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "authropic"
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetLastTextResp() string {
|
||||
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
|
||||
return a.textResponse.Choices[0].Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
@ -12,6 +13,7 @@ import (
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
textResponse *openai.TextResponse
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
@ -91,3 +93,10 @@ func (a *Adaptor) GetModelList() []string {
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "baidu"
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetLastTextResp() string {
|
||||
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
|
||||
return a.textResponse.Choices[0].Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
textResponse *openai.TextResponse
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
@ -64,3 +65,10 @@ func (a *Adaptor) GetModelList() []string {
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "google gemini"
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetLastTextResp() string {
|
||||
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
|
||||
return a.textResponse.Choices[0].Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -17,4 +17,5 @@ type Adaptor interface {
|
||||
DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
GetLastTextResp() string
|
||||
}
|
||||
|
@ -16,7 +16,8 @@ import (
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
ChannelType int
|
||||
ChannelType int
|
||||
textResponse *TextResponse
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
@ -68,6 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp, meta.Mode)
|
||||
@ -101,3 +103,10 @@ func (a *Adaptor) GetChannelName() string {
|
||||
return "openai"
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetLastTextResp() string {
|
||||
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
|
||||
return a.textResponse.Choices[0].Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
textResponse *openai.TextResponse
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
@ -58,3 +59,10 @@ func (a *Adaptor) GetModelList() []string {
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "google palm"
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetLastTextResp() string {
|
||||
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
|
||||
return a.textResponse.Choices[0].Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -16,7 +16,8 @@ import (
|
||||
// https://cloud.tencent.com/document/api/1729/101837
|
||||
|
||||
type Adaptor struct {
|
||||
Sign string
|
||||
Sign string
|
||||
textResponse *openai.TextResponse
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
|
@ -236,3 +236,10 @@ func GetSign(req ChatRequest, secretKey string) string {
|
||||
sign := mac.Sum([]byte(nil))
|
||||
return base64.StdEncoding.EncodeToString(sign)
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetLastTextResp() string {
|
||||
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
|
||||
return a.textResponse.Choices[0].Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -13,7 +13,8 @@ import (
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
request *model.GeneralOpenAIRequest
|
||||
request *model.GeneralOpenAIRequest
|
||||
textResponse *openai.TextResponse
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
@ -68,3 +69,10 @@ func (a *Adaptor) GetModelList() []string {
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "xunfei"
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetLastTextResp() string {
|
||||
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
|
||||
return a.textResponse.Choices[0].Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
@ -12,6 +13,7 @@ import (
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
textResponse *openai.TextResponse
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
@ -60,3 +62,10 @@ func (a *Adaptor) GetModelList() []string {
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "zhipu"
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetLastTextResp() string {
|
||||
if a.textResponse != nil && len(a.textResponse.Choices) > 0 && a.textResponse.Choices[0].Content != nil {
|
||||
return a.textResponse.Choices[0].Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
@ -26,6 +27,17 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
|
||||
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
memory := GetMemory(meta.UserId, meta.TokenName)
|
||||
reqMsg := textRequest.Messages[0]
|
||||
|
||||
//fmt.Printf("----req msg %v \n", textRequest.Messages)
|
||||
//fmt.Printf("----memory%v \n", memory)
|
||||
//fmt.Println("-------------------------------------------------------------")
|
||||
|
||||
memory = append(memory, textRequest.Messages...)
|
||||
textRequest.Messages = memory
|
||||
|
||||
meta.IsStream = textRequest.Stream
|
||||
|
||||
// map model name
|
||||
@ -95,7 +107,56 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
return respErr
|
||||
}
|
||||
|
||||
if !meta.IsStream {
|
||||
//非流式,保存历史记录
|
||||
respMsg := adaptor.GetLastTextResp()
|
||||
SaveMemory(meta.UserId, meta.TokenName, respMsg, reqMsg)
|
||||
|
||||
}
|
||||
// post-consume quota
|
||||
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
|
||||
return nil
|
||||
}
|
||||
|
||||
func SaveMemory(userId int, tokenName, resp string, req model.Message) {
|
||||
if len(resp) < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
msgs := []model.Message{}
|
||||
msgs = append(msgs, req)
|
||||
msgs = append(msgs, model.Message{Role: "assistant", Content: resp})
|
||||
|
||||
v, _ := json.Marshal(&msgs)
|
||||
|
||||
key := fmt.Sprintf("one_api_memory:%d:%s", userId, tokenName)
|
||||
|
||||
common.RedisLPush(key, string(v))
|
||||
}
|
||||
|
||||
func GetMemory(userId int, tokenName string) []model.Message {
|
||||
|
||||
key := fmt.Sprintf("one_api_memory:%d:%s", userId, tokenName)
|
||||
|
||||
ss := common.RedisLRange(key, 0, int64(config.MemoryMaxNum))
|
||||
var memory []model.Message
|
||||
|
||||
i := len(ss) - 1
|
||||
|
||||
for i >= 0 {
|
||||
s := ss[i]
|
||||
var msgItem []model.Message
|
||||
if e := json.Unmarshal([]byte(s), &msgItem); e != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, v := range msgItem {
|
||||
if v.Content != nil && len(v.Content.(string)) > 0 {
|
||||
memory = append(memory, v)
|
||||
}
|
||||
}
|
||||
i--
|
||||
}
|
||||
return memory
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
@ -51,5 +53,7 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
|
||||
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
|
||||
}
|
||||
meta.APIType = constant.ChannelType2APIType(meta.ChannelType)
|
||||
ss, _ := json.Marshal(&meta)
|
||||
fmt.Println("RelayMeta:>>" + string(ss))
|
||||
return &meta
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user