Merge e187fe0490
into fdd7bf41c0
This commit is contained in:
commit
1994861a46
@ -14,6 +14,7 @@ const (
|
|||||||
OriginalModel = "original_model"
|
OriginalModel = "original_model"
|
||||||
Group = "group"
|
Group = "group"
|
||||||
ModelMapping = "model_mapping"
|
ModelMapping = "model_mapping"
|
||||||
|
ParamsOverride = "params_override"
|
||||||
ChannelName = "channel_name"
|
ChannelName = "channel_name"
|
||||||
TokenId = "token_id"
|
TokenId = "token_id"
|
||||||
TokenName = "token_name"
|
TokenName = "token_name"
|
||||||
|
@ -62,6 +62,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
c.Set(ctxkey.ChannelId, channel.Id)
|
c.Set(ctxkey.ChannelId, channel.Id)
|
||||||
c.Set(ctxkey.ChannelName, channel.Name)
|
c.Set(ctxkey.ChannelName, channel.Name)
|
||||||
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
||||||
|
c.Set(ctxkey.ParamsOverride, channel.GetParamsOverride())
|
||||||
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
|
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
|
||||||
|
@ -35,6 +35,7 @@ type Channel struct {
|
|||||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
|
ParamsOverride *string `json:"default_params_override" gorm:"type:text;default:''"`
|
||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
Config string `json:"config"`
|
Config string `json:"config"`
|
||||||
}
|
}
|
||||||
@ -123,6 +124,20 @@ func (channel *Channel) GetModelMapping() map[string]string {
|
|||||||
return modelMapping
|
return modelMapping
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetParamsOverride() map[string]map[string]interface{} {
|
||||||
|
if channel.ParamsOverride == nil || *channel.ParamsOverride == "" || *channel.ParamsOverride == "{}" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
paramsOverride := make(map[string]map[string]interface{})
|
||||||
|
err := json.Unmarshal([]byte(*channel.ParamsOverride), ¶msOverride)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError(fmt.Sprintf("failed to unmarshal params override for channel %d, error: %s", channel.Id, err.Error()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return paramsOverride
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func (channel *Channel) Insert() error {
|
func (channel *Channel) Insert() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Create(channel).Error
|
err = DB.Create(channel).Error
|
||||||
|
@ -6,6 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"io/ioutil"
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
@ -23,7 +25,18 @@ import (
|
|||||||
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
meta := meta.GetByContext(c)
|
meta := meta.GetByContext(c)
|
||||||
// get & validate textRequest
|
|
||||||
|
// Read the original request body
|
||||||
|
bodyBytes, err := ioutil.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf(ctx, "Failed to read request body: %s", err.Error())
|
||||||
|
return openai.ErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore the request body for `getAndValidateTextRequest`
|
||||||
|
c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
|
// Call `getAndValidateTextRequest`
|
||||||
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
|
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
|
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
|
||||||
@ -31,6 +44,16 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
meta.IsStream = textRequest.Stream
|
meta.IsStream = textRequest.Stream
|
||||||
|
|
||||||
|
// Parse the request body into a map
|
||||||
|
var rawRequest map[string]interface{}
|
||||||
|
if err := json.Unmarshal(bodyBytes, &rawRequest); err != nil {
|
||||||
|
logger.Errorf(ctx, "Failed to parse request body into map: %s", err.Error())
|
||||||
|
return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply parameter overrides
|
||||||
|
applyParameterOverrides(ctx, meta, textRequest, rawRequest)
|
||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
meta.OriginModelName = textRequest.Model
|
meta.OriginModelName = textRequest.Model
|
||||||
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||||
@ -105,3 +128,70 @@ func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralO
|
|||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
return requestBody, nil
|
return requestBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func applyParameterOverrides(ctx context.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, rawRequest map[string]interface{}) {
|
||||||
|
if meta.ParamsOverride != nil {
|
||||||
|
modelName := meta.OriginModelName
|
||||||
|
if overrideParams, exists := meta.ParamsOverride[modelName]; exists {
|
||||||
|
logger.Infof(ctx, "Applying parameter overrides for model %s on channel %d", modelName, meta.ChannelId)
|
||||||
|
for key, value := range overrideParams {
|
||||||
|
if _, userSpecified := rawRequest[key]; !userSpecified {
|
||||||
|
// Apply the override since the user didn't specify this parameter
|
||||||
|
switch key {
|
||||||
|
case "temperature":
|
||||||
|
if v, ok := value.(float64); ok {
|
||||||
|
textRequest.Temperature = v
|
||||||
|
} else if v, ok := value.(int); ok {
|
||||||
|
textRequest.Temperature = float64(v)
|
||||||
|
}
|
||||||
|
case "max_tokens":
|
||||||
|
if v, ok := value.(float64); ok {
|
||||||
|
textRequest.MaxTokens = int(v)
|
||||||
|
} else if v, ok := value.(int); ok {
|
||||||
|
textRequest.MaxTokens = v
|
||||||
|
}
|
||||||
|
case "top_p":
|
||||||
|
if v, ok := value.(float64); ok {
|
||||||
|
textRequest.TopP = v
|
||||||
|
} else if v, ok := value.(int); ok {
|
||||||
|
textRequest.TopP = float64(v)
|
||||||
|
}
|
||||||
|
case "frequency_penalty":
|
||||||
|
if v, ok := value.(float64); ok {
|
||||||
|
textRequest.FrequencyPenalty = v
|
||||||
|
} else if v, ok := value.(int); ok {
|
||||||
|
textRequest.FrequencyPenalty = float64(v)
|
||||||
|
}
|
||||||
|
case "presence_penalty":
|
||||||
|
if v, ok := value.(float64); ok {
|
||||||
|
textRequest.PresencePenalty = v
|
||||||
|
} else if v, ok := value.(int); ok {
|
||||||
|
textRequest.PresencePenalty = float64(v)
|
||||||
|
}
|
||||||
|
case "stop":
|
||||||
|
textRequest.Stop = value
|
||||||
|
case "n":
|
||||||
|
if v, ok := value.(float64); ok {
|
||||||
|
textRequest.N = int(v)
|
||||||
|
} else if v, ok := value.(int); ok {
|
||||||
|
textRequest.N = v
|
||||||
|
}
|
||||||
|
case "stream":
|
||||||
|
if v, ok := value.(bool); ok {
|
||||||
|
textRequest.Stream = v
|
||||||
|
}
|
||||||
|
case "num_ctx":
|
||||||
|
if v, ok := value.(float64); ok {
|
||||||
|
textRequest.NumCtx = int(v)
|
||||||
|
} else if v, ok := value.(int); ok {
|
||||||
|
textRequest.NumCtx = v
|
||||||
|
}
|
||||||
|
// Handle other parameters as needed
|
||||||
|
default:
|
||||||
|
logger.Warnf(ctx, "Unknown parameter override key: %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -18,6 +18,7 @@ type Meta struct {
|
|||||||
UserId int
|
UserId int
|
||||||
Group string
|
Group string
|
||||||
ModelMapping map[string]string
|
ModelMapping map[string]string
|
||||||
|
ParamsOverride map[string]map[string]interface{}
|
||||||
// BaseURL is the proxy url set in the channel config
|
// BaseURL is the proxy url set in the channel config
|
||||||
BaseURL string
|
BaseURL string
|
||||||
APIKey string
|
APIKey string
|
||||||
@ -46,6 +47,11 @@ func GetByContext(c *gin.Context) *Meta {
|
|||||||
BaseURL: c.GetString(ctxkey.BaseURL),
|
BaseURL: c.GetString(ctxkey.BaseURL),
|
||||||
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||||
RequestURLPath: c.Request.URL.String(),
|
RequestURLPath: c.Request.URL.String(),
|
||||||
|
}
|
||||||
|
// Retrieve ParamsOverride
|
||||||
|
paramsOverride, exists := c.Get(ctxkey.ParamsOverride)
|
||||||
|
if exists && paramsOverride != nil {
|
||||||
|
meta.ParamsOverride = paramsOverride.(map[string]map[string]interface{})
|
||||||
}
|
}
|
||||||
cfg, ok := c.Get(ctxkey.Config)
|
cfg, ok := c.Get(ctxkey.Config)
|
||||||
if ok {
|
if ok {
|
||||||
|
@ -178,6 +178,10 @@ const EditChannel = () => {
|
|||||||
showInfo('模型映射必须是合法的 JSON 格式!');
|
showInfo('模型映射必须是合法的 JSON 格式!');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (inputs.default_params_override !== '' && !verifyJSON(inputs.default_params_override)) {
|
||||||
|
showInfo('默认参数Override必须是合法的 JSON 格式!');
|
||||||
|
return;
|
||||||
|
}
|
||||||
let localInputs = {...inputs};
|
let localInputs = {...inputs};
|
||||||
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
|
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
|
||||||
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
||||||
@ -439,6 +443,21 @@ const EditChannel = () => {
|
|||||||
</Form.Field>
|
</Form.Field>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
inputs.type !== 43 && (
|
||||||
|
<Form.Field>
|
||||||
|
<Form.TextArea
|
||||||
|
label='默认参数Override'
|
||||||
|
placeholder={`此项可选,用于修改请求体中的默认参数,为一个 JSON 字符串,键为请求中模型名称,值为要替换的默认参数,例如:\n${JSON.stringify({ 'llama3:70b': { 'num_ctx': 11520, 'temperature': 0.2 }, 'qwen2:72b': { 'num_ctx': 11520, 'temperature': 0.8 } }, null, 2)}`}
|
||||||
|
name='default_params_override'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.default_params_override}
|
||||||
|
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
|
)
|
||||||
|
}
|
||||||
{
|
{
|
||||||
inputs.type === 33 && (
|
inputs.type === 33 && (
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
|
Loading…
Reference in New Issue
Block a user