增加Cloudflare Workers Ai支持

This commit is contained in:
shxgreen 2023-12-04 14:59:59 +08:00
parent 01f7b0186f
commit d34fa60f79
9 changed files with 404 additions and 4 deletions

35
Dockerfile-cn Normal file
View File

@ -0,0 +1,35 @@
FROM node:16 as builder
WORKDIR /build
COPY web/package.json .
RUN npm config set registry https://registry.npm.taobao.org/ && npm install
COPY ./web .
COPY ./VERSION .
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2
ENV GO111MODULE=on \
CGO_ENABLED=1 \
GOOS=linux \
GOPROXY=https://goproxy.cn
WORKDIR /build
ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /build/build ./web/build
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine
RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apk/repositories \
&& apk update \
&& apk upgrade \
&& apk add --no-cache ca-certificates tzdata \
&& update-ca-certificates 2>/dev/null || true
COPY --from=builder2 /build/one-api /
EXPOSE 3000
WORKDIR /data
ENTRYPOINT ["/one-api"]

View File

@ -187,6 +187,7 @@ const (
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeCloudflare = 24
)
var ChannelBaseURLs = []string{
@ -214,4 +215,5 @@ var ChannelBaseURLs = []string{
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
"https://api.cloudflare.com", // 24
}

View File

@ -96,6 +96,9 @@ var ModelRatio = map[string]float64{
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"llama-2-7b-chat-fp16": 0.05, // CF New Add Test Token
"llama-2-7b-chat-int8": 0.05, // CF New Add Test Token
"mistral-7b-instruct-v0.1": 0.05, // CF New Add Test Token
}
func ModelRatio2JSONString() string {

View File

@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/model"
"strconv"
"strings"
"sync"
"time"
@ -39,12 +40,28 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
}
}()
case common.ChannelTypeCloudflare:
request.Model = "llama-2-7b-chat-fp16"
default:
request.Model = "gpt-3.5-turbo"
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
} else if channel.Type == common.ChannelTypeCloudflare { // CF New Add
apiKey := channel.Key
accountID, err := getCloudflareAccountID(apiKey)
if err != nil {
return err, nil
}
baseURL := channel.GetBaseURL()
if !(strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai")) {
// Cloudflare Ai Gateway on workers-ai URL: https://gateway.ai.cloudflare.com/v1/[ACCOUNT_ID]/cftest/workers-ai
baseURL = fmt.Sprintf("https://api.cloudflare.com/client/v4/accounts/%s/ai/run", accountID)
}
requestURL = fmt.Sprintf("%s/@cf/meta/llama-2-7b-chat-fp16", baseURL)
request.Messages[0].Content = "Hello What's Your Name"
request.MaxTokens = 256
} else {
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
requestURL = baseURL
@ -62,6 +79,13 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else if channel.Type == common.ChannelTypeCloudflare { // CF New Add
apiKey := channel.Key
API_Token, err := getCloudflareAPI_Token(apiKey)
if err != nil {
return err, nil
}
req.Header.Set("Authorization", "Bearer "+API_Token)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
@ -76,10 +100,24 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
if err != nil {
return err, nil
}
if channel.Type == common.ChannelTypeCloudflare { // CF New Add
var cloudflareResponse CloudflareResponse
err = json.Unmarshal(body, &cloudflareResponse)
if err != nil {
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
}
cloudflareConverResponse := responseCloudflare2OpenAI(&cloudflareResponse)
response = TextResponse{
Choices: cloudflareConverResponse.Choices,
Usage: cloudflareConverResponse.Usage,
}
} else {
err = json.Unmarshal(body, &response)
if err != nil {
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
}
}
if response.Usage.CompletionTokens == 0 {
if response.Error.Message == "" {
response.Error.Message = "补全 tokens 非预期返回 0"

View File

@ -540,6 +540,33 @@ func init() {
Root: "hunyuan",
Parent: nil,
},
{
Id: "llama-2-7b-chat-fp16",
Object: "model",
Created: 1677649963,
OwnedBy: "cloudflare",
Permission: permission,
Root: "cloudflare",
Parent: nil,
},
{
Id: "llama-2-7b-chat-int8",
Object: "model",
Created: 1677649963,
OwnedBy: "cloudflare",
Permission: permission,
Root: "cloudflare",
Parent: nil,
},
{
Id: "mistral-7b-instruct-v0.1",
Object: "model",
Created: 1677649963,
OwnedBy: "cloudflare",
Permission: permission,
Root: "cloudflare",
Parent: nil,
},
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {

View File

@ -0,0 +1,233 @@
package controller
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"strings"
"github.com/gin-gonic/gin"
)
type CloudflareRequest struct {
Messages []Message `json:"messages"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens"`
Prompt any `json:"prompt,omitempty"`
}
type CloudflareError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type CloudflareResult struct {
Response string `json:"response"`
}
type CloudflareResponse struct {
Result CloudflareResult `json:"result"`
Success bool `json:"success"`
Errors []CloudflareError `json:"errors"`
Messages []string `json:"messages"`
}
type CloudflareStreamResponse struct {
Reponse string `json:"response"`
}
func requestOpenAI2Cloudflare(textRequest GeneralOpenAIRequest) *CloudflareRequest {
cloudflareRequest := CloudflareRequest{
Messages: textRequest.Messages,
Stream: textRequest.Stream,
MaxTokens: -1,
Prompt: textRequest.Prompt,
}
return &cloudflareRequest
}
func streamResponseCloudflare2OpenAI(cloudflareStreamResponse *CloudflareStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = cloudflareStreamResponse.Reponse
choice.FinishReason = &stopFinishReason
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "cloudflare"
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func responseCloudflare2OpenAI(cloudflareResponse *CloudflareResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: strings.TrimPrefix(cloudflareResponse.Result.Response, " "),
Name: nil,
},
FinishReason: stopFinishReason,
}
PromptTokens := 1
CompletionTokens := len(strings.TrimPrefix(cloudflareResponse.Result.Response, " "))
TotalTokens := CompletionTokens + PromptTokens
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: Usage{
PromptTokens: PromptTokens,
TotalTokens: TotalTokens,
CompletionTokens: CompletionTokens,
},
}
return &fullTextResponse
}
func cloudflareStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
var streamResponse CloudflareStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
}
responseText += streamResponse.Reponse
}
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
data = data[6:]
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var cloudflareStreamResponse CloudflareStreamResponse
err := json.Unmarshal([]byte(data), &cloudflareStreamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += cloudflareStreamResponse.Reponse
response := streamResponseCloudflare2OpenAI(&cloudflareStreamResponse)
response.Id = responseId
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func cloudflareHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cloudflareResponse CloudflareResponse
err = json.Unmarshal(responseBody, &cloudflareResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(cloudflareResponse.Errors) > 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: cloudflareResponse.Errors[0].Message,
Type: "",
Param: "",
Code: cloudflareResponse.Errors[0].Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseCloudflare2OpenAI(&cloudflareResponse)
// completionTokens := 0
completionTokens := countTokenText(cloudflareResponse.Result.Response, model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
func getCloudflareAccountID(apiKey string) (string, error) {
split := strings.Split(apiKey, "|")
if len(split) != 2 {
return "", errors.New("getCloudflareAccountID: Invalid API key format")
}
return split[0], nil
}
func getCloudflareAPI_Token(apiKey string) (string, error) {
split := strings.Split(apiKey, "|")
if len(split) != 2 {
return "", errors.New("getCloudflareAPI_Token: Invalid API key format")
}
return split[1], nil
}

View File

@ -27,6 +27,7 @@ const (
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeCloudflare
)
var httpClient *http.Client
@ -118,7 +119,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeCloudflare:
apiType = APITypeCloudflare
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
@ -192,6 +196,28 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
case APITypeCloudflare:
// Cloudflare Workers Ai request URL need input ACCOUNT_ID
// https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
accountID, err := getCloudflareAccountID(apiKey)
if err != nil {
return errorWrapper(err, "invalid_cloudflare_config", http.StatusInternalServerError)
}
baseURL = c.GetString("base_url")
if !(strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai")) {
// Cloudflare Ai Gateway on workers-ai URL: https://gateway.ai.cloudflare.com/v1/[ACCOUNT_ID]/cftest/workers-ai
baseURL = fmt.Sprintf("https://api.cloudflare.com/client/v4/accounts/%s/ai/run", accountID)
}
switch textRequest.Model {
case "llama-2-7b-chat-fp16":
fullRequestURL = fmt.Sprintf("%s/@cf/meta/llama-2-7b-chat-fp16", baseURL)
case "llama-2-7b-chat-int8":
fullRequestURL = fmt.Sprintf("%s/@cf/meta/llama-2-7b-chat-int8", baseURL)
case "mistral-7b-instruct-v0.1":
fullRequestURL = fmt.Sprintf("%s/@cf/mistral/mistral-7b-instruct-v0.1", baseURL)
}
}
var promptTokens int
var completionTokens int
@ -321,6 +347,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeCloudflare:
cloudflareRequest := requestOpenAI2Cloudflare(textRequest)
jsonStr, err := json.Marshal(cloudflareRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
var req *http.Request
@ -364,6 +397,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
req.Header.Set("Authorization", apiKey)
case APITypePaLM:
// do not set Authorization header
case APITypeCloudflare:
API_Token, err := getCloudflareAPI_Token(c.Request.Header.Get("Authorization"))
if err != nil {
return errorWrapper(err, "cloudflare_token_split_error", http.StatusInternalServerError)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", API_Token))
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
@ -635,6 +674,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
}
case APITypeCloudflare:
if isStream {
err, responseText := cloudflareStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := cloudflareHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}

View File

@ -21,5 +21,6 @@ export const CHANNEL_OPTIONS = [
{ key: 6, text: '代理OpenAI Max', value: 6, color: 'violet' },
{ key: 9, text: '代理AI.LS', value: 9, color: 'yellow' },
{ key: 12, text: '代理API2GPT', value: 12, color: 'blue' },
{ key: 13, text: '代理AIGC2D', value: 13, color: 'purple' }
{ key: 13, text: '代理AIGC2D', value: 13, color: 'purple' },
{ key: 24, text: 'Cloudflare Ai', value: 24, color: 'blue' }
];

View File

@ -83,6 +83,9 @@ const EditChannel = () => {
case 23:
localModels = ['hunyuan'];
break;
case 24:
localModels = ['llama-2-7b-chat-fp16','llama-2-7b-chat-int8','mistral-7b-instruct-v0.1'];
break;
}
setInputs((inputs) => ({ ...inputs, models: localModels }));
}