From d6536d29070deb82ae8e9fc2888d70fa2aa23506 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Mon, 15 Jul 2024 06:28:03 +0000 Subject: [PATCH] fix: update GetAdaptor function to use the actual model name The GetAdaptor function in the Adaptor struct has been updated to use the actual model name instead of the origin model name. This change ensures that the correct adaptor is retrieved for processing the response. --- common/config/config.go | 1 - relay/adaptor/vertexai/adaptor.go | 4 ++-- relay/adaptor/vertexai/claude/adapter.go | 14 ++++++------ relay/adaptor/vertexai/claude/model.go | 22 +++++++++---------- relay/adaptor/vertexai/registry.go | 1 - relay/adaptor/vertexai/token.go | 1 - relay/channeltype/url.go | 2 +- relay/meta/relay_meta.go | 28 +++++++++++++----------- 8 files changed, 36 insertions(+), 37 deletions(-) diff --git a/common/config/config.go b/common/config/config.go index 9b55e413..11da0b96 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -147,7 +147,6 @@ var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN") var GeminiVersion = env.String("GEMINI_VERSION", "v1") - var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) var RelayProxy = env.String("RELAY_PROXY", "") diff --git a/relay/adaptor/vertexai/adaptor.go b/relay/adaptor/vertexai/adaptor.go index f1b68b70..cf5d4c48 100644 --- a/relay/adaptor/vertexai/adaptor.go +++ b/relay/adaptor/vertexai/adaptor.go @@ -19,7 +19,7 @@ var _ adaptor.Adaptor = new(Adaptor) const channelName = "vertexai" -type Adaptor struct {} +type Adaptor struct{} func (a *Adaptor) Init(meta *meta.Meta) { } @@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - adaptor := GetAdaptor(meta.OriginModelName) + adaptor := GetAdaptor(meta.ActualModelName) if adaptor == nil { return nil, &relaymodel.ErrorWithStatusCode{ StatusCode: http.StatusInternalServerError, diff --git a/relay/adaptor/vertexai/claude/adapter.go b/relay/adaptor/vertexai/claude/adapter.go index 1e37f0ff..aab39864 100644 --- a/relay/adaptor/vertexai/claude/adapter.go +++ b/relay/adaptor/vertexai/claude/adapter.go @@ -30,13 +30,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G req := Request{ AnthropicVersion: anthropicVersion, // Model: claudeReq.Model, - Messages: claudeReq.Messages, - MaxTokens: claudeReq.MaxTokens, - Temperature: claudeReq.Temperature, - TopP: claudeReq.TopP, - TopK: claudeReq.TopK, - Stream: claudeReq.Stream, - Tools: claudeReq.Tools, + Messages: claudeReq.Messages, + MaxTokens: claudeReq.MaxTokens, + Temperature: claudeReq.Temperature, + TopP: claudeReq.TopP, + TopK: claudeReq.TopK, + Stream: claudeReq.Stream, + Tools: claudeReq.Tools, } c.Set(ctxkey.RequestModel, request.Model) diff --git a/relay/adaptor/vertexai/claude/model.go b/relay/adaptor/vertexai/claude/model.go index 2f13f598..e1bd5dd4 100644 --- a/relay/adaptor/vertexai/claude/model.go +++ b/relay/adaptor/vertexai/claude/model.go @@ -4,16 +4,16 @@ import "github.com/songquanpeng/one-api/relay/adaptor/anthropic" type Request struct { // AnthropicVersion must be "vertex-2023-10-16" - AnthropicVersion string `json:"anthropic_version"` + AnthropicVersion string `json:"anthropic_version"` // Model string `json:"model"` - Messages []anthropic.Message `json:"messages"` - System string `json:"system,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Tools []anthropic.Tool `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` + Messages []anthropic.Message `json:"messages"` + System string `json:"system,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []anthropic.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` } diff --git a/relay/adaptor/vertexai/registry.go b/relay/adaptor/vertexai/registry.go index f9547ebf..41099f02 100644 --- a/relay/adaptor/vertexai/registry.go +++ b/relay/adaptor/vertexai/registry.go @@ -32,7 +32,6 @@ func init() { } } - type innerAIAdapter interface { ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) diff --git a/relay/adaptor/vertexai/token.go b/relay/adaptor/vertexai/token.go index e5fa7b48..0a5e0aad 100644 --- a/relay/adaptor/vertexai/token.go +++ b/relay/adaptor/vertexai/token.go @@ -26,7 +26,6 @@ type ApplicationDefaultCredentials struct { UniverseDomain string `json:"universe_domain"` } - var Cache = cache.New(50*time.Minute, 55*time.Minute) const defaultScope = "https://www.googleapis.com/auth/cloud-platform" diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index 47241063..20a24ab0 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -43,7 +43,7 @@ var ChannelBaseURLs = []string{ "https://api.together.xyz", // 39 "https://ark.cn-beijing.volces.com", // 40 "https://api.novita.ai/v3/openai", // 41 - "", // 42 + "", // 42 } func init() { diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index 9714ebb5..04977db5 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -10,20 +10,22 @@ import ( ) type Meta struct { - Mode int - ChannelType int - ChannelId int - TokenId int - TokenName string - UserId int - Group string - ModelMapping map[string]string - BaseURL string - APIKey string - APIType int - Config model.ChannelConfig - IsStream bool + Mode int + ChannelType int + ChannelId int + TokenId int + TokenName string + UserId int + Group string + ModelMapping map[string]string + BaseURL string + APIKey string + APIType int + Config model.ChannelConfig + IsStream bool + // OriginModelName is the model name from the raw user request OriginModelName string + // ActualModelName is the model name after mapping ActualModelName string RequestURLPath string PromptTokens int // only for DoResponse