feat: add support for /v1/engines/text-embedding-ada-002/embeddings (#224, close #222)

This commit is contained in:
玩牛牛 2023-07-15 12:03:23 +08:00 committed by GitHub
parent abc53cb208
commit 81c5901123
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 4 deletions

View File

@ -6,12 +6,13 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@ -30,6 +31,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)

View File

@ -2,10 +2,11 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"strings"
"github.com/gin-gonic/gin"
)
type Message struct {
@ -100,6 +101,8 @@ func Relay(c *gin.Context) {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {

View File

@ -2,12 +2,13 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type ModelRequest struct {
@ -73,6 +74,11 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := "无可用渠道"

View File

@ -1,9 +1,10 @@
package router
import (
"github.com/gin-gonic/gin"
"one-api/controller"
"one-api/middleware"
"github.com/gin-gonic/gin"
)
func SetRelayRouter(router *gin.Engine) {
@ -24,6 +25,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.POST("/embeddings", controller.Relay)
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
relayV1Router.GET("/files", controller.RelayNotImplemented)