From 05347bc9a153936a781b83cb7f1c60e4afb311ae Mon Sep 17 00:00:00 2001 From: Buer <42402987+MartialBE@users.noreply.github.com> Date: Wed, 13 Mar 2024 15:36:31 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20support=20Groq=20(#107)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/constants.go | 2 + common/model-ratio.go | 8 +++ controller/model.go | 1 + providers/groq/base.go | 35 ++++++++++++ providers/groq/chat.go | 82 +++++++++++++++++++++++++++ providers/providers.go | 2 + types/chat.go | 3 +- web/src/constants/ChannelConstants.js | 6 ++ web/src/views/Channel/type/Config.js | 7 +++ 9 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 providers/groq/base.go create mode 100644 providers/groq/chat.go diff --git a/common/constants.go b/common/constants.go index 63566ecc..88b007d2 100644 --- a/common/constants.go +++ b/common/constants.go @@ -197,6 +197,7 @@ const ( ChannelTypeDeepseek = 28 ChannelTypeMoonshot = 29 ChannelTypeMistral = 30 + ChannelTypeGroq = 31 ) var ChannelBaseURLs = []string{ @@ -231,6 +232,7 @@ var ChannelBaseURLs = []string{ "https://api.deepseek.com", //28 "https://api.moonshot.cn", //29 "https://api.mistral.ai", //30 + "https://api.groq.com/openai", //30 } const ( diff --git a/common/model-ratio.go b/common/model-ratio.go index 8db6920f..a84e9803 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -183,6 +183,14 @@ func init() { "mistral-medium-latest": {[]float64{1.35, 4.05}, ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens "mistral-large-latest": {[]float64{4, 12}, ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens "mistral-embed": {[]float64{0.05, 0.05}, ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens + + // $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens + "llama2-70b-4096": {[]float64{0.35, 0.4}, ChannelTypeGroq}, + // $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens + "llama2-7b-2048": {[]float64{0.05, 0.05}, ChannelTypeGroq}, + "gemma-7b-it": {[]float64{0.05, 0.05}, ChannelTypeGroq}, + // $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens + "mixtral-8x7b-32768": {[]float64{0.135, 0.135}, ChannelTypeGroq}, } ModelRatio = make(map[string][]float64) diff --git a/controller/model.go b/controller/model.go index b345e58f..0fc90618 100644 --- a/controller/model.go +++ b/controller/model.go @@ -59,6 +59,7 @@ func init() { common.ChannelTypeDeepseek: "Deepseek", common.ChannelTypeMoonshot: "Moonshot", common.ChannelTypeMistral: "Mistral", + common.ChannelTypeGroq: "Groq", } } diff --git a/providers/groq/base.go b/providers/groq/base.go new file mode 100644 index 00000000..0152b8fd --- /dev/null +++ b/providers/groq/base.go @@ -0,0 +1,35 @@ +package groq + +import ( + "one-api/common/requester" + "one-api/model" + "one-api/providers/base" + "one-api/providers/openai" +) + +// 定义供应商工厂 +type GroqProviderFactory struct{} + +// 创建 GroqProvider +func (f GroqProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + return &GroqProvider{ + OpenAIProvider: openai.OpenAIProvider{ + BaseProvider: base.BaseProvider{ + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(*channel.Proxy, openai.RequestErrorHandle), + }, + }, + } +} + +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://api.groq.com/openai", + ChatCompletions: "/v1/chat/completions", + } +} + +type GroqProvider struct { + openai.OpenAIProvider +} diff --git a/providers/groq/chat.go b/providers/groq/chat.go new file mode 100644 index 00000000..26eb008a --- /dev/null +++ b/providers/groq/chat.go @@ -0,0 +1,82 @@ +package groq + +import ( + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/providers/openai" + "one-api/types" +) + +func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + p.getChatRequestBody(request) + + req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + response := &openai.OpenAIProviderChatResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, response, false) + if errWithCode != nil { + return nil, errWithCode + } + + // 检测是否错误 + openaiErr := openai.ErrorHandle(&response.OpenAIErrorResponse) + if openaiErr != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *openaiErr, + StatusCode: http.StatusBadRequest, + } + return nil, errWithCode + } + + *p.Usage = *response.Usage + + return &response.ChatCompletionResponse, nil +} + +func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { + p.getChatRequestBody(request) + req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := openai.OpenAIStreamHandler{ + Usage: p.Usage, + ModelName: request.Model, + } + + return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream) +} + +// 获取聊天请求体 +func (p *GroqProvider) getChatRequestBody(request *types.ChatCompletionRequest) { + if request.Tools != nil { + request.Tools = nil + } + + if request.ToolChoice != nil { + request.ToolChoice = nil + } + + if request.ResponseFormat != nil { + request.ResponseFormat = nil + } + + if request.N > 1 { + request.N = 1 + } + +} diff --git a/providers/providers.go b/providers/providers.go index 5407cb8c..4c1c11fe 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -17,6 +17,7 @@ import ( "one-api/providers/closeai" "one-api/providers/deepseek" "one-api/providers/gemini" + "one-api/providers/groq" "one-api/providers/minimax" "one-api/providers/mistral" "one-api/providers/openai" @@ -60,6 +61,7 @@ func init() { providerFactories[common.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{} providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{} providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{} + providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{} } diff --git a/types/chat.go b/types/chat.go index 63e2515a..718b47a0 100644 --- a/types/chat.go +++ b/types/chat.go @@ -140,7 +140,7 @@ type ChatCompletionRequest struct { Seed *int `json:"seed,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` LogitBias any `json:"logit_bias,omitempty"` - LogProbs bool `json:"logprobs,omitempty"` + LogProbs *bool `json:"logprobs,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"` User string `json:"user,omitempty"` Functions []*ChatCompletionFunction `json:"functions,omitempty"` @@ -172,6 +172,7 @@ type ChatCompletionTool struct { type ChatCompletionChoice struct { Index int `json:"index"` Message ChatCompletionMessage `json:"message"` + LogProbs any `json:"logprobs,omitempty"` FinishReason any `json:"finish_reason,omitempty"` ContentFilterResults any `json:"content_filter_results,omitempty"` FinishDetails any `json:"finish_details,omitempty"` diff --git a/web/src/constants/ChannelConstants.js b/web/src/constants/ChannelConstants.js index 04074c29..f4f70298 100644 --- a/web/src/constants/ChannelConstants.js +++ b/web/src/constants/ChannelConstants.js @@ -95,6 +95,12 @@ export const CHANNEL_OPTIONS = { value: 30, color: 'orange' }, + 31: { + key: 31, + text: 'Groq', + value: 31, + color: 'primary' + }, 24: { key: 24, text: 'Azure Speech', diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index 88c4a3c1..813fcf3d 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -195,6 +195,13 @@ const typeConfig = { test_model: 'open-mistral-7b' }, modelGroup: 'Mistral' + }, + 31: { + input: { + models: ['llama2-7b-2048', 'llama2-70b-4096', 'mixtral-8x7b-32768', 'gemma-7b-it'], + test_model: 'llama2-7b-2048' + }, + modelGroup: 'Groq' } };