From 81c5901123b74553cfb004388d51779936c1afdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=A9=E7=89=9B=E7=89=9B?= Date: Sat, 15 Jul 2023 12:03:23 +0800 Subject: [PATCH] feat: add support for /v1/engines/text-embedding-ada-002/embeddings (#224, close #222) --- controller/relay-text.go | 6 +++++- controller/relay.go | 5 ++++- middleware/distributor.go | 8 +++++++- router/relay-router.go | 4 +++- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/controller/relay-text.go b/controller/relay-text.go index eab71a95..a26355e3 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -6,12 +6,13 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" "strings" + + "github.com/gin-gonic/gin" ) func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -30,6 +31,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if relayMode == RelayModeModerations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } + if relayMode == RelayModeEmbeddings && textRequest.Model == "" { + textRequest.Model = c.Param("model") + } // request validation if textRequest.Model == "" { return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) diff --git a/controller/relay.go b/controller/relay.go index c8bd929c..2f562799 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,10 +2,11 @@ package controller import ( "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "strings" + + "github.com/gin-gonic/gin" ) type Message struct { @@ -100,6 +101,8 @@ func Relay(c *gin.Context) { relayMode = RelayModeCompletions } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { relayMode = RelayModeEmbeddings + } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + relayMode = RelayModeEmbeddings } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { relayMode = RelayModeModerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { diff --git a/middleware/distributor.go b/middleware/distributor.go index 314677c7..cb419d6d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,12 +2,13 @@ package middleware import ( "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "strings" + + "github.com/gin-gonic/gin" ) type ModelRequest struct { @@ -73,6 +74,11 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "text-moderation-stable" } } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) if err != nil { message := "无可用渠道" diff --git a/router/relay-router.go b/router/relay-router.go index cbdfef11..cef5c7cc 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -1,9 +1,10 @@ package router import ( - "github.com/gin-gonic/gin" "one-api/controller" "one-api/middleware" + + "github.com/gin-gonic/gin" ) func SetRelayRouter(router *gin.Engine) { @@ -24,6 +25,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/images/edits", controller.RelayNotImplemented) relayV1Router.POST("/images/variations", controller.RelayNotImplemented) relayV1Router.POST("/embeddings", controller.Relay) + relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented) relayV1Router.POST("/audio/translations", controller.RelayNotImplemented) relayV1Router.GET("/files", controller.RelayNotImplemented)