From 893c9adeb996a02d047656e168e010372c999b7e Mon Sep 17 00:00:00 2001 From: nongqiqin Date: Tue, 30 Apr 2024 12:29:41 +0800 Subject: [PATCH] feat: support /v1/rerank router --- controller/relay.go | 2 ++ relay/adaptor/openai/model.go | 24 ++++++++++++++ relay/controller/helper.go | 30 +++++++++++++++++ relay/controller/rerank.go | 62 +++++++++++++++++++++++++++++++++++ relay/model/rerank.go | 10 ++++++ relay/relaymode/define.go | 1 + relay/relaymode/helper.go | 3 ++ router/relay.go | 1 + 8 files changed, 133 insertions(+) create mode 100644 relay/controller/rerank.go create mode 100644 relay/model/rerank.go diff --git a/controller/relay.go b/controller/relay.go index aba4cd94..376c4200 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -33,6 +33,8 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { fallthrough case relaymode.AudioTranscription: err = controller.RelayAudioHelper(c, relayMode) + case relaymode.Rerank: + err = controller.RerankHelper(c, relayMode) default: err = controller.RelayTextHelper(c) } diff --git a/relay/adaptor/openai/model.go b/relay/adaptor/openai/model.go index 4c974de4..704538cc 100644 --- a/relay/adaptor/openai/model.go +++ b/relay/adaptor/openai/model.go @@ -143,3 +143,27 @@ type CompletionsStreamResponse struct { FinishReason string `json:"finish_reason"` } `json:"choices"` } +type Document struct { + Text string `json:"text"` +} + +type DocumentResult struct { + Index int `json:"index"` + Score float64 `json:"score"` + Document *Document `json:"document,omitempty"` +} + +type RerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []Document `json:"documents"` + TopN *int `json:"top_n,omitempty"` + MaxChunksPerDoc *int `json:"max_chunks_per_doc,omitempty"` + ReturnDocuments bool `json:"return_documents"` +} + +type RerankResponse struct { + ID string `json:"id,omitempty"` + Results []DocumentResult `json:"results"` + Error *string `json:"error,omitempty"` +} diff --git a/relay/controller/helper.go b/relay/controller/helper.go index dccff486..2a39f0de 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -57,7 +57,37 @@ func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, e } return imageRequest, nil } +func getRerankRequest(c *gin.Context, relayMode int) (*relaymodel.RerankRequest, error) { + rerankRequest := &relaymodel.RerankRequest{} + err := common.UnmarshalBodyReusable(c, rerankRequest) + if err != nil { + return nil, err + } + if rerankRequest.Model == "" { + return nil, errors.New("model parameter must be provided") + } + // Set default values if necessary + if rerankRequest.TopN == nil { + defaultTopN := 10 // Default to returning top 10 results + rerankRequest.TopN = &defaultTopN + } + if rerankRequest.Query == "" { + return nil, errors.New("query must not be empty") + } + if len(rerankRequest.Documents) == 0 { + return nil, errors.New("document list must not be empty") + } + // if rerankRequest.MaxChunksPerDoc == nil { + // defaultMaxChunks := 5 // Default maximum chunks per document + // rerankRequest.MaxChunksPerDoc = &defaultMaxChunks + // } + if rerankRequest.ReturnDocuments == nil { + defaultReturnDocs := true // Default to returning documents + rerankRequest.ReturnDocuments = &defaultReturnDocs + } + return rerankRequest, nil +} func isValidImageSize(model string, size string) bool { if model == "cogview-3" { return true diff --git a/relay/controller/rerank.go b/relay/controller/rerank.go new file mode 100644 index 00000000..782713cc --- /dev/null +++ b/relay/controller/rerank.go @@ -0,0 +1,62 @@ +package controller + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +func RerankHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() + meta := meta.GetByContext(c) + rerankRequest, err := getRerankRequest(c, meta.Mode) + if err != nil { + logger.Errorf(ctx, "getRerankRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_rerank_request", http.StatusBadRequest) + } + + // Map model name + var isModelMapped bool + meta.OriginModelName = rerankRequest.Model + rerankRequest.Model, isModelMapped = getMappedModelName(rerankRequest.Model, meta.ModelMapping) + meta.ActualModelName = rerankRequest.Model + + var requestBody io.Reader + if isModelMapped { + jsonStr, err := json.Marshal(rerankRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_rerank_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = c.Request.Body + } + + adaptor := relay.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + + resp, err := adaptor.DoRequest(c, meta, requestBody) + if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + // do response + _, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + return respErr + } + + return nil +} diff --git a/relay/model/rerank.go b/relay/model/rerank.go new file mode 100644 index 00000000..01a4153e --- /dev/null +++ b/relay/model/rerank.go @@ -0,0 +1,10 @@ +package model + +type RerankRequest struct { + Model string `json:"model"` + Documents []string `json:"documents"` + Query string `json:"query"` + TopN *int `json:"top_n,omitempty"` + MaxChunksPerDoc *int `json:"max_chunks_per_doc,omitempty"` + ReturnDocuments *bool `json:"return_documents,omitempty"` +} diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go index 96d09438..88b2086a 100644 --- a/relay/relaymode/define.go +++ b/relay/relaymode/define.go @@ -11,4 +11,5 @@ const ( AudioSpeech AudioTranscription AudioTranslation + Rerank ) diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go index 926dd42e..7666c269 100644 --- a/relay/relaymode/helper.go +++ b/relay/relaymode/helper.go @@ -24,6 +24,9 @@ func GetByPath(path string) int { relayMode = AudioTranscription } else if strings.HasPrefix(path, "/v1/audio/translations") { relayMode = AudioTranslation + } else if strings.HasPrefix(path, "/v1/rerank") { + relayMode = Rerank } + return relayMode } diff --git a/router/relay.go b/router/relay.go index 65072c86..43a449ac 100644 --- a/router/relay.go +++ b/router/relay.go @@ -30,6 +30,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay) relayV1Router.POST("/audio/speech", controller.Relay) + relayV1Router.POST("/rerank", controller.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)