diff --git a/common/constants.go b/common/constants.go index d6f9ddb2..63566ecc 100644 --- a/common/constants.go +++ b/common/constants.go @@ -196,6 +196,7 @@ const ( ChannelTypeMiniMax = 27 ChannelTypeDeepseek = 28 ChannelTypeMoonshot = 29 + ChannelTypeMistral = 30 ) var ChannelBaseURLs = []string{ @@ -229,6 +230,7 @@ var ChannelBaseURLs = []string{ "https://api.minimax.chat/v1", //27 "https://api.deepseek.com", //28 "https://api.moonshot.cn", //29 + "https://api.mistral.ai", //30 } const ( diff --git a/common/model-ratio.go b/common/model-ratio.go index b6334059..07916eef 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -163,6 +163,13 @@ func init() { "moonshot-v1-8k": {[]float64{0.8572, 0.8572}, ChannelTypeMoonshot}, // ¥0.012 / 1K tokens "moonshot-v1-32k": {[]float64{1.7143, 1.7143}, ChannelTypeMoonshot}, // ¥0.024 / 1K tokens "moonshot-v1-128k": {[]float64{4.2857, 4.2857}, ChannelTypeMoonshot}, // ¥0.06 / 1K tokens + + "open-mistral-7b": {[]float64{0.125, 0.125}, ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens + "open-mixtral-8x7b": {[]float64{0.35, 0.35}, ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens + "mistral-small-latest": {[]float64{1, 3}, ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens + "mistral-medium-latest": {[]float64{1.35, 4.05}, ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens + "mistral-large-latest": {[]float64{4, 12}, ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens + "mistral-embed": {[]float64{0.05, 0.05}, ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens } ModelRatio = make(map[string][]float64) diff --git a/controller/model.go b/controller/model.go index 2803bafe..b345e58f 100644 --- a/controller/model.go +++ b/controller/model.go @@ -58,6 +58,7 @@ func init() { common.ChannelTypeMiniMax: "MiniMax", common.ChannelTypeDeepseek: "Deepseek", common.ChannelTypeMoonshot: "Moonshot", + common.ChannelTypeMistral: "Mistral", } } diff --git a/providers/mistral/base.go b/providers/mistral/base.go new file mode 100644 index 00000000..66a8c234 --- /dev/null +++ b/providers/mistral/base.go @@ -0,0 +1,77 @@ +package mistral + +import ( + "encoding/json" + "fmt" + "net/http" + "one-api/common/requester" + "one-api/model" + "one-api/types" + + "one-api/providers/base" +) + +type MistralProviderFactory struct{} + +type MistralProvider struct { + base.BaseProvider +} + +// 创建 MistralProvider +func (f MistralProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + MistralProvider := CreateMistralProvider(channel, "https://api.mistral.ai") + return MistralProvider +} + +// 创建 MistralProvider +func CreateMistralProvider(channel *model.Channel, baseURL string) *MistralProvider { + config := getMistralConfig(baseURL) + + return &MistralProvider{ + BaseProvider: base.BaseProvider{ + Config: config, + Channel: channel, + Requester: requester.NewHTTPRequester(*channel.Proxy, RequestErrorHandle), + }, + } +} + +func getMistralConfig(baseURL string) base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: baseURL, + ChatCompletions: "/v1/chat/completions", + Embeddings: "/v1/embeddings", + } +} + +// 请求错误处理 +func RequestErrorHandle(resp *http.Response) *types.OpenAIError { + errorResponse := &MistralError{} + err := json.NewDecoder(resp.Body).Decode(errorResponse) + if err != nil { + return nil + } + + return errorHandle(errorResponse) +} + +// 错误处理 +func errorHandle(MistralError *MistralError) *types.OpenAIError { + if MistralError.Object != "error" { + return nil + } + return &types.OpenAIError{ + Message: MistralError.Message.Detail[0].errorMsg(), + Type: MistralError.Type, + } +} + +// 获取请求头 +func (p *MistralProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + p.CommonRequestHeaders(headers) + + headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key) + + return headers +} diff --git a/providers/mistral/chat.go b/providers/mistral/chat.go new file mode 100644 index 00000000..63b62427 --- /dev/null +++ b/providers/mistral/chat.go @@ -0,0 +1,137 @@ +package mistral + +import ( + "encoding/json" + "io" + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/types" + "strings" +) + +type mistralStreamHandler struct { + Usage *types.Usage + Request *types.ChatCompletionRequest +} + +func (p *MistralProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + response := &types.ChatCompletionResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, response, false) + if errWithCode != nil { + return nil, errWithCode + } + + *p.Usage = *response.Usage + + return response, nil +} + +func (p *MistralProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := &mistralStreamHandler{ + Usage: p.Usage, + Request: request, + } + + return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) +} + +func (p *MistralProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + if errWithCode != nil { + return nil, errWithCode + } + + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, request.Model) + + // 获取请求头 + headers := p.GetRequestHeaders() + + mistralRequest := convertFromChatOpenai(request) + + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(mistralRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + return req, nil +} + +func convertFromChatOpenai(request *types.ChatCompletionRequest) *MistralChatCompletionRequest { + mistralRequest := &MistralChatCompletionRequest{ + Model: request.Model, + Messages: make([]types.ChatCompletionMessage, 0, len(request.Messages)), + Temperature: request.Temperature, + MaxTokens: request.MaxTokens, + TopP: request.TopP, + N: request.N, + Stream: request.Stream, + Seed: request.Seed, + } + + for _, message := range request.Messages { + mistralRequest.Messages = append(mistralRequest.Messages, types.ChatCompletionMessage{ + Role: message.Role, + Content: message.StringContent(), + }) + } + + if request.Tools != nil { + mistralRequest.Tools = request.Tools + mistralRequest.ToolChoice = "auto" + } + + return mistralRequest +} + +// 转换为OpenAI聊天流式请求体 +func (h *mistralStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { + if !strings.HasPrefix(string(*rawLine), "data: ") { + *rawLine = nil + return + } + + *rawLine = (*rawLine)[6:] + + if string(*rawLine) == "[DONE]" { + errChan <- io.EOF + *rawLine = requester.StreamClosed + return + } + + mistralResponse := &ChatCompletionStreamResponse{} + err := json.Unmarshal(*rawLine, mistralResponse) + if err != nil { + errChan <- common.ErrorToOpenAIError(err) + return + } + + if mistralResponse.Usage != nil { + *h.Usage = *mistralResponse.Usage + } + + responseBody, _ := json.Marshal(mistralResponse.ChatCompletionStreamResponse) + dataChan <- string(responseBody) + +} diff --git a/providers/mistral/embeddings.go b/providers/mistral/embeddings.go new file mode 100644 index 00000000..8bcc9fbb --- /dev/null +++ b/providers/mistral/embeddings.go @@ -0,0 +1,39 @@ +package mistral + +import ( + "net/http" + "one-api/common" + "one-api/types" +) + +func (p *MistralProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) + if errWithCode != nil { + return nil, errWithCode + } + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, request.Model) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_mistral_config", http.StatusInternalServerError) + } + + // 获取请求头 + headers := p.GetRequestHeaders() + + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(request), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + defer req.Body.Close() + + mistralResponse := &types.EmbeddingResponse{} + + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, mistralResponse, false) + if errWithCode != nil { + return nil, errWithCode + } + + return mistralResponse, nil +} diff --git a/providers/mistral/type.go b/providers/mistral/type.go new file mode 100644 index 00000000..17fff033 --- /dev/null +++ b/providers/mistral/type.go @@ -0,0 +1,55 @@ +package mistral + +import ( + "encoding/json" + "one-api/types" +) + +type MistralChatCompletionRequest struct { + Model string `json:"model" binding:"required"` + Messages []types.ChatCompletionMessage `json:"messages" binding:"required"` + Temperature float64 `json:"temperature,omitempty"` // 0-1 + MaxTokens int `json:"max_tokens,omitempty"` + TopP float64 `json:"top_p,omitempty"` // 0-1 + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools []*types.ChatCompletionTool `json:"tools,omitempty"` + ToolChoice string `json:"tool_choice,omitempty"` + Seed *int `json:"seed,omitempty"` + SafePrompt bool `json:"safe_prompt,omitempty"` +} + +type MistralError struct { + Object string `json:"object"` + Type string `json:"type,omitempty"` + Message MistralErrorMessages `json:"message,omitempty"` +} + +type MistralErrorMessages struct { + Detail []MistralErrorDetail `json:"detail,omitempty"` +} + +type MistralErrorDetail struct { + Type string `json:"type"` + Loc any `json:"loc"` + Msg string `json:"msg"` + Input string `json:"input"` + Ctx any `json:"ctx"` +} + +func (m *MistralErrorDetail) errorMsg() string { + // 循环Loc,拼接成字符串 + // 返回字符串 + var errMsg string + + locStr, _ := json.Marshal(m.Loc) + + errMsg += "Loc:" + string(locStr) + "Msg:" + m.Msg + + return errMsg +} + +type ChatCompletionStreamResponse struct { + types.ChatCompletionStreamResponse + Usage *types.Usage `json:"usage,omitempty"` +} diff --git a/providers/providers.go b/providers/providers.go index 0bce6e6a..5407cb8c 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -18,6 +18,7 @@ import ( "one-api/providers/deepseek" "one-api/providers/gemini" "one-api/providers/minimax" + "one-api/providers/mistral" "one-api/providers/openai" "one-api/providers/openaisb" "one-api/providers/palm" @@ -58,6 +59,7 @@ func init() { providerFactories[common.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{} providerFactories[common.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{} providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{} + providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{} } diff --git a/web/src/constants/ChannelConstants.js b/web/src/constants/ChannelConstants.js index 16831428..04074c29 100644 --- a/web/src/constants/ChannelConstants.js +++ b/web/src/constants/ChannelConstants.js @@ -89,6 +89,12 @@ export const CHANNEL_OPTIONS = { value: 29, color: 'default' }, + 30: { + key: 30, + text: 'Mistral', + value: 30, + color: 'orange' + }, 24: { key: 24, text: 'Azure Speech', diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index 6d31a9fc..0454db72 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -181,6 +181,20 @@ const typeConfig = { test_model: 'moonshot-v1-8k' }, modelGroup: 'Moonshot' + }, + 30: { + input: { + models: [ + 'open-mistral-7b', + 'open-mixtral-8x7b', + 'mistral-small-latest', + 'mistral-medium-latest', + 'mistral-large-latest', + 'mistral-embed' + ], + test_model: 'open-mistral-7b' + }, + modelGroup: 'Mistral' } };