From d34fa60f797474a948b545d23fd513c6bf2d71ac Mon Sep 17 00:00:00 2001 From: shxgreen Date: Mon, 4 Dec 2023 14:59:59 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0Cloudflare=20Workers=20Ai?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dockerfile-cn | 35 ++++ common/constants.go | 2 + common/model-ratio.go | 3 + controller/channel-test.go | 44 ++++- controller/model.go | 27 +++ controller/relay-cloudflare.go | 233 +++++++++++++++++++++++++ controller/relay-text.go | 58 ++++++ web/src/constants/channel.constants.js | 3 +- web/src/pages/Channel/EditChannel.js | 3 + 9 files changed, 404 insertions(+), 4 deletions(-) create mode 100644 Dockerfile-cn create mode 100644 controller/relay-cloudflare.go diff --git a/Dockerfile-cn b/Dockerfile-cn new file mode 100644 index 00000000..3431bdb3 --- /dev/null +++ b/Dockerfile-cn @@ -0,0 +1,35 @@ +FROM node:16 as builder + +WORKDIR /build +COPY web/package.json . +RUN npm config set registry https://registry.npm.taobao.org/ && npm install +COPY ./web . +COPY ./VERSION . +RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build + +FROM golang AS builder2 + +ENV GO111MODULE=on \ + CGO_ENABLED=1 \ + GOOS=linux \ + GOPROXY=https://goproxy.cn + +WORKDIR /build +ADD go.mod go.sum ./ +RUN go mod download +COPY . . +COPY --from=builder /build/build ./web/build +RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api + +FROM alpine + +RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apk/repositories \ + && apk update \ + && apk upgrade \ + && apk add --no-cache ca-certificates tzdata \ + && update-ca-certificates 2>/dev/null || true + +COPY --from=builder2 /build/one-api / +EXPOSE 3000 +WORKDIR /data +ENTRYPOINT ["/one-api"] diff --git a/common/constants.go b/common/constants.go index f6860f67..b8d9493a 100644 --- a/common/constants.go +++ b/common/constants.go @@ -187,6 +187,7 @@ const ( ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 + ChannelTypeCloudflare = 24 ) var ChannelBaseURLs = []string{ @@ -214,4 +215,5 @@ var ChannelBaseURLs = []string{ "https://api.aiproxy.io", // 21 "https://fastgpt.run/api/openapi", // 22 "https://hunyuan.cloud.tencent.com", //23 + "https://api.cloudflare.com", // 24 } diff --git a/common/model-ratio.go b/common/model-ratio.go index ccbc05dd..fbd6c39e 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -96,6 +96,9 @@ var ModelRatio = map[string]float64{ "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 + "llama-2-7b-chat-fp16": 0.05, // CF New Add Test Token + "llama-2-7b-chat-int8": 0.05, // CF New Add Test Token + "mistral-7b-instruct-v0.1": 0.05, // CF New Add Test Token } func ModelRatio2JSONString() string { diff --git a/controller/channel-test.go b/controller/channel-test.go index bba9a657..c5617ad6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/model" "strconv" + "strings" "sync" "time" @@ -39,12 +40,28 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") } }() + case common.ChannelTypeCloudflare: + request.Model = "llama-2-7b-chat-fp16" default: request.Model = "gpt-3.5-turbo" } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) + } else if channel.Type == common.ChannelTypeCloudflare { // CF New Add + apiKey := channel.Key + accountID, err := getCloudflareAccountID(apiKey) + if err != nil { + return err, nil + } + baseURL := channel.GetBaseURL() + if !(strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai")) { + // Cloudflare Ai Gateway on workers-ai URL: https://gateway.ai.cloudflare.com/v1/[ACCOUNT_ID]/cftest/workers-ai + baseURL = fmt.Sprintf("https://api.cloudflare.com/client/v4/accounts/%s/ai/run", accountID) + } + requestURL = fmt.Sprintf("%s/@cf/meta/llama-2-7b-chat-fp16", baseURL) + request.Messages[0].Content = "Hello What's Your Name" + request.MaxTokens = 256 } else { if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { requestURL = baseURL @@ -62,6 +79,13 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } if channel.Type == common.ChannelTypeAzure { req.Header.Set("api-key", channel.Key) + } else if channel.Type == common.ChannelTypeCloudflare { // CF New Add + apiKey := channel.Key + API_Token, err := getCloudflareAPI_Token(apiKey) + if err != nil { + return err, nil + } + req.Header.Set("Authorization", "Bearer "+API_Token) } else { req.Header.Set("Authorization", "Bearer "+channel.Key) } @@ -76,10 +100,24 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai if err != nil { return err, nil } - err = json.Unmarshal(body, &response) - if err != nil { - return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil + if channel.Type == common.ChannelTypeCloudflare { // CF New Add + var cloudflareResponse CloudflareResponse + err = json.Unmarshal(body, &cloudflareResponse) + if err != nil { + return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil + } + cloudflareConverResponse := responseCloudflare2OpenAI(&cloudflareResponse) + response = TextResponse{ + Choices: cloudflareConverResponse.Choices, + Usage: cloudflareConverResponse.Usage, + } + } else { + err = json.Unmarshal(body, &response) + if err != nil { + return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil + } } + if response.Usage.CompletionTokens == 0 { if response.Error.Message == "" { response.Error.Message = "补全 tokens 非预期返回 0" diff --git a/controller/model.go b/controller/model.go index 8f79524d..6ad306a0 100644 --- a/controller/model.go +++ b/controller/model.go @@ -540,6 +540,33 @@ func init() { Root: "hunyuan", Parent: nil, }, + { + Id: "llama-2-7b-chat-fp16", + Object: "model", + Created: 1677649963, + OwnedBy: "cloudflare", + Permission: permission, + Root: "cloudflare", + Parent: nil, + }, + { + Id: "llama-2-7b-chat-int8", + Object: "model", + Created: 1677649963, + OwnedBy: "cloudflare", + Permission: permission, + Root: "cloudflare", + Parent: nil, + }, + { + Id: "mistral-7b-instruct-v0.1", + Object: "model", + Created: 1677649963, + OwnedBy: "cloudflare", + Permission: permission, + Root: "cloudflare", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay-cloudflare.go b/controller/relay-cloudflare.go new file mode 100644 index 00000000..b117ac6e --- /dev/null +++ b/controller/relay-cloudflare.go @@ -0,0 +1,233 @@ +package controller + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/common" + "strings" + + "github.com/gin-gonic/gin" +) + +type CloudflareRequest struct { + Messages []Message `json:"messages"` + Stream bool `json:"stream,omitempty"` + MaxTokens int `json:"max_tokens"` + Prompt any `json:"prompt,omitempty"` +} + +type CloudflareError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type CloudflareResult struct { + Response string `json:"response"` +} +type CloudflareResponse struct { + Result CloudflareResult `json:"result"` + Success bool `json:"success"` + Errors []CloudflareError `json:"errors"` + Messages []string `json:"messages"` +} + +type CloudflareStreamResponse struct { + Reponse string `json:"response"` +} + +func requestOpenAI2Cloudflare(textRequest GeneralOpenAIRequest) *CloudflareRequest { + cloudflareRequest := CloudflareRequest{ + Messages: textRequest.Messages, + Stream: textRequest.Stream, + MaxTokens: -1, + Prompt: textRequest.Prompt, + } + return &cloudflareRequest +} + +func streamResponseCloudflare2OpenAI(cloudflareStreamResponse *CloudflareStreamResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = cloudflareStreamResponse.Reponse + choice.FinishReason = &stopFinishReason + var response ChatCompletionsStreamResponse + + response.Object = "chat.completion.chunk" + response.Model = "cloudflare" + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + + return &response +} + +func responseCloudflare2OpenAI(cloudflareResponse *CloudflareResponse) *OpenAITextResponse { + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: strings.TrimPrefix(cloudflareResponse.Result.Response, " "), + Name: nil, + }, + FinishReason: stopFinishReason, + } + PromptTokens := 1 + CompletionTokens := len(strings.TrimPrefix(cloudflareResponse.Result.Response, " ")) + TotalTokens := CompletionTokens + PromptTokens + fullTextResponse := OpenAITextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []OpenAITextResponseChoice{choice}, + Usage: Usage{ + PromptTokens: PromptTokens, + TotalTokens: TotalTokens, + CompletionTokens: CompletionTokens, + }, + } + return &fullTextResponse +} + +func cloudflareStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // ignore blank line or wrong format + continue + } + if data[:6] != "data: " && data[:6] != "[DONE]" { + continue + } + dataChan <- data + data = data[6:] + + if !strings.HasPrefix(data, "[DONE]") { + var streamResponse CloudflareStreamResponse + err := json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + continue // just ignore the error + } + responseText += streamResponse.Reponse + } + } + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + if strings.HasPrefix(data, "data: [DONE]") { + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + data = data[6:] + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var cloudflareStreamResponse CloudflareStreamResponse + err := json.Unmarshal([]byte(data), &cloudflareStreamResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + responseText += cloudflareStreamResponse.Reponse + response := streamResponseCloudflare2OpenAI(&cloudflareStreamResponse) + response.Id = responseId + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func cloudflareHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var cloudflareResponse CloudflareResponse + err = json.Unmarshal(responseBody, &cloudflareResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if len(cloudflareResponse.Errors) > 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: cloudflareResponse.Errors[0].Message, + Type: "", + Param: "", + Code: cloudflareResponse.Errors[0].Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseCloudflare2OpenAI(&cloudflareResponse) + // completionTokens := 0 + completionTokens := countTokenText(cloudflareResponse.Result.Response, model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} + +func getCloudflareAccountID(apiKey string) (string, error) { + split := strings.Split(apiKey, "|") + if len(split) != 2 { + return "", errors.New("getCloudflareAccountID: Invalid API key format") + } + return split[0], nil +} + +func getCloudflareAPI_Token(apiKey string) (string, error) { + split := strings.Split(apiKey, "|") + if len(split) != 2 { + return "", errors.New("getCloudflareAPI_Token: Invalid API key format") + } + return split[1], nil +} diff --git a/controller/relay-text.go b/controller/relay-text.go index a3e233d3..8934efe6 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -27,6 +27,7 @@ const ( APITypeXunfei APITypeAIProxyLibrary APITypeTencent + APITypeCloudflare ) var httpClient *http.Client @@ -118,7 +119,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAIProxyLibrary case common.ChannelTypeTencent: apiType = APITypeTencent + case common.ChannelTypeCloudflare: + apiType = APITypeCloudflare } + baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if c.GetString("base_url") != "" { @@ -192,6 +196,28 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" case APITypeAIProxyLibrary: fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) + case APITypeCloudflare: + // Cloudflare Workers Ai request URL need input ACCOUNT_ID + // https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + accountID, err := getCloudflareAccountID(apiKey) + if err != nil { + return errorWrapper(err, "invalid_cloudflare_config", http.StatusInternalServerError) + } + baseURL = c.GetString("base_url") + if !(strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai")) { + // Cloudflare Ai Gateway on workers-ai URL: https://gateway.ai.cloudflare.com/v1/[ACCOUNT_ID]/cftest/workers-ai + baseURL = fmt.Sprintf("https://api.cloudflare.com/client/v4/accounts/%s/ai/run", accountID) + } + switch textRequest.Model { + case "llama-2-7b-chat-fp16": + fullRequestURL = fmt.Sprintf("%s/@cf/meta/llama-2-7b-chat-fp16", baseURL) + case "llama-2-7b-chat-int8": + fullRequestURL = fmt.Sprintf("%s/@cf/meta/llama-2-7b-chat-int8", baseURL) + case "mistral-7b-instruct-v0.1": + fullRequestURL = fmt.Sprintf("%s/@cf/mistral/mistral-7b-instruct-v0.1", baseURL) + } } var promptTokens int var completionTokens int @@ -321,6 +347,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeCloudflare: + cloudflareRequest := requestOpenAI2Cloudflare(textRequest) + jsonStr, err := json.Marshal(cloudflareRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) } var req *http.Request @@ -364,6 +397,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { req.Header.Set("Authorization", apiKey) case APITypePaLM: // do not set Authorization header + case APITypeCloudflare: + API_Token, err := getCloudflareAPI_Token(c.Request.Header.Get("Authorization")) + if err != nil { + return errorWrapper(err, "cloudflare_token_split_error", http.StatusInternalServerError) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", API_Token)) default: req.Header.Set("Authorization", "Bearer "+apiKey) } @@ -635,6 +674,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeCloudflare: + if isStream { + err, responseText := cloudflareStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := cloudflareHandler(c, resp, promptTokens, textRequest.Model) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 76407745..e90b3a32 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -21,5 +21,6 @@ export const CHANNEL_OPTIONS = [ { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, - { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } + { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' }, + { key: 24, text: 'Cloudflare Ai', value: 24, color: 'blue' } ]; \ No newline at end of file diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index bc3886a0..cf3ab622 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -83,6 +83,9 @@ const EditChannel = () => { case 23: localModels = ['hunyuan']; break; + case 24: + localModels = ['llama-2-7b-chat-fp16','llama-2-7b-chat-int8','mistral-7b-instruct-v0.1']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); }