package ollama import ( "bufio" "context" "encoding/json" "fmt" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/random" "io" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" ) func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { ollamaRequest := ChatRequest{ Model: request.Model, Options: &Options{ Seed: int(request.Seed), Temperature: request.Temperature, TopP: request.TopP, FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, }, Stream: request.Stream, } for _, message := range request.Messages { ollamaRequest.Messages = append(ollamaRequest.Messages, Message{ Role: message.Role, Content: message.StringContent(), }) } return &ollamaRequest } func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: 0, Message: model.Message{ Role: response.Message.Role, Content: response.Message.Content, }, } if response.Done { choice.FinishReason = "stop" } fullTextResponse := openai.TextResponse{ Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Model: response.Model, Object: "chat.completion", Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, Usage: model.Usage{ PromptTokens: response.PromptEvalCount, CompletionTokens: response.EvalCount, TotalTokens: response.PromptEvalCount + response.EvalCount, }, } return &fullTextResponse } func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Role = ollamaResponse.Message.Role choice.Delta.Content = ollamaResponse.Message.Content if ollamaResponse.Done { choice.FinishReason = &constant.StopFinishReason } response := openai.ChatCompletionsStreamResponse{ Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: ollamaResponse.Model, Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response } func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var usage model.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil } if i := strings.Index(string(data), "}\n"); i >= 0 { return i + 2, data[0:i], nil } if atEOF { return len(data), data, nil } return 0, nil, nil }) dataChan := make(chan string) stopChan := make(chan bool) go func() { for scanner.Scan() { data := strings.TrimPrefix(scanner.Text(), "}") dataChan <- data + "}" } stopChan <- true }() common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: var ollamaResponse ChatResponse err := json.Unmarshal([]byte(data), &ollamaResponse) if err != nil { logger.SysError("error unmarshalling stream response: " + err.Error()) return true } if ollamaResponse.EvalCount != 0 { usage.PromptTokens = ollamaResponse.PromptEvalCount usage.CompletionTokens = ollamaResponse.EvalCount usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount } response := streamResponseOllama2OpenAI(&ollamaResponse) jsonResponse, err := json.Marshal(response) if err != nil { logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case <-stopChan: c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return false } }) err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, &usage } func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { return &EmbeddingRequest{ Model: request.Model, Prompt: strings.Join(request.ParseInput(), " "), } } func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var ollamaResponse EmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&ollamaResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } if ollamaResponse.Error != "" { return &model.ErrorWithStatusCode{ Error: model.Error{ Message: ollamaResponse.Error, Type: "ollama_error", Param: "", Code: "ollama_error", }, StatusCode: resp.StatusCode, }, nil } fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { openAIEmbeddingResponse := openai.EmbeddingResponse{ Object: "list", Data: make([]openai.EmbeddingResponseItem, 0, 1), Model: "text-embedding-v1", Usage: model.Usage{TotalTokens: 0}, } openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ Object: `embedding`, Index: 0, Embedding: response.Embedding, }) return &openAIEmbeddingResponse } func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { ctx := context.TODO() var ollamaResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } logger.Debugf(ctx, "ollama response: %s", string(responseBody)) err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &ollamaResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if ollamaResponse.Error != "" { return &model.ErrorWithStatusCode{ Error: model.Error{ Message: ollamaResponse.Error, Type: "ollama_error", Param: "", Code: "ollama_error", }, StatusCode: resp.StatusCode, }, nil } fullTextResponse := responseOllama2OpenAI(&ollamaResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage }