diff --git a/common/constants.go b/common/constants.go index 7b09152b..99e94f42 100644 --- a/common/constants.go +++ b/common/constants.go @@ -173,6 +173,7 @@ const ( ChannelTypeLingyi = 33 ChannelTypeMidjourney = 34 ChannelTypeCloudflareAI = 35 + ChannelTypeCohere = 36 ) var ChannelBaseURLs = []string{ @@ -212,6 +213,7 @@ var ChannelBaseURLs = []string{ "https://api.lingyiwanwu.com", //33 "", //34 "", //35 + "https://api.cohere.ai/v1", //36 } const ( diff --git a/model/price.go b/model/price.go index dc44820f..4937bc89 100644 --- a/model/price.go +++ b/model/price.go @@ -295,6 +295,10 @@ func GetDefaultPrice() []*Price { "@hf/google/gemma-7b-it": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, "@hf/thebloke/llama-2-13b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, "@cf/openai/whisper": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, + //$0.50 /1M TOKENS $1.50/1M TOKENS + "command-r": {[]float64{0.25, 0.75}, common.ChannelTypeCohere}, + //$3 /1M TOKENS $15/1M TOKENS + "command-r-plus": {[]float64{1.5, 7.5}, common.ChannelTypeCohere}, } var prices []*Price diff --git a/providers/cohere/base.go b/providers/cohere/base.go new file mode 100644 index 00000000..352657a5 --- /dev/null +++ b/providers/cohere/base.go @@ -0,0 +1,84 @@ +package cohere + +import ( + "encoding/json" + "fmt" + "net/http" + "one-api/common/requester" + "one-api/model" + "one-api/providers/base" + "one-api/types" + "strings" +) + +type CohereProviderFactory struct{} + +// 创建 CohereProvider +func (f CohereProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + return &CohereProvider{ + BaseProvider: base.BaseProvider{ + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle), + }, + } +} + +type CohereProvider struct { + base.BaseProvider +} + +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://api.cohere.ai/v1", + ChatCompletions: "/chat", + } +} + +// 请求错误处理 +func requestErrorHandle(resp *http.Response) *types.OpenAIError { + CohereError := &CohereError{} + err := json.NewDecoder(resp.Body).Decode(CohereError) + if err != nil { + return nil + } + + return errorHandle(CohereError) +} + +// 错误处理 +func errorHandle(CohereError *CohereError) *types.OpenAIError { + if CohereError.Message == "" { + return nil + } + return &types.OpenAIError{ + Message: CohereError.Message, + Type: "Cohere error", + } +} + +// 获取请求头 +func (p *CohereProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + p.CommonRequestHeaders(headers) + headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key) + + return headers +} + +func (p *CohereProvider) GetFullRequestURL(requestURL string) string { + baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") + + return fmt.Sprintf("%s%s", baseURL, requestURL) +} + +func convertRole(role string) string { + switch role { + case types.ChatMessageRoleSystem: + return "SYSTEM" + case types.ChatMessageRoleAssistant: + return "CHATBOT" + default: + return "USER" + } +} diff --git a/providers/cohere/chat.go b/providers/cohere/chat.go new file mode 100644 index 00000000..14eb8bfe --- /dev/null +++ b/providers/cohere/chat.go @@ -0,0 +1,208 @@ +package cohere + +import ( + "encoding/json" + "fmt" + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/providers/base" + "one-api/types" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" +) + +type CohereStreamHandler struct { + Usage *types.Usage + Request *types.ChatCompletionRequest +} + +func (p *CohereProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + cohereResponse := &CohereResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, cohereResponse, false) + if errWithCode != nil { + return nil, errWithCode + } + + return ConvertToChatOpenai(p, cohereResponse, request) +} + +func (p *CohereProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := &CohereStreamHandler{ + Usage: p.Usage, + Request: request, + } + + eventstream.NewDecoder() + + return requester.RequestStream(p.Requester, resp, chatHandler.HandlerStream) +} + +func (p *CohereProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + if errWithCode != nil { + return nil, errWithCode + } + + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_cohere_config", http.StatusInternalServerError) + } + + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + cohereRequest, errWithCode := ConvertFromChatOpenai(request) + if errWithCode != nil { + return nil, errWithCode + } + + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(cohereRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + return req, nil +} + +func ConvertFromChatOpenai(request *types.ChatCompletionRequest) (*CohereRequest, *types.OpenAIErrorWithStatusCode) { + cohereRequest := CohereRequest{ + Model: request.Model, + MaxTokens: request.MaxTokens, + Temperature: request.Temperature, + Stream: request.Stream, + P: request.TopP, + K: request.N, + Seed: request.Seed, + StopSequences: request.Stop, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, + } + + msgLen := len(request.Messages) - 1 + + for index, message := range request.Messages { + if index == msgLen { + cohereRequest.Message = message.StringContent() + } else { + cohereRequest.ChatHistory = append(cohereRequest.ChatHistory, ChatHistory{ + Role: convertRole(message.Role), + Message: message.StringContent(), + }) + } + + } + + return &cohereRequest, nil +} + +func ConvertToChatOpenai(provider base.ProviderInterface, response *CohereResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + error := errorHandle(&response.CohereError) + if error != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *error, + StatusCode: http.StatusBadRequest, + } + return + } + + choice := types.ChatCompletionChoice{ + Index: 0, + Message: types.ChatCompletionMessage{ + Role: types.ChatMessageRoleAssistant, + Content: response.Text, + }, + FinishReason: types.FinishReasonStop, + } + openaiResponse = &types.ChatCompletionResponse{ + ID: response.GenerationID, + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []types.ChatCompletionChoice{choice}, + Model: request.Model, + Usage: &types.Usage{ + PromptTokens: response.Meta.BilledUnits.InputTokens, + }, + } + + openaiResponse.Usage.CompletionTokens = response.Meta.BilledUnits.OutputTokens + response.Meta.Tokens.SearchUnits + response.Meta.Tokens.Classifications + openaiResponse.Usage.TotalTokens = openaiResponse.Usage.PromptTokens + openaiResponse.Usage.CompletionTokens + + usage := provider.GetUsage() + *usage = *openaiResponse.Usage + + return openaiResponse, nil +} + +// 转换为OpenAI聊天流式请求体 +func (h *CohereStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { + // 如果rawLine 前缀不为data:,则直接返回 + if !strings.HasPrefix(string(*rawLine), "{") { + *rawLine = nil + return + } + + var cohereResponse CohereStreamResponse + err := json.Unmarshal(*rawLine, &cohereResponse) + if err != nil { + errChan <- common.ErrorToOpenAIError(err) + return + } + + if cohereResponse.EventType != "text-generation" && cohereResponse.EventType != "stream-end" { + *rawLine = nil + return + } + + h.convertToOpenaiStream(&cohereResponse, dataChan) +} + +func (h *CohereStreamHandler) convertToOpenaiStream(cohereResponse *CohereStreamResponse, dataChan chan string) { + choice := types.ChatCompletionStreamChoice{ + Index: 0, + } + + if cohereResponse.EventType == "stream-end" { + choice.FinishReason = types.FinishReasonStop + } else { + choice.Delta = types.ChatCompletionStreamChoiceDelta{ + Role: types.ChatMessageRoleAssistant, + Content: cohereResponse.Text, + } + } + + chatCompletion := types.ChatCompletionStreamResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: h.Request.Model, + Choices: []types.ChatCompletionStreamChoice{choice}, + } + + responseBody, _ := json.Marshal(chatCompletion) + dataChan <- string(responseBody) +} diff --git a/providers/cohere/type.go b/providers/cohere/type.go new file mode 100644 index 00000000..c8fa9852 --- /dev/null +++ b/providers/cohere/type.go @@ -0,0 +1,85 @@ +package cohere + +import "one-api/types" + +type ChatHistory struct { + Role string `json:"role"` + Message string `json:"message"` +} + +type CohereConnector struct { + ID string `json:"id"` + UserAccessToken string `json:"user_access_token,omitempty"` + ContinueOnFailure bool `json:"continue_on_failure,omitempty"` + Options any `json:"options,omitempty"` +} + +type CohereRequest struct { + Message string `json:"message"` + Model string `json:"model,omitempty"` + Stream bool `json:"stream,omitempty"` + Preamble string `json:"preamble,omitempty"` + ChatHistory []ChatHistory `json:"chat_history,omitempty"` + ConversationId string `json:"conversation_id,omitempty"` + PromptTruncation string `json:"prompt_truncation,omitempty"` + Connectors []CohereConnector `json:"connectors,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxInputTokens int `json:"max_input_tokens,omitempty"` + K int `json:"k,omitempty"` + P float64 `json:"p,omitempty"` + Seed *int `json:"seed,omitempty"` + StopSequences any `json:"stop_sequences,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Tools []*types.ChatCompletionFunction `json:"tools,omitempty"` + ToolResults any `json:"tool_results,omitempty"` + // SearchQueriesOnly bool `json:"search_queries_only,omitempty"` +} + +type APIVersion struct { + Version string `json:"version"` +} + +type Tokens struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + SearchUnits int `json:"search_units,omitempty"` + Classifications int `json:"classifications,omitempty"` +} + +type Meta struct { + APIVersion APIVersion `json:"api_version"` + BilledUnits Tokens `json:"billed_units"` + Tokens Tokens `json:"tokens"` +} + +type CohereToolCall struct { + Name string `json:"name,omitempty"` + Parameters any `json:"parameters,omitempty"` +} + +type CohereResponse struct { + Text string `json:"text,omitempty"` + ResponseID string `json:"response_id,omitempty"` + GenerationID string `json:"generation_id,omitempty"` + ChatHistory []ChatHistory `json:"chat_history,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + ToolCalls []CohereToolCall `json:"tool_calls,omitempty"` + Meta Meta `json:"meta,omitempty"` + CohereError +} + +type CohereError struct { + Message string `json:"message,omitempty"` +} + +type CohereStreamResponse struct { + IsFinished bool `json:"is_finished"` + EventType string `json:"event_type"` + GenerationID string `json:"generation_id,omitempty"` + Text string `json:"text,omitempty"` + Response CohereResponse `json:"response,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + ToolCalls []CohereToolCall `json:"tool_calls,omitempty"` +} diff --git a/providers/providers.go b/providers/providers.go index 875ed859..149460c1 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -12,6 +12,7 @@ import ( "one-api/providers/bedrock" "one-api/providers/claude" "one-api/providers/cloudflareAI" + "one-api/providers/cohere" "one-api/providers/deepseek" "one-api/providers/gemini" "one-api/providers/groq" @@ -56,6 +57,7 @@ func init() { providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{} providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{} providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{} + providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{} } diff --git a/relay/util/type.go b/relay/util/type.go index d644543c..6606dfad 100644 --- a/relay/util/type.go +++ b/relay/util/type.go @@ -26,5 +26,6 @@ func init() { common.ChannelTypeLingyi: "Lingyiwanwu", common.ChannelTypeMidjourney: "Midjourney", common.ChannelTypeCloudflareAI: "Cloudflare AI", + common.ChannelTypeCohere: "Cohere", } } diff --git a/web/src/constants/ChannelConstants.js b/web/src/constants/ChannelConstants.js index e66617d4..818df56c 100644 --- a/web/src/constants/ChannelConstants.js +++ b/web/src/constants/ChannelConstants.js @@ -146,6 +146,13 @@ export const CHANNEL_OPTIONS = { color: 'orange', url: '' }, + 36: { + key: 36, + text: 'Cohere', + value: 36, + color: 'default', + url: '' + }, 24: { key: 24, text: 'Azure Speech', diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index bd60f4b8..8f9dbf91 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -283,6 +283,13 @@ const typeConfig = { base_url: '' }, modelGroup: 'Cloudflare AI' + }, + 36: { + input: { + models: ['command-r', 'command-r-plus'], + test_model: 'command-r' + }, + modelGroup: 'Cohere' } };