feat: support DeepL's model (close #1126)

This commit is contained in:
JustSong 2024-04-27 13:37:22 +08:00
parent e64e7707a0
commit 007906216d
20 changed files with 305 additions and 10 deletions

View File

@ -86,6 +86,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Cohere](https://cohere.com/) + [x] [Cohere](https://cohere.com/)
+ [x] [DeepSeek](https://www.deepseek.com/) + [x] [DeepSeek](https://www.deepseek.com/)
+ [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/)
+ [x] [DeepL](https://www.deepl.com/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。 3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。

View File

@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/cloudflare" "github.com/songquanpeng/one-api/relay/adaptor/cloudflare"
"github.com/songquanpeng/one-api/relay/adaptor/cohere" "github.com/songquanpeng/one-api/relay/adaptor/cohere"
"github.com/songquanpeng/one-api/relay/adaptor/coze" "github.com/songquanpeng/one-api/relay/adaptor/coze"
"github.com/songquanpeng/one-api/relay/adaptor/deepl"
"github.com/songquanpeng/one-api/relay/adaptor/gemini" "github.com/songquanpeng/one-api/relay/adaptor/gemini"
"github.com/songquanpeng/one-api/relay/adaptor/ollama" "github.com/songquanpeng/one-api/relay/adaptor/ollama"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
@ -52,6 +53,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
return &cohere.Adaptor{} return &cohere.Adaptor{}
case apitype.Cloudflare: case apitype.Cloudflare:
return &cloudflare.Adaptor{} return &cloudflare.Adaptor{}
case apitype.DeepL:
return &deepl.Adaptor{}
} }
return nil return nil
} }

View File

@ -0,0 +1,73 @@
package deepl
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
type Adaptor struct {
meta *meta.Meta
promptText string
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v2/translate", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "DeepL-Auth-Key "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
convertedRequest, text := ConvertRequest(*request)
a.promptText = text
return convertedRequest, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err = StreamHandler(c, resp, meta.ActualModelName)
} else {
err = Handler(c, resp, meta.ActualModelName)
}
promptTokens := len(a.promptText)
usage = &model.Usage{
PromptTokens: promptTokens,
TotalTokens: promptTokens,
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "deepl"
}

View File

@ -0,0 +1,9 @@
package deepl
// https://developers.deepl.com/docs/api-reference/glossaries
var ModelList = []string{
"deepl-zh",
"deepl-en",
"deepl-ja",
}

View File

@ -0,0 +1,11 @@
package deepl
import "strings"
func parseLangFromModelName(modelName string) string {
parts := strings.Split(modelName, "-")
if len(parts) == 1 {
return "ZH"
}
return parts[1]
}

137
relay/adaptor/deepl/main.go Normal file
View File

@ -0,0 +1,137 @@
package deepl
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/constant/finishreason"
"github.com/songquanpeng/one-api/relay/constant/role"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
// https://developers.deepl.com/docs/getting-started/your-first-api-request
func ConvertRequest(textRequest model.GeneralOpenAIRequest) (*Request, string) {
var text string
if len(textRequest.Messages) != 0 {
text = textRequest.Messages[len(textRequest.Messages)-1].StringContent()
}
deeplRequest := Request{
TargetLang: parseLangFromModelName(textRequest.Model),
Text: []string{text},
}
return &deeplRequest, text
}
func StreamResponseDeepL2OpenAI(deeplResponse *Response) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
if len(deeplResponse.Translations) != 0 {
choice.Delta.Content = deeplResponse.Translations[0].Text
}
choice.Delta.Role = role.Assistant
choice.FinishReason = &constant.StopFinishReason
openaiResponse := openai.ChatCompletionsStreamResponse{
Object: constant.StreamObject,
Created: helper.GetTimestamp(),
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &openaiResponse
}
func ResponseDeepL2OpenAI(deeplResponse *Response) *openai.TextResponse {
var responseText string
if len(deeplResponse.Translations) != 0 {
responseText = deeplResponse.Translations[0].Text
}
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: role.Assistant,
Content: responseText,
Name: nil,
},
FinishReason: finishreason.Stop,
}
fullTextResponse := openai.TextResponse{
Object: constant.NonStreamObject,
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var deeplResponse Response
err = json.Unmarshal(responseBody, &deeplResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
fullTextResponse := StreamResponseDeepL2OpenAI(&deeplResponse)
fullTextResponse.Model = modelName
fullTextResponse.Id = helper.GetResponseID(c)
jsonData, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
if jsonData != nil {
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
jsonData = nil
return true
}
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
})
_ = resp.Body.Close()
return nil
}
func Handler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var deeplResponse Response
err = json.Unmarshal(responseBody, &deeplResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
if deeplResponse.Message != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: deeplResponse.Message,
Code: "deepl_error",
},
StatusCode: resp.StatusCode,
}
}
fullTextResponse := ResponseDeepL2OpenAI(&deeplResponse)
fullTextResponse.Model = modelName
fullTextResponse.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil
}

View File

@ -0,0 +1,16 @@
package deepl
type Request struct {
Text []string `json:"text"`
TargetLang string `json:"target_lang"`
}
type Translation struct {
DetectedSourceLanguage string `json:"detected_source_language,omitempty"`
Text string `json:"text,omitempty"`
}
type Response struct {
Translations []Translation `json:"translations,omitempty"`
Message string `json:"message,omitempty"`
}

View File

@ -206,3 +206,7 @@ func CountTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text) return getTokenNum(tokenEncoder, text)
} }
func CountToken(text string) int {
return CountTokenInput(text, "gpt-3.5-turbo")
}

View File

@ -16,6 +16,7 @@ const (
Coze Coze
Cohere Cohere
Cloudflare Cloudflare
DeepL
Dummy // this one is only for count, do not add any channel after this Dummy // this one is only for count, do not add any channel after this
) )

View File

@ -173,6 +173,10 @@ var ModelRatio = map[string]float64{
// https://platform.deepseek.com/api-docs/pricing/ // https://platform.deepseek.com/api-docs/pricing/
"deepseek-chat": 1.0 / 1000 * RMB, "deepseek-chat": 1.0 / 1000 * RMB,
"deepseek-coder": 1.0 / 1000 * RMB, "deepseek-coder": 1.0 / 1000 * RMB,
// https://www.deepl.com/pro?cta=header-prices
"deepl-zh": 25.0 / 1000 * USD,
"deepl-en": 25.0 / 1000 * USD,
"deepl-ja": 25.0 / 1000 * USD,
} }
var CompletionRatio = map[string]float64{} var CompletionRatio = map[string]float64{}

View File

@ -39,6 +39,7 @@ const (
Cohere Cohere
DeepSeek DeepSeek
Cloudflare Cloudflare
DeepL
Dummy Dummy
) )

View File

@ -33,6 +33,8 @@ func ToAPIType(channelType int) int {
apiType = apitype.Cohere apiType = apitype.Cohere
case Cloudflare: case Cloudflare:
apiType = apitype.Cloudflare apiType = apitype.Cloudflare
case DeepL:
apiType = apitype.DeepL
} }
return apiType return apiType

View File

@ -39,6 +39,7 @@ var ChannelBaseURLs = []string{
"https://api.cohere.ai", // 35 "https://api.cohere.ai", // 35
"https://api.deepseek.com", // 36 "https://api.deepseek.com", // 36
"https://api.cloudflare.com", // 37 "https://api.cloudflare.com", // 37
"https://api-free.deepl.com", // 38
} }
func init() { func init() {

View File

@ -1,3 +1,5 @@
package constant package constant
var StopFinishReason = "stop" var StopFinishReason = "stop"
var StreamObject = "chat.completion.chunk"
var NonStreamObject = "chat.completion"

View File

@ -0,0 +1,5 @@
package finishreason
const (
Stop = "stop"
)

View File

@ -0,0 +1,5 @@
package role
const (
Assistant = "assistant"
)

View File

@ -18,6 +18,7 @@ import (
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"math" "math"
"net/http" "net/http"
"strings"
) )
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
@ -204,3 +205,20 @@ func getMappedModelName(modelName string, mapping map[string]string) (string, bo
} }
return modelName, false return modelName, false
} }
func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
if resp == nil {
return true
}
if resp.StatusCode != http.StatusOK {
return true
}
if meta.ChannelType == channeltype.DeepL {
// skip stream check for deepl
return false
}
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
return true
}
return false
}

View File

@ -4,10 +4,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay"
@ -18,6 +14,8 @@ import (
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
) )
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
@ -88,12 +86,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
if resp != nil { if isErrorHappened(meta, resp) {
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json")) billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
if errorHappened { return RelayErrorHandler(resp)
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return RelayErrorHandler(resp)
}
} }
// do response // do response

View File

@ -137,6 +137,12 @@ export const CHANNEL_OPTIONS = {
value: 36, value: 36,
color: 'primary' color: 'primary'
}, },
38: {
key: 38,
text: 'DeepL',
value: 38,
color: 'primary'
},
8: { 8: {
key: 8, key: 8,
text: '自定义渠道', text: '自定义渠道',

View File

@ -23,6 +23,7 @@ export const CHANNEL_OPTIONS = [
{key: 35, text: 'Cohere', value: 35, color: 'blue'}, {key: 35, text: 'Cohere', value: 35, color: 'blue'},
{key: 36, text: 'DeepSeek', value: 36, color: 'black'}, {key: 36, text: 'DeepSeek', value: 36, color: 'black'},
{key: 37, text: 'Cloudflare', value: 37, color: 'orange'}, {key: 37, text: 'Cloudflare', value: 37, color: 'orange'},
{key: 38, text: 'DeepL', value: 38, color: 'black'},
{key: 8, text: '自定义渠道', value: 8, color: 'pink'}, {key: 8, text: '自定义渠道', value: 8, color: 'pink'},
{key: 22, text: '知识库FastGPT', value: 22, color: 'blue'}, {key: 22, text: '知识库FastGPT', value: 22, color: 'blue'},
{key: 21, text: '知识库AI Proxy', value: 21, color: 'purple'}, {key: 21, text: '知识库AI Proxy', value: 21, color: 'purple'},