diff --git a/README.md b/README.md index af9f44f8..22923110 100644 --- a/README.md +++ b/README.md @@ -60,8 +60,9 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用 ## 功能 1. 支持多种 API 访问渠道: + [x] OpenAI 官方通道(支持配置镜像) - + [x] [Anthropic Claude 系列模型](https://anthropic.com) + [x] **Azure OpenAI API** + + [x] [Anthropic Claude 系列模型](https://anthropic.com) + + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) + [x] [OpenAI-SB](https://openai-sb.com) + [x] [API2D](https://api2d.com/r/197971) diff --git a/common/constants.go b/common/constants.go index 50c0bd2f..0ba4a4fd 100644 --- a/common/constants.go +++ b/common/constants.go @@ -152,6 +152,7 @@ const ( ChannelTypeAPI2GPT = 12 ChannelTypeAIGC2D = 13 ChannelTypeAnthropic = 14 + ChannelTypeBaidu = 15 ) var ChannelBaseURLs = []string{ @@ -170,4 +171,5 @@ var ChannelBaseURLs = []string{ "https://api.api2gpt.com", // 12 "https://api.aigc2d.com", // 13 "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 } diff --git a/common/model-ratio.go b/common/model-ratio.go index 0ba4c397..8f034ec6 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -4,6 +4,7 @@ import "encoding/json" // ModelRatio // https://platform.openai.com/docs/models/model-endpoint-compatibility +// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf // https://openai.com/pricing // TODO: when a new api is enabled, check the pricing here // 1 === $0.002 / 1K tokens @@ -38,6 +39,8 @@ var ModelRatio = map[string]float64{ "dall-e": 8, "claude-instant-1": 0.75, "claude-2": 30, + "ERNIE-Bot": 1, // 0.012元/千tokens + "ERNIE-Bot-turbo": 0.67, // 0.008元/千tokens } func ModelRatio2JSONString() string { diff --git a/controller/model.go b/controller/model.go index cc77f621..cfcb8d87 100644 --- a/controller/model.go +++ b/controller/model.go @@ -288,6 +288,24 @@ func init() { Root: "claude-2", Parent: nil, }, + { + Id: "ERNIE-Bot", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "ERNIE-Bot", + Parent: nil, + }, + { + Id: "ERNIE-Bot-turbo", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "ERNIE-Bot-turbo", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go new file mode 100644 index 00000000..e82f4904 --- /dev/null +++ b/controller/relay-baidu.go @@ -0,0 +1,203 @@ +package controller + +import ( + "bufio" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strings" +) + +// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 + +type BaiduTokenResponse struct { + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + SessionKey string `json:"session_key"` + AccessToken string `json:"access_token"` + Scope string `json:"scope"` + SessionSecret string `json:"session_secret"` +} + +type BaiduMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type BaiduChatRequest struct { + Messages []BaiduMessage `json:"messages"` + Stream bool `json:"stream"` + UserId string `json:"user_id,omitempty"` +} + +type BaiduError struct { + ErrorCode int `json:"error_code"` + ErrorMsg string `json:"error_msg"` +} + +type BaiduChatResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage Usage `json:"usage"` + BaiduError +} + +type BaiduChatStreamResponse struct { + BaiduChatResponse + SentenceId int `json:"sentence_id"` + IsEnd bool `json:"is_end"` +} + +func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { + messages := make([]BaiduMessage, 0, len(request.Messages)) + for _, message := range request.Messages { + messages = append(messages, BaiduMessage{ + Role: message.Role, + Content: message.Content, + }) + } + return &BaiduChatRequest{ + Messages: messages, + Stream: request.Stream, + } +} + +func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: response.Result, + }, + FinishReason: "stop", + } + fullTextResponse := OpenAITextResponse{ + Id: response.Id, + Object: "chat.completion", + Created: response.Created, + Choices: []OpenAITextResponseChoice{choice}, + Usage: response.Usage, + } + return &fullTextResponse +} + +func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = baiduResponse.Result + choice.FinishReason = "stop" + response := ChatCompletionsStreamResponse{ + Id: baiduResponse.Id, + Object: "chat.completion.chunk", + Created: baiduResponse.Created, + Model: "ernie-bot", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var usage Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // ignore blank line or wrong format + continue + } + data = data[6:] + dataChan <- data + } + stopChan <- true + }() + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var baiduResponse BaiduChatStreamResponse + err := json.Unmarshal([]byte(data), &baiduResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + usage.PromptTokens += baiduResponse.Usage.PromptTokens + usage.CompletionTokens += baiduResponse.Usage.CompletionTokens + usage.TotalTokens += baiduResponse.Usage.TotalTokens + response := streamResponseBaidu2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var baiduResponse BaiduChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &baiduResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if baiduResponse.ErrorMsg != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: baiduResponse.ErrorMsg, + Type: "baidu_error", + Param: "", + Code: baiduResponse.ErrorCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseBaidu2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 723ed59e..0e7893a6 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -18,6 +18,7 @@ const ( APITypeOpenAI = iota APITypeClaude APITypePaLM + APITypeBaidu ) func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -79,6 +80,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType := APITypeOpenAI if strings.HasPrefix(textRequest.Model, "claude") { apiType = APITypeClaude + } else if strings.HasPrefix(textRequest.Model, "ERNIE") { + apiType = APITypeBaidu } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -112,6 +115,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if baseURL != "" { fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) } + case APITypeBaidu: + switch textRequest.Model { + case "ERNIE-Bot": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" + case "ERNIE-Bot-turbo": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" + case "BLOOMZ-7B": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" + } + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days } var promptTokens int var completionTokens int @@ -164,6 +179,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeBaidu: + baiduRequest := requestOpenAI2Baidu(textRequest) + jsonStr, err := json.Marshal(baiduRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -216,7 +238,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if strings.HasPrefix(textRequest.Model, "gpt-4") { completionRatio = 2 } - if isStream { + if isStream && apiType != APITypeBaidu { completionTokens = countTokenText(streamResponseText, textRequest.Model) } else { promptTokens = textResponse.Usage.PromptTokens @@ -285,6 +307,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { textResponse.Usage = *usage return nil } + case APITypeBaidu: + if isStream { + err, usage := baiduStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage = *usage + return nil + } else { + err, usage := baiduHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage = *usage + return nil + } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index db9f5e1e..28d20405 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -3,6 +3,7 @@ export const CHANNEL_OPTIONS = [ { key: 14, text: 'Anthropic', value: 14, color: 'black' }, { key: 8, text: '自定义', value: 8, color: 'pink' }, { key: 3, text: 'Azure', value: 3, color: 'olive' }, + { key: 15, text: 'Baidu', value: 15, color: 'blue' }, { key: 2, text: 'API2D', value: 2, color: 'blue' }, { key: 4, text: 'CloseAI', value: 4, color: 'teal' }, { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },