diff --git a/common/constants.go b/common/constants.go index 4b9df311..fdafcd2e 100644 --- a/common/constants.go +++ b/common/constants.go @@ -173,6 +173,7 @@ const ( ChannelTypeZhipu = 16 ChannelTypeAli = 17 ChannelTypeXunfei = 18 + ChannelTypeMiniMax = 19 ) var ChannelBaseURLs = []string{ @@ -195,4 +196,5 @@ var ChannelBaseURLs = []string{ "https://open.bigmodel.cn", // 16 "https://dashscope.aliyuncs.com", // 17 "", // 18 + "", // 19 } diff --git a/common/model-ratio.go b/common/model-ratio.go index 5865b4dc..787690b2 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -50,6 +50,9 @@ var ModelRatio = map[string]float64{ "qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag "qwen-plus-v1": 0.5715, // Same as above "SparkDesk": 0.8572, // TBD + "abab5.5-chat": 1.0715, // ¥0.014 / 1k tokens + "abab5-chat": 1.0715, // ¥0.014 / 1k tokens + "embo-01": 0.75, // TBD: https://api.minimax.chat/document/price?id=6433f32294878d408fc8293e } func ModelRatio2JSONString() string { diff --git a/common/validate.go b/common/validate.go index b3c78591..4e8884b3 100644 --- a/common/validate.go +++ b/common/validate.go @@ -1,9 +1,44 @@ package common -import "github.com/go-playground/validator/v10" +import ( + "github.com/go-playground/validator/v10" + "reflect" +) var Validate *validator.Validate func init() { Validate = validator.New() + _ = Validate.RegisterValidation("ValidateEmbeddingInput", validateEmbeddingInput) +} + +func validateEmbeddingInput(fl validator.FieldLevel) bool { + v := fl.Field() + var check func(v reflect.Value, mustBe reflect.Kind) bool + check = func(v reflect.Value, mustBe reflect.Kind) bool { + if mustBe != reflect.Invalid && v.Kind() != mustBe { + return false + } + switch v.Kind() { + case reflect.String: + return true + case reflect.Array, reflect.Slice: + if v.Len() == 0 { + return false + } + for i := 0; i < v.Len(); i++ { + checkResult := check(v.Index(i), reflect.String) + if v.Index(i).Kind() == reflect.Interface || v.Index(i).Kind() == reflect.Ptr { + checkResult = checkResult || check(v.Index(i).Elem(), reflect.String) + } + if !checkResult { + return false + } + } + default: + return false + } + return true + } + return check(v, reflect.Invalid) } diff --git a/controller/model.go b/controller/model.go index c68aa50c..e6900181 100644 --- a/controller/model.go +++ b/controller/model.go @@ -360,6 +360,33 @@ func init() { Root: "SparkDesk", Parent: nil, }, + { + Id: "abab5.5-chat", + Object: "model", + Created: 1677649963, + OwnedBy: "minimax", + Permission: permission, + Root: "abab5.5-chat", + Parent: nil, + }, + { + Id: "abab5-chat", + Object: "model", + Created: 1677649963, + OwnedBy: "minimax", + Permission: permission, + Root: "abab5-chat", + Parent: nil, + }, + { + Id: "embo-01", + Object: "model", + Created: 1677649963, + OwnedBy: "minimax", + Permission: permission, + Root: "embo-01", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay-minimax.go b/controller/relay-minimax.go new file mode 100644 index 00000000..6d57ded9 --- /dev/null +++ b/controller/relay-minimax.go @@ -0,0 +1,326 @@ +package controller + +import ( + "bufio" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "reflect" + "strings" +) + +// https://api.minimax.chat/document/guides/chat?id=6433f37294878d408fc82953 + +type MinimaxError struct { + StatusCode int `json:"status_code"` + StatusMsg string `json:"status_msg"` +} + +type MinimaxChatMessage struct { + SenderType string `json:"sender_type,omitempty"` //USER or BOT + Text string `json:"text,omitempty"` +} + +type MinimaxChatRequest struct { + Model string `json:"model,omitempty"` + Stream bool `json:"stream,omitempty"` + Prompt string `json:"prompt,omitempty"` + Messages []MinimaxChatMessage `json:"messages,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` +} + +type MinimaxChoice struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` +} + +type MinimaxStreamChoice struct { + Delta string `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +type MinimaxChatResponse struct { + Id string `json:"id"` + Created int64 `json:"created"` + Choices []MinimaxChoice `json:"choices"` + Usage `json:"usage"` + BaseResp MinimaxError `json:"base_resp"` +} + +type MinimaxChatStreamResponse struct { + Id string `json:"id"` + Created int64 `json:"created"` + Choices []MinimaxStreamChoice `json:"choices"` + Usage `json:"usage"` + BaseResp MinimaxError `json:"base_resp"` +} + +type MinimaxEmbeddingRequest struct { + Model string `json:"model,omitempty"` + Texts []string `json:"texts,omitempty"` //upper bound: 4096 tokens + Type string `json:"type,omitempty"` // + // must choose one of the cases: {"db", "query"}; + // because of the default meaning of embedding request is "Creates an embedding vector representing the input text" + // so we default use the "db" input to generate texts' embedding vector + // for the "query" input, we will support later + // Refer: https://api.minimax.chat/document/guides/embeddings?id=6464722084cdc277dfaa966a#%E6%8E%A5%E5%8F%A3%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E +} + +type MinimaxEmbeddingResponse struct { + Vectors [][]float64 `json:"vectors"` + BaseResp MinimaxError `json:"base_resp"` +} + +func openAIMsgRoleToMinimaxMsgRole(input string) string { + if input == "user" { + return "USER" + } else { + return "BOT" + } +} + +func requestOpenAI2Minimax(request GeneralOpenAIRequest) *MinimaxChatRequest { + messages := make([]MinimaxChatMessage, 0, len(request.Messages)) + prompt := "" + for _, message := range request.Messages { + if message.Role == "system" { + prompt += message.Content + } else { + messages = append(messages, MinimaxChatMessage{ + SenderType: openAIMsgRoleToMinimaxMsgRole(message.Role), + Text: message.Content, + }) + } + } + return &MinimaxChatRequest{ + Model: request.Model, + Stream: request.Stream, + Messages: messages, + Prompt: prompt, + Temperature: request.Temperature, + TopP: request.TopP, + } +} + +func responseMinimaxChat2OpenAI(response *MinimaxChatResponse) *OpenAITextResponse { + ans := OpenAITextResponse{ + Id: response.Id, + Object: "", + Created: response.Created, + Choices: make([]OpenAITextResponseChoice, 0, len(response.Choices)), + Usage: response.Usage, + } + for _, choice := range response.Choices { + ans.Choices = append(ans.Choices, OpenAITextResponseChoice{ + Index: choice.Index, + Message: Message{ + Role: "assistant", + Content: choice.Text, + }, + FinishReason: choice.FinishReason, + }) + } + return &ans +} + +func streamResponseMinimaxChat2OpenAI(response *MinimaxChatStreamResponse) *ChatCompletionsStreamResponse { + ans := ChatCompletionsStreamResponse{ + Id: response.Id, + Object: "chat.completion.chunk", + Created: response.Created, + Model: "abab", //"abab5.5-chat", "abab5-chat" + Choices: make([]ChatCompletionsStreamResponseChoice, 0, len(response.Choices)), + } + for i := range response.Choices { + choice := response.Choices[i] + ans.Choices = append(ans.Choices, ChatCompletionsStreamResponseChoice{ + Delta: struct { + Content string `json:"content"` + }{ + Content: choice.Delta, + }, + FinishReason: &choice.FinishReason, + }) + } + return &ans +} + +func embeddingRequestOpenAI2Minimax(request GeneralOpenAIRequest) *MinimaxEmbeddingRequest { + texts := make([]string, 0, 100) + v := reflect.ValueOf(request.Input) + switch v.Kind() { + case reflect.String: + texts = []string{v.Interface().(string)} + case reflect.Array, reflect.Slice: + for i := 0; i < v.Len(); i++ { + texts = append(texts, v.Index(i).Interface().(string)) + } + } + ans := MinimaxEmbeddingRequest{ + Model: request.Model, + Texts: texts, + Type: "db", + } + return &ans +} + +func embeddingResponseMinimax2OpenAI(response *MinimaxEmbeddingResponse) *OpenAIEmbeddingResponse { + ans := OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Vectors)), + Model: "minimax-embedding", + } + for i, vector := range response.Vectors { + ans.Data = append(ans.Data, OpenAIEmbeddingResponseItem{ + Object: "embedding", + Index: i, + Embedding: vector, + }) + } + return &ans +} + +func minimaxHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var minimaxChatRsp MinimaxChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &minimaxChatRsp) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if minimaxChatRsp.BaseResp.StatusMsg != "success" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: minimaxChatRsp.BaseResp.StatusMsg, + Type: "minimax_error", + Param: "", + Code: minimaxChatRsp.BaseResp.StatusCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseMinimaxChat2OpenAI(&minimaxChatRsp) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func minimaxStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var usage Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string, 100) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // ignore blank line or wrong format + continue + } + if data[:6] != "data: " { + continue + } + data = data[6:] + dataChan <- data + } + close(dataChan) + }() + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Stream(func(w io.Writer) bool { + if data, ok := <-dataChan; ok { + var minimaxChatStreamRsp MinimaxChatStreamResponse + err := json.Unmarshal([]byte(data), &minimaxChatStreamRsp) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + usage.TotalTokens += minimaxChatStreamRsp.TotalTokens + response := streamResponseMinimaxChat2OpenAI(&minimaxChatStreamRsp) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + } + return false + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func minimaxEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + var minimaxEmbeddingRsp MinimaxEmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &minimaxEmbeddingRsp) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + if minimaxEmbeddingRsp.BaseResp.StatusMsg != "success" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: minimaxEmbeddingRsp.BaseResp.StatusMsg, + Type: "minimax_error", + Param: "", + Code: minimaxEmbeddingRsp.BaseResp.StatusCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := embeddingResponseMinimax2OpenAI(&minimaxEmbeddingRsp) + fullTextResponse.Usage = Usage{ + PromptTokens: promptTokens, + CompletionTokens: 0, + TotalTokens: promptTokens, + } + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 1bb463fa..a3e28ff7 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -21,6 +21,7 @@ const ( APITypeZhipu APITypeAli APITypeXunfei + APITypeMiniMax ) var httpClient *http.Client @@ -62,6 +63,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) } case RelayModeEmbeddings: + if textRequest.Input == nil { + return errorWrapper(errors.New("embedding input is nil"), "error_field_input", http.StatusBadRequest) + } + if err := common.Validate.Struct(textRequest); err != nil { + return errorWrapper(err, "error_field_input", http.StatusBadRequest) + } case RelayModeModerations: if textRequest.Input == "" { return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) @@ -99,6 +106,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAli case common.ChannelTypeXunfei: apiType = APITypeXunfei + case common.ChannelTypeMiniMax: + apiType = APITypeMiniMax } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -162,6 +171,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) case APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + case APITypeMiniMax: + groupId := c.GetString("group_id") + fullRequestURL = fmt.Sprintf("https://api.minimax.chat/v1/text/chatcompletion?GroupId=%s", groupId) + if relayMode == RelayModeEmbeddings { + fullRequestURL = fmt.Sprintf("https://api.minimax.chat/v1/embeddings?GroupId=%s", groupId) + } } var promptTokens int var completionTokens int @@ -172,6 +187,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) case RelayModeModerations: promptTokens = countTokenInput(textRequest.Input, textRequest.Model) + case RelayModeEmbeddings: + promptTokens = countTokenEmbeddingInput(textRequest.Input, textRequest.Model) } preConsumedTokens := common.PreConsumedQuota if textRequest.MaxTokens != 0 { @@ -250,6 +267,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeMiniMax: + var jsonData []byte + var err error + if relayMode == RelayModeEmbeddings { + minimaxRequest := embeddingRequestOpenAI2Minimax(textRequest) + jsonData, err = json.Marshal(minimaxRequest) + } else { + minimaxRequest := requestOpenAI2Minimax(textRequest) + jsonData, err = json.Marshal(minimaxRequest) + } + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonData) } var req *http.Request @@ -285,6 +316,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { req.Header.Set("X-DashScope-SSE", "enable") } + case APITypeMiniMax: + req.Header.Set("Authorization", "Bearer "+apiKey) } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) @@ -502,6 +535,37 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } else { return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) } + case APITypeMiniMax: + if isStream { + err, usage := minimaxStreamHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + // minimax's API does not return prompt tokens & completion tokens + textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens + return nil + } else { + var err *OpenAIErrorWithStatusCode + var usage *Usage + switch relayMode { + case RelayModeEmbeddings: + err, usage = minimaxEmbeddingHandler(c, resp, promptTokens, textRequest.Model) + default: + err, usage = minimaxHandler(c, resp) + } + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + // minimax's API does not return prompt tokens & completion tokens + textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens + return nil + } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 3695e119..5622d3e0 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/pkoukk/tiktoken-go" "one-api/common" + "reflect" ) var stopFinishReason = "stop" @@ -63,6 +64,22 @@ func countTokenMessages(messages []Message, model string) int { return tokenNum } +func countTokenEmbeddingInput(input any, model string) int { + tokenEncoder := getTokenEncoder(model) + v := reflect.ValueOf(input) + switch v.Kind() { + case reflect.String: + return getTokenNum(tokenEncoder, input.(string)) + case reflect.Array, reflect.Slice: + ans := 0 + for i := 0; i < v.Len(); i++ { + ans += countTokenEmbeddingInput(v.Index(i).Interface(), model) + } + return ans + } + return 0 +} + func countTokenInput(input any, model string) int { switch input.(type) { case string: diff --git a/controller/relay.go b/controller/relay.go index 86f16c45..e71c975a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -37,7 +37,7 @@ type GeneralOpenAIRequest struct { Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` + Input any `json:"input,omitempty" validate:"omitempty,ValidateEmbeddingInput"` Instruction string `json:"instruction,omitempty"` Size string `json:"size,omitempty"` } diff --git a/middleware/distributor.go b/middleware/distributor.go index 91c00e1a..bcafe01f 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -110,6 +110,9 @@ func Distribute() func(c *gin.Context) { if channel.Type == common.ChannelTypeAzure { c.Set("api_version", channel.Other) } + if channel.Type == common.ChannelTypeMiniMax { + c.Set("group_id", channel.Other) + } c.Next() } } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index a17ef374..13b85556 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -7,6 +7,7 @@ export const CHANNEL_OPTIONS = [ { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, + { key: 19, text: 'MiniMax', value: 19, color: 'rose' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 0d7a4a01..40edc124 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -61,6 +61,9 @@ const EditChannel = () => { case 18: localModels = ['SparkDesk']; break; + case 19: + localModels = ["abab5.5-chat", "abab5-chat", "embo-01"]; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); } @@ -247,6 +250,20 @@ const EditChannel = () => { ) } + { + inputs.type === 19 && ( + + + + ) + }