diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 9dca9a89..50dc743c 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -35,6 +35,29 @@ type AliChatRequest struct { Parameters AliParameters `json:"parameters,omitempty"` } +type AliEmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type AliEmbedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type AliEmbeddingResponse struct { + Output struct { + Embeddings []AliEmbedding `json:"embeddings"` + } `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} + type AliError struct { Code string `json:"code"` Message string `json:"message"` @@ -44,6 +67,7 @@ type AliError struct { type AliUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` } type AliOutput struct { @@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { } } +func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { + return &AliEmbeddingRequest{ + Model: "text-embedding-v1", + Input: struct { + Texts []string `json:"texts"` + }{ + Texts: request.ParseInput(), + }, + } +} + +func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var aliResponse AliEmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&aliResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Code != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { + openAIEmbeddingResponse := OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), + Model: "text-embedding-v1", + Usage: Usage{TotalTokens: response.Usage.TotalTokens}, + } + + for _, item := range response.Output.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + Object: `embedding`, + Index: item.TextIndex, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { choice := OpenAITextResponseChoice{ Index: 0, diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index 39f31a9a..ed08ac04 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom } func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { - baiduEmbeddingRequest := BaiduEmbeddingRequest{ - Input: nil, + return &BaiduEmbeddingRequest{ + Input: request.ParseInput(), } - switch request.Input.(type) { - case string: - baiduEmbeddingRequest.Input = []string{request.Input.(string)} - case []any: - for _, item := range request.Input.([]any) { - if str, ok := item.(string); ok { - baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str) - } - } - } - return &baiduEmbeddingRequest } func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { diff --git a/controller/relay-text.go b/controller/relay-text.go index 624b9d01..708f94cb 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -171,6 +171,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) case APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + if relayMode == RelayModeEmbeddings { + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + } } var promptTokens int var completionTokens int @@ -257,8 +260,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } requestBody = bytes.NewBuffer(jsonStr) case APITypeAli: - aliRequest := requestOpenAI2Ali(textRequest) - jsonStr, err := json.Marshal(aliRequest) + var jsonStr []byte + var err error + switch relayMode { + case RelayModeEmbeddings: + aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) + jsonStr, err = json.Marshal(aliEmbeddingRequest) + default: + aliRequest := requestOpenAI2Ali(textRequest) + jsonStr, err = json.Marshal(aliRequest) + } if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } @@ -488,7 +499,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } else { - err, usage := aliHandler(c, resp) + var err *OpenAIErrorWithStatusCode + var usage *Usage + switch relayMode { + case RelayModeEmbeddings: + err, usage = aliEmbeddingHandler(c, resp) + default: + err, usage = aliHandler(c, resp) + } if err != nil { return err } diff --git a/controller/relay.go b/controller/relay.go index 056d42d3..d20663f6 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -44,6 +44,25 @@ type GeneralOpenAIRequest struct { Functions any `json:"functions,omitempty"` } +func (r GeneralOpenAIRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} + type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"`