From c243cd553537bf0a544242a52708aaab48e34997 Mon Sep 17 00:00:00 2001 From: xietong Date: Sun, 24 Mar 2024 21:51:31 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20ollama=20=E7=9A=84?= =?UTF-8?q?=20embedding=20=E6=8E=A5=E5=8F=A3=20(#1221)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 增加ollama的embedding接口 * chore: fix function name --------- Co-authored-by: JustSong --- relay/channel/ollama/adaptor.go | 18 +++++++-- relay/channel/ollama/main.go | 65 +++++++++++++++++++++++++++++++-- relay/channel/ollama/model.go | 10 +++++ 3 files changed, 86 insertions(+), 7 deletions(-) diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 06c66101..e2ae7d2b 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -3,13 +3,14 @@ package ollama import ( "errors" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" ) type Adaptor struct { @@ -22,6 +23,9 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { // https://github.com/ollama/ollama/blob/main/docs/api.md fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) + if meta.Mode == constant.RelayModeEmbeddings { + fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) + } return fullRequestURL, nil } @@ -37,7 +41,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } switch relayMode { case constant.RelayModeEmbeddings: - return nil, errors.New("not supported") + ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request) + return ollamaEmbeddingRequest, nil default: return ConvertRequest(*request), nil } @@ -51,7 +56,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel if meta.IsStream { err, usage = StreamHandler(c, resp) } else { - err, usage = Handler(c, resp) + switch meta.Mode { + case constant.RelayModeEmbeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } } return } diff --git a/relay/channel/ollama/main.go b/relay/channel/ollama/main.go index 7ec646a3..821a335b 100644 --- a/relay/channel/ollama/main.go +++ b/relay/channel/ollama/main.go @@ -5,6 +5,10 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -12,9 +16,6 @@ import ( "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" ) func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { @@ -139,6 +140,64 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC return nil, &usage } +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + Model: request.Model, + Prompt: strings.Join(request.ParseInput(), " "), + } +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var ollamaResponse EmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&ollamaResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if ollamaResponse.Error != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: ollamaResponse.Error, + Type: "ollama_error", + Param: "", + Code: "ollama_error", + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, 1), + Model: "text-embedding-v1", + Usage: model.Usage{TotalTokens: 0}, + } + + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: 0, + Embedding: response.Embedding, + }) + return &openAIEmbeddingResponse +} + func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { ctx := context.TODO() var ollamaResponse ChatResponse diff --git a/relay/channel/ollama/model.go b/relay/channel/ollama/model.go index a8ef1ffc..8baf56a0 100644 --- a/relay/channel/ollama/model.go +++ b/relay/channel/ollama/model.go @@ -35,3 +35,13 @@ type ChatResponse struct { EvalDuration int `json:"eval_duration,omitempty"` Error string `json:"error,omitempty"` } + +type EmbeddingRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` +} + +type EmbeddingResponse struct { + Error string `json:"error,omitempty"` + Embedding []float64 `json:"embedding,omitempty"` +}