diff --git a/controller/relay-text.go b/controller/relay-text.go index 8dfdf6e1..52e10f2b 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -85,13 +85,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } } apiType := APITypeOpenAI - if strings.HasPrefix(textRequest.Model, "claude") { + switch channelType { + case common.ChannelTypeAnthropic: apiType = APITypeClaude - } else if strings.HasPrefix(textRequest.Model, "ERNIE") { + case common.ChannelTypeBaidu: apiType = APITypeBaidu - } else if strings.HasPrefix(textRequest.Model, "PaLM") { + case common.ChannelTypePaLM: apiType = APITypePaLM - } else if strings.HasPrefix(textRequest.Model, "chatglm_") { + case common.ChannelTypeZhipu: apiType = APITypeZhipu } baseURL := common.ChannelBaseURLs[channelType] @@ -140,6 +141,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days case APITypePaLM: fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" + if baseURL != "" { + fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) + } apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey diff --git a/router/relay-router.go b/router/relay-router.go index 0c8e9415..c3c84d8b 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -12,7 +12,7 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth()) { - modelsRouter.GET("/", controller.ListModels) + modelsRouter.GET("", controller.ListModels) modelsRouter.GET("/:model", controller.RetrieveModel) } relayV1Router := router.Group("/v1")