diff --git a/common/model-ratio.go b/common/model-ratio.go index eb4a07f0..67aa6147 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -138,6 +138,35 @@ func UpdateModelRatioByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &ModelRatio) } +func MergeModelRatioByJSONString(jsonStr string) (newJsonStr string, err error) { + inputModelRatio := make(map[string]float64) + err = json.Unmarshal([]byte(jsonStr), &inputModelRatio) + if err != nil { + return + } + + isNew := false + // 与现有的ModelRatio进行比较,如果有新增的模型,需要添加 + for key, value := range ModelRatio { + if _, ok := inputModelRatio[key]; !ok { + isNew = true + inputModelRatio[key] = value + } + } + + if !isNew { + return + } + + var jsonBytes []byte + jsonBytes, err = json.Marshal(inputModelRatio) + if err != nil { + SysError("error marshalling model ratio: " + err.Error()) + } + newJsonStr = string(jsonBytes) + return +} + func GetModelRatio(name string) float64 { if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") diff --git a/model/option.go b/model/option.go index bb8b709c..730a46fb 100644 --- a/model/option.go +++ b/model/option.go @@ -19,6 +19,11 @@ func AllOption() ([]*Option, error) { return options, err } +func GetOption(key string) (option Option, err error) { + err = DB.First(&option, Option{Key: key}).Error + return +} + func InitOptionMap() { common.OptionMapRWMutex.Lock() common.OptionMap = make(map[string]string) @@ -73,9 +78,28 @@ func InitOptionMap() { common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) common.OptionMapRWMutex.Unlock() + initModelRatio() loadOptionsFromDatabase() } +func initModelRatio() { + // 查询数据库中的ModelRatio + option, err := GetOption("ModelRatio") + if err != nil || option.Value == "" { + return + } + + newModelRatio, err := common.MergeModelRatioByJSONString(option.Value) + if err != nil || newModelRatio == "" { + return + } + + // 更新数据库中的ModelRatio + common.SysLog("update ModelRatio") + UpdateOption("ModelRatio", newModelRatio) + +} + func loadOptionsFromDatabase() { options, _ := AllOption() for _, option := range options {