From f564245e347da5dca078de630ef4a6d86fa60096 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Thu, 29 Feb 2024 08:58:54 +0000 Subject: [PATCH] fix: #1054 Add model mapping to abilities - Update Abilities model to include model mapping key - Parse model mapping in Channel model and filter models by model mapping in Update function --- common/image/image.go | 2 +- controller/relay.go | 2 +- model/ability.go | 15 +++++++++++++++ model/channel.go | 21 +++++++++++++++++++-- 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/common/image/image.go b/common/image/image.go index de8fefd3..12f0adff 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -16,7 +16,7 @@ import ( ) // Regex to match data URL pattern -var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) +var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) func IsImageUrl(url string) (bool, error) { resp, err := http.Head(url) diff --git a/controller/relay.go b/controller/relay.go index 499e8ddc..9ace90ed 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -60,7 +60,7 @@ func Relay(c *gin.Context) { for i := retryTimes; i > 0; i-- { channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) if err != nil { - logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) + logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %v", err) break } logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i) diff --git a/model/ability.go b/model/ability.go index 7127abc3..5d65612b 100644 --- a/model/ability.go +++ b/model/ability.go @@ -55,6 +55,21 @@ func (channel *Channel) AddAbilities() error { abilities = append(abilities, ability) } } + + // add keys of model mapping to abilities + for model := range channel.GetModelMapping() { + for _, group := range groups_ { + ability := Ability{ + Group: group, + Model: model, + ChannelId: channel.Id, + Enabled: channel.Status == common.ChannelStatusEnabled, + Priority: channel.Priority, + } + abilities = append(abilities, ability) + } + } + return DB.Create(&abilities).Error } diff --git a/model/channel.go b/model/channel.go index 19af2263..dc4d338b 100644 --- a/model/channel.go +++ b/model/channel.go @@ -3,6 +3,8 @@ package model import ( "encoding/json" "fmt" + "strings" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" @@ -115,8 +117,23 @@ func (channel *Channel) Insert() error { return err } -func (channel *Channel) Update() error { - var err error +func (channel *Channel) Update() (err error) { + // https://github.com/songquanpeng/one-api/issues/1054 + // for compatability, filter models by model-mapping. + mapping := channel.GetModelMapping() + if len(mapping) != 0 { + models := strings.Split(channel.Models, ",") + var filteredModels []string + for _, model := range models { + if _, ok := mapping[model]; !ok { + filteredModels = append(filteredModels, model) + } + } + + channel.Models = strings.Join(filteredModels, ",") + } + + // update err = DB.Model(channel).Updates(channel).Error if err != nil { return err