diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index e82f4904..4267757d 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -57,10 +57,21 @@ type BaiduChatStreamResponse struct { func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { messages := make([]BaiduMessage, 0, len(request.Messages)) for _, message := range request.Messages { - messages = append(messages, BaiduMessage{ - Role: message.Role, - Content: message.Content, - }) + if message.Role == "system" { + messages = append(messages, BaiduMessage{ + Role: "user", + Content: message.Content, + }) + messages = append(messages, BaiduMessage{ + Role: "assistant", + Content: "Okay", + }) + } else { + messages = append(messages, BaiduMessage{ + Role: message.Role, + Content: message.Content, + }) + } } return &BaiduChatRequest{ Messages: messages, diff --git a/controller/relay-claude.go b/controller/relay-claude.go index 22f41cef..1d67fa7b 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -69,11 +69,11 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) } else if message.Role == "assistant" { prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) - } else { - // ignore other roles + } else if message.Role == "system" { + prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) } - prompt += "\n\nAssistant:" } + prompt += "\n\nAssistant:" claudeRequest.Prompt = prompt return &claudeRequest } diff --git a/controller/relay-text.go b/controller/relay-text.go index 8dfdf6e1..52e10f2b 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -85,13 +85,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } } apiType := APITypeOpenAI - if strings.HasPrefix(textRequest.Model, "claude") { + switch channelType { + case common.ChannelTypeAnthropic: apiType = APITypeClaude - } else if strings.HasPrefix(textRequest.Model, "ERNIE") { + case common.ChannelTypeBaidu: apiType = APITypeBaidu - } else if strings.HasPrefix(textRequest.Model, "PaLM") { + case common.ChannelTypePaLM: apiType = APITypePaLM - } else if strings.HasPrefix(textRequest.Model, "chatglm_") { + case common.ChannelTypeZhipu: apiType = APITypeZhipu } baseURL := common.ChannelBaseURLs[channelType] @@ -140,6 +141,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days case APITypePaLM: fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" + if baseURL != "" { + fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) + } apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go index 349f4742..33d141c7 100644 --- a/controller/relay-zhipu.go +++ b/controller/relay-zhipu.go @@ -111,10 +111,21 @@ func getZhipuToken(apikey string) string { func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { messages := make([]ZhipuMessage, 0, len(request.Messages)) for _, message := range request.Messages { - messages = append(messages, ZhipuMessage{ - Role: message.Role, - Content: message.Content, - }) + if message.Role == "system" { + messages = append(messages, ZhipuMessage{ + Role: "system", + Content: message.Content, + }) + messages = append(messages, ZhipuMessage{ + Role: "user", + Content: "Okay", + }) + } else { + messages = append(messages, ZhipuMessage{ + Role: message.Role, + Content: message.Content, + }) + } } return &ZhipuRequest{ Prompt: messages, diff --git a/router/relay-router.go b/router/relay-router.go index 0c8e9415..c3c84d8b 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -12,7 +12,7 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth()) { - modelsRouter.GET("/", controller.ListModels) + modelsRouter.GET("", controller.ListModels) modelsRouter.GET("/:model", controller.RetrieveModel) } relayV1Router := router.Group("/v1")