feat: 支持 ollama 的 embedding 接口 (#1221)
* 增加ollama的embedding接口 * chore: fix function name --------- Co-authored-by: JustSong <songquanpeng@foxmail.com>
This commit is contained in:
parent
e96b173abe
commit
c243cd5535
@ -3,13 +3,14 @@ package ollama
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@ -22,6 +23,9 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
// https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
|
||||
if meta.Mode == constant.RelayModeEmbeddings {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
|
||||
}
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
@ -37,7 +41,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
return nil, errors.New("not supported")
|
||||
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return ollamaEmbeddingRequest, nil
|
||||
default:
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
@ -51,7 +56,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
switch meta.Mode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -5,6 +5,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
@ -12,9 +16,6 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
@ -139,6 +140,64 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Prompt: strings.Join(request.ParseInput(), " "),
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var ollamaResponse EmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&ollamaResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if ollamaResponse.Error != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: ollamaResponse.Error,
|
||||
Type: "ollama_error",
|
||||
Param: "",
|
||||
Code: "ollama_error",
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, 1),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: model.Usage{TotalTokens: 0},
|
||||
}
|
||||
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: 0,
|
||||
Embedding: response.Embedding,
|
||||
})
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
ctx := context.TODO()
|
||||
var ollamaResponse ChatResponse
|
||||
|
@ -35,3 +35,13 @@ type ChatResponse struct {
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Embedding []float64 `json:"embedding,omitempty"`
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user