feat: Support Cohere

This commit is contained in:
Martial BE 2024-04-18 01:39:18 +08:00
parent 7e206c1e7e
commit 95854f5912
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
9 changed files with 400 additions and 0 deletions

View File

@ -173,6 +173,7 @@ const (
ChannelTypeLingyi = 33
ChannelTypeMidjourney = 34
ChannelTypeCloudflareAI = 35
ChannelTypeCohere = 36
)
var ChannelBaseURLs = []string{
@ -212,6 +213,7 @@ var ChannelBaseURLs = []string{
"https://api.lingyiwanwu.com", //33
"", //34
"", //35
"https://api.cohere.ai/v1", //36
}
const (

View File

@ -295,6 +295,10 @@ func GetDefaultPrice() []*Price {
"@hf/google/gemma-7b-it": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
"@hf/thebloke/llama-2-13b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
"@cf/openai/whisper": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
//$0.50 /1M TOKENS $1.50/1M TOKENS
"command-r": {[]float64{0.25, 0.75}, common.ChannelTypeCohere},
//$3 /1M TOKENS $15/1M TOKENS
"command-r-plus": {[]float64{1.5, 7.5}, common.ChannelTypeCohere},
}
var prices []*Price

84
providers/cohere/base.go Normal file
View File

@ -0,0 +1,84 @@
package cohere
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
)
type CohereProviderFactory struct{}
// 创建 CohereProvider
func (f CohereProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &CohereProvider{
BaseProvider: base.BaseProvider{
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
},
}
}
type CohereProvider struct {
base.BaseProvider
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://api.cohere.ai/v1",
ChatCompletions: "/chat",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
CohereError := &CohereError{}
err := json.NewDecoder(resp.Body).Decode(CohereError)
if err != nil {
return nil
}
return errorHandle(CohereError)
}
// 错误处理
func errorHandle(CohereError *CohereError) *types.OpenAIError {
if CohereError.Message == "" {
return nil
}
return &types.OpenAIError{
Message: CohereError.Message,
Type: "Cohere error",
}
}
// 获取请求头
func (p *CohereProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
return headers
}
func (p *CohereProvider) GetFullRequestURL(requestURL string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
return fmt.Sprintf("%s%s", baseURL, requestURL)
}
func convertRole(role string) string {
switch role {
case types.ChatMessageRoleSystem:
return "SYSTEM"
case types.ChatMessageRoleAssistant:
return "CHATBOT"
default:
return "USER"
}
}

208
providers/cohere/chat.go Normal file
View File

@ -0,0 +1,208 @@
package cohere
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/providers/base"
"one-api/types"
"strings"
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
)
type CohereStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *CohereProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
cohereResponse := &CohereResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, cohereResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
return ConvertToChatOpenai(p, cohereResponse, request)
}
func (p *CohereProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &CohereStreamHandler{
Usage: p.Usage,
Request: request,
}
eventstream.NewDecoder()
return requester.RequestStream(p.Requester, resp, chatHandler.HandlerStream)
}
func (p *CohereProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_cohere_config", http.StatusInternalServerError)
}
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
}
cohereRequest, errWithCode := ConvertFromChatOpenai(request)
if errWithCode != nil {
return nil, errWithCode
}
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(cohereRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func ConvertFromChatOpenai(request *types.ChatCompletionRequest) (*CohereRequest, *types.OpenAIErrorWithStatusCode) {
cohereRequest := CohereRequest{
Model: request.Model,
MaxTokens: request.MaxTokens,
Temperature: request.Temperature,
Stream: request.Stream,
P: request.TopP,
K: request.N,
Seed: request.Seed,
StopSequences: request.Stop,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
}
msgLen := len(request.Messages) - 1
for index, message := range request.Messages {
if index == msgLen {
cohereRequest.Message = message.StringContent()
} else {
cohereRequest.ChatHistory = append(cohereRequest.ChatHistory, ChatHistory{
Role: convertRole(message.Role),
Message: message.StringContent(),
})
}
}
return &cohereRequest, nil
}
func ConvertToChatOpenai(provider base.ProviderInterface, response *CohereResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.CohereError)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *error,
StatusCode: http.StatusBadRequest,
}
return
}
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: types.ChatMessageRoleAssistant,
Content: response.Text,
},
FinishReason: types.FinishReasonStop,
}
openaiResponse = &types.ChatCompletionResponse{
ID: response.GenerationID,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
Model: request.Model,
Usage: &types.Usage{
PromptTokens: response.Meta.BilledUnits.InputTokens,
},
}
openaiResponse.Usage.CompletionTokens = response.Meta.BilledUnits.OutputTokens + response.Meta.Tokens.SearchUnits + response.Meta.Tokens.Classifications
openaiResponse.Usage.TotalTokens = openaiResponse.Usage.PromptTokens + openaiResponse.Usage.CompletionTokens
usage := provider.GetUsage()
*usage = *openaiResponse.Usage
return openaiResponse, nil
}
// 转换为OpenAI聊天流式请求体
func (h *CohereStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "{") {
*rawLine = nil
return
}
var cohereResponse CohereStreamResponse
err := json.Unmarshal(*rawLine, &cohereResponse)
if err != nil {
errChan <- common.ErrorToOpenAIError(err)
return
}
if cohereResponse.EventType != "text-generation" && cohereResponse.EventType != "stream-end" {
*rawLine = nil
return
}
h.convertToOpenaiStream(&cohereResponse, dataChan)
}
func (h *CohereStreamHandler) convertToOpenaiStream(cohereResponse *CohereStreamResponse, dataChan chan string) {
choice := types.ChatCompletionStreamChoice{
Index: 0,
}
if cohereResponse.EventType == "stream-end" {
choice.FinishReason = types.FinishReasonStop
} else {
choice.Delta = types.ChatCompletionStreamChoiceDelta{
Role: types.ChatMessageRoleAssistant,
Content: cohereResponse.Text,
}
}
chatCompletion := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: h.Request.Model,
Choices: []types.ChatCompletionStreamChoice{choice},
}
responseBody, _ := json.Marshal(chatCompletion)
dataChan <- string(responseBody)
}

85
providers/cohere/type.go Normal file
View File

@ -0,0 +1,85 @@
package cohere
import "one-api/types"
type ChatHistory struct {
Role string `json:"role"`
Message string `json:"message"`
}
type CohereConnector struct {
ID string `json:"id"`
UserAccessToken string `json:"user_access_token,omitempty"`
ContinueOnFailure bool `json:"continue_on_failure,omitempty"`
Options any `json:"options,omitempty"`
}
type CohereRequest struct {
Message string `json:"message"`
Model string `json:"model,omitempty"`
Stream bool `json:"stream,omitempty"`
Preamble string `json:"preamble,omitempty"`
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
ConversationId string `json:"conversation_id,omitempty"`
PromptTruncation string `json:"prompt_truncation,omitempty"`
Connectors []CohereConnector `json:"connectors,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxInputTokens int `json:"max_input_tokens,omitempty"`
K int `json:"k,omitempty"`
P float64 `json:"p,omitempty"`
Seed *int `json:"seed,omitempty"`
StopSequences any `json:"stop_sequences,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Tools []*types.ChatCompletionFunction `json:"tools,omitempty"`
ToolResults any `json:"tool_results,omitempty"`
// SearchQueriesOnly bool `json:"search_queries_only,omitempty"`
}
type APIVersion struct {
Version string `json:"version"`
}
type Tokens struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
SearchUnits int `json:"search_units,omitempty"`
Classifications int `json:"classifications,omitempty"`
}
type Meta struct {
APIVersion APIVersion `json:"api_version"`
BilledUnits Tokens `json:"billed_units"`
Tokens Tokens `json:"tokens"`
}
type CohereToolCall struct {
Name string `json:"name,omitempty"`
Parameters any `json:"parameters,omitempty"`
}
type CohereResponse struct {
Text string `json:"text,omitempty"`
ResponseID string `json:"response_id,omitempty"`
GenerationID string `json:"generation_id,omitempty"`
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
Meta Meta `json:"meta,omitempty"`
CohereError
}
type CohereError struct {
Message string `json:"message,omitempty"`
}
type CohereStreamResponse struct {
IsFinished bool `json:"is_finished"`
EventType string `json:"event_type"`
GenerationID string `json:"generation_id,omitempty"`
Text string `json:"text,omitempty"`
Response CohereResponse `json:"response,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
}

View File

@ -12,6 +12,7 @@ import (
"one-api/providers/bedrock"
"one-api/providers/claude"
"one-api/providers/cloudflareAI"
"one-api/providers/cohere"
"one-api/providers/deepseek"
"one-api/providers/gemini"
"one-api/providers/groq"
@ -56,6 +57,7 @@ func init() {
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{}
}

View File

@ -26,5 +26,6 @@ func init() {
common.ChannelTypeLingyi: "Lingyiwanwu",
common.ChannelTypeMidjourney: "Midjourney",
common.ChannelTypeCloudflareAI: "Cloudflare AI",
common.ChannelTypeCohere: "Cohere",
}
}

View File

@ -146,6 +146,13 @@ export const CHANNEL_OPTIONS = {
color: 'orange',
url: ''
},
36: {
key: 36,
text: 'Cohere',
value: 36,
color: 'default',
url: ''
},
24: {
key: 24,
text: 'Azure Speech',

View File

@ -283,6 +283,13 @@ const typeConfig = {
base_url: ''
},
modelGroup: 'Cloudflare AI'
},
36: {
input: {
models: ['command-r', 'command-r-plus'],
test_model: 'command-r'
},
modelGroup: 'Cohere'
}
};