feat: support mistral (#94)

This commit is contained in:
Buer 2024-03-10 01:53:33 +08:00 committed by GitHub
parent d8d880bf85
commit 6329db1a49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 340 additions and 0 deletions

View File

@ -196,6 +196,7 @@ const (
ChannelTypeMiniMax = 27
ChannelTypeDeepseek = 28
ChannelTypeMoonshot = 29
ChannelTypeMistral = 30
)
var ChannelBaseURLs = []string{
@ -229,6 +230,7 @@ var ChannelBaseURLs = []string{
"https://api.minimax.chat/v1", //27
"https://api.deepseek.com", //28
"https://api.moonshot.cn", //29
"https://api.mistral.ai", //30
}
const (

View File

@ -163,6 +163,13 @@ func init() {
"moonshot-v1-8k": {[]float64{0.8572, 0.8572}, ChannelTypeMoonshot}, // ¥0.012 / 1K tokens
"moonshot-v1-32k": {[]float64{1.7143, 1.7143}, ChannelTypeMoonshot}, // ¥0.024 / 1K tokens
"moonshot-v1-128k": {[]float64{4.2857, 4.2857}, ChannelTypeMoonshot}, // ¥0.06 / 1K tokens
"open-mistral-7b": {[]float64{0.125, 0.125}, ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens
"open-mixtral-8x7b": {[]float64{0.35, 0.35}, ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens
"mistral-small-latest": {[]float64{1, 3}, ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens
"mistral-medium-latest": {[]float64{1.35, 4.05}, ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens
"mistral-large-latest": {[]float64{4, 12}, ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens
"mistral-embed": {[]float64{0.05, 0.05}, ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens
}
ModelRatio = make(map[string][]float64)

View File

@ -58,6 +58,7 @@ func init() {
common.ChannelTypeMiniMax: "MiniMax",
common.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral",
}
}

77
providers/mistral/base.go Normal file
View File

@ -0,0 +1,77 @@
package mistral
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/types"
"one-api/providers/base"
)
type MistralProviderFactory struct{}
type MistralProvider struct {
base.BaseProvider
}
// 创建 MistralProvider
func (f MistralProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
MistralProvider := CreateMistralProvider(channel, "https://api.mistral.ai")
return MistralProvider
}
// 创建 MistralProvider
func CreateMistralProvider(channel *model.Channel, baseURL string) *MistralProvider {
config := getMistralConfig(baseURL)
return &MistralProvider{
BaseProvider: base.BaseProvider{
Config: config,
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, RequestErrorHandle),
},
}
}
func getMistralConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
}
}
// 请求错误处理
func RequestErrorHandle(resp *http.Response) *types.OpenAIError {
errorResponse := &MistralError{}
err := json.NewDecoder(resp.Body).Decode(errorResponse)
if err != nil {
return nil
}
return errorHandle(errorResponse)
}
// 错误处理
func errorHandle(MistralError *MistralError) *types.OpenAIError {
if MistralError.Object != "error" {
return nil
}
return &types.OpenAIError{
Message: MistralError.Message.Detail[0].errorMsg(),
Type: MistralError.Type,
}
}
// 获取请求头
func (p *MistralProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
return headers
}

137
providers/mistral/chat.go Normal file
View File

@ -0,0 +1,137 @@
package mistral
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
)
type mistralStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *MistralProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
response := &types.ChatCompletionResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
*p.Usage = *response.Usage
return response, nil
}
func (p *MistralProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &mistralStreamHandler{
Usage: p.Usage,
Request: request,
}
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *MistralProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
// 获取请求头
headers := p.GetRequestHeaders()
mistralRequest := convertFromChatOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(mistralRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func convertFromChatOpenai(request *types.ChatCompletionRequest) *MistralChatCompletionRequest {
mistralRequest := &MistralChatCompletionRequest{
Model: request.Model,
Messages: make([]types.ChatCompletionMessage, 0, len(request.Messages)),
Temperature: request.Temperature,
MaxTokens: request.MaxTokens,
TopP: request.TopP,
N: request.N,
Stream: request.Stream,
Seed: request.Seed,
}
for _, message := range request.Messages {
mistralRequest.Messages = append(mistralRequest.Messages, types.ChatCompletionMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
if request.Tools != nil {
mistralRequest.Tools = request.Tools
mistralRequest.ToolChoice = "auto"
}
return mistralRequest
}
// 转换为OpenAI聊天流式请求体
func (h *mistralStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return
}
*rawLine = (*rawLine)[6:]
if string(*rawLine) == "[DONE]" {
errChan <- io.EOF
*rawLine = requester.StreamClosed
return
}
mistralResponse := &ChatCompletionStreamResponse{}
err := json.Unmarshal(*rawLine, mistralResponse)
if err != nil {
errChan <- common.ErrorToOpenAIError(err)
return
}
if mistralResponse.Usage != nil {
*h.Usage = *mistralResponse.Usage
}
responseBody, _ := json.Marshal(mistralResponse.ChatCompletionStreamResponse)
dataChan <- string(responseBody)
}

View File

@ -0,0 +1,39 @@
package mistral
import (
"net/http"
"one-api/common"
"one-api/types"
)
func (p *MistralProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_mistral_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(request), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
defer req.Body.Close()
mistralResponse := &types.EmbeddingResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, mistralResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
return mistralResponse, nil
}

55
providers/mistral/type.go Normal file
View File

@ -0,0 +1,55 @@
package mistral
import (
"encoding/json"
"one-api/types"
)
type MistralChatCompletionRequest struct {
Model string `json:"model" binding:"required"`
Messages []types.ChatCompletionMessage `json:"messages" binding:"required"`
Temperature float64 `json:"temperature,omitempty"` // 0-1
MaxTokens int `json:"max_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"` // 0-1
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []*types.ChatCompletionTool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
Seed *int `json:"seed,omitempty"`
SafePrompt bool `json:"safe_prompt,omitempty"`
}
type MistralError struct {
Object string `json:"object"`
Type string `json:"type,omitempty"`
Message MistralErrorMessages `json:"message,omitempty"`
}
type MistralErrorMessages struct {
Detail []MistralErrorDetail `json:"detail,omitempty"`
}
type MistralErrorDetail struct {
Type string `json:"type"`
Loc any `json:"loc"`
Msg string `json:"msg"`
Input string `json:"input"`
Ctx any `json:"ctx"`
}
func (m *MistralErrorDetail) errorMsg() string {
// 循环Loc拼接成字符串
// 返回字符串
var errMsg string
locStr, _ := json.Marshal(m.Loc)
errMsg += "Loc:" + string(locStr) + "Msg:" + m.Msg
return errMsg
}
type ChatCompletionStreamResponse struct {
types.ChatCompletionStreamResponse
Usage *types.Usage `json:"usage,omitempty"`
}

View File

@ -18,6 +18,7 @@ import (
"one-api/providers/deepseek"
"one-api/providers/gemini"
"one-api/providers/minimax"
"one-api/providers/mistral"
"one-api/providers/openai"
"one-api/providers/openaisb"
"one-api/providers/palm"
@ -58,6 +59,7 @@ func init() {
providerFactories[common.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{}
providerFactories[common.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{}
providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{}
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
}

View File

@ -89,6 +89,12 @@ export const CHANNEL_OPTIONS = {
value: 29,
color: 'default'
},
30: {
key: 30,
text: 'Mistral',
value: 30,
color: 'orange'
},
24: {
key: 24,
text: 'Azure Speech',

View File

@ -181,6 +181,20 @@ const typeConfig = {
test_model: 'moonshot-v1-8k'
},
modelGroup: 'Moonshot'
},
30: {
input: {
models: [
'open-mistral-7b',
'open-mixtral-8x7b',
'mistral-small-latest',
'mistral-medium-latest',
'mistral-large-latest',
'mistral-embed'
],
test_model: 'open-mistral-7b'
},
modelGroup: 'Mistral'
}
};