diff --git a/README.md b/README.md index 1cb30591..8f6c6bf7 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) + [x] [MINIMAX](https://api.minimax.chat/) + [x] [Groq](https://wow.groq.com/) + + [x] [Ollama](https://github.com/ollama/ollama) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 diff --git a/common/constants.go b/common/constants.go index de71bc7a..f4f575ba 100644 --- a/common/constants.go +++ b/common/constants.go @@ -69,6 +69,7 @@ const ( ChannelTypeMinimax ChannelTypeMistral ChannelTypeGroq + ChannelTypeOllama ChannelTypeDummy ) @@ -104,6 +105,7 @@ var ChannelBaseURLs = []string{ "https://api.minimax.chat", // 27 "https://api.mistral.ai", // 28 "https://api.groq.com/openai", // 29 + "http://localhost:11434", // 30 } const ( diff --git a/common/helper/helper.go b/common/helper/helper.go index 76db5042..db41ac74 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -185,6 +185,10 @@ func GetTimeString() string { return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) } +func GenRequestID() string { + return GetTimeString() + GetRandomNumberString(8) +} + func Max(a int, b int) int { if a >= b { return a diff --git a/common/logger/logger.go b/common/logger/logger.go index ad0a0bea..957d8a11 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" "io" "log" "os" @@ -94,6 +95,9 @@ func logHelper(ctx context.Context, level string, msg string) { writer = gin.DefaultWriter } id := ctx.Value(RequestIdKey) + if id == nil { + id = helper.GenRequestID() + } now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) if !setupLogWorking { diff --git a/middleware/request-id.go b/middleware/request-id.go index 234a93d8..a4c49ddb 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -9,7 +9,7 @@ import ( func RequestId() func(c *gin.Context) { return func(c *gin.Context) { - id := helper.GetTimeString() + helper.GetRandomNumberString(8) + id := helper.GenRequestID() c.Set(logger.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) c.Request = c.Request.WithContext(ctx) diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go new file mode 100644 index 00000000..06c66101 --- /dev/null +++ b/relay/channel/ollama/adaptor.go @@ -0,0 +1,65 @@ +package ollama + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" + "net/http" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *util.RelayMeta) { + +} + +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + // https://github.com/ollama/ollama/blob/main/docs/api.md + fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case constant.RelayModeEmbeddings: + return nil, errors.New("not supported") + default: + return ConvertRequest(*request), nil + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "ollama" +} diff --git a/relay/channel/ollama/constants.go b/relay/channel/ollama/constants.go new file mode 100644 index 00000000..32f82b2a --- /dev/null +++ b/relay/channel/ollama/constants.go @@ -0,0 +1,5 @@ +package ollama + +var ModelList = []string{ + "qwen:0.5b-chat", +} diff --git a/relay/channel/ollama/main.go b/relay/channel/ollama/main.go new file mode 100644 index 00000000..7ec646a3 --- /dev/null +++ b/relay/channel/ollama/main.go @@ -0,0 +1,178 @@ +package ollama + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" +) + +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { + ollamaRequest := ChatRequest{ + Model: request.Model, + Options: &Options{ + Seed: int(request.Seed), + Temperature: request.Temperature, + TopP: request.TopP, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, + }, + Stream: request.Stream, + } + for _, message := range request.Messages { + ollamaRequest.Messages = append(ollamaRequest.Messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) + } + return &ollamaRequest +} + +func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: response.Message.Role, + Content: response.Message.Content, + }, + } + if response.Done { + choice.FinishReason = "stop" + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + Usage: model.Usage{ + PromptTokens: response.PromptEvalCount, + CompletionTokens: response.EvalCount, + TotalTokens: response.PromptEvalCount + response.EvalCount, + }, + } + return &fullTextResponse +} + +func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Role = ollamaResponse.Message.Role + choice.Delta.Content = ollamaResponse.Message.Content + if ollamaResponse.Done { + choice.FinishReason = &constant.StopFinishReason + } + response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: ollamaResponse.Model, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "}\n"); i >= 0 { + return i + 2, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := strings.TrimPrefix(scanner.Text(), "}") + dataChan <- data + "}" + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var ollamaResponse ChatResponse + err := json.Unmarshal([]byte(data), &ollamaResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if ollamaResponse.EvalCount != 0 { + usage.PromptTokens = ollamaResponse.PromptEvalCount + usage.CompletionTokens = ollamaResponse.EvalCount + usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount + } + response := streamResponseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := context.TODO() + var ollamaResponse ChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + logger.Debugf(ctx, "ollama response: %s", string(responseBody)) + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &ollamaResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if ollamaResponse.Error != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: ollamaResponse.Error, + Type: "ollama_error", + Param: "", + Code: "ollama_error", + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/relay/channel/ollama/model.go b/relay/channel/ollama/model.go new file mode 100644 index 00000000..a8ef1ffc --- /dev/null +++ b/relay/channel/ollama/model.go @@ -0,0 +1,37 @@ +package ollama + +type Options struct { + Seed int `json:"seed,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` +} + +type Message struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Stream bool `json:"stream"` + Options *Options `json:"options,omitempty"` +} + +type ChatResponse struct { + Model string `json:"model,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + Message Message `json:"message,omitempty"` + Response string `json:"response,omitempty"` // for stream response + Done bool `json:"done,omitempty"` + TotalDuration int `json:"total_duration,omitempty"` + LoadDuration int `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration int `json:"eval_duration,omitempty"` + Error string `json:"error,omitempty"` +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index d2184dac..b249f6a2 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -15,6 +15,7 @@ const ( APITypeAIProxyLibrary APITypeTencent APITypeGemini + APITypeOllama APITypeDummy // this one is only for count, do not add any channel after this ) @@ -40,6 +41,8 @@ func ChannelType2APIType(channelType int) int { apiType = APITypeTencent case common.ChannelTypeGemini: apiType = APITypeGemini + case common.ChannelTypeOllama: + apiType = APITypeOllama } return apiType } diff --git a/relay/helper/main.go b/relay/helper/main.go index c2b6e6af..e7342329 100644 --- a/relay/helper/main.go +++ b/relay/helper/main.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/anthropic" "github.com/songquanpeng/one-api/relay/channel/baidu" "github.com/songquanpeng/one-api/relay/channel/gemini" + "github.com/songquanpeng/one-api/relay/channel/ollama" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/palm" "github.com/songquanpeng/one-api/relay/channel/tencent" @@ -37,6 +38,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &xunfei.Adaptor{} case constant.APITypeZhipu: return &zhipu.Adaptor{} + case constant.APITypeOllama: + return &ollama.Adaptor{} } return nil } diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index 8e9fc97c..c0379381 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -95,6 +95,12 @@ export const CHANNEL_OPTIONS = { value: 29, color: 'default' }, + 30: { + key: 30, + text: 'Ollama', + value: 30, + color: 'default' + }, 8: { key: 8, text: '自定义渠道', diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index 897db189..c42c0253 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -166,6 +166,9 @@ const typeConfig = { 29: { modelGroup: "groq", }, + 30: { + modelGroup: "ollama", + }, }; export { defaultConfig, typeConfig }; diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index f6db46c3..c8284ef2 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -15,6 +15,7 @@ export const CHANNEL_OPTIONS = [ { key: 26, text: '百川大模型', value: 26, color: 'orange' }, { key: 27, text: 'MiniMax', value: 27, color: 'red' }, { key: 29, text: 'Groq', value: 29, color: 'orange' }, + { key: 30, text: 'Ollama', value: 30, color: 'black' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },