diff --git a/common/model-ratio.go b/common/model-ratio.go index 08cde8c7..e16df566 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -30,6 +30,12 @@ var DalleImagePromptLengthLimitations = map[string]int{ "dall-e-3": 4000, } +const ( + USD2RMB = 7 + USD = 500 // $0.002 = 1 -> $1 = 500 + RMB = USD / USD2RMB +) + // ModelRatio // https://platform.openai.com/docs/models/model-endpoint-compatibility // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf @@ -38,57 +44,60 @@ var DalleImagePromptLengthLimitations = map[string]int{ // 1 === $0.002 / 1K tokens // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ - "gpt-4": 15, - "gpt-4-0314": 15, - "gpt-4-0613": 15, - "gpt-4-32k": 30, - "gpt-4-32k-0314": 30, - "gpt-4-32k-0613": 30, - "gpt-4-1106-preview": 5, // $0.01 / 1K tokens - "gpt-4-0125-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens - "gpt-4-vision-preview": 5, // $0.01 / 1K tokens - "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens - "gpt-3.5-turbo-0301": 0.75, - "gpt-3.5-turbo-0613": 0.75, - "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens - "gpt-3.5-turbo-16k-0613": 1.5, - "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens - "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens - "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens - "davinci-002": 1, // $0.002 / 1K tokens - "babbage-002": 0.2, // $0.0004 / 1K tokens - "text-ada-001": 0.2, - "text-babbage-001": 0.25, - "text-curie-001": 1, - "text-davinci-002": 10, - "text-davinci-003": 10, - "text-davinci-edit-001": 10, - "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens - "tts-1": 7.5, // $0.015 / 1K characters - "tts-1-1106": 7.5, - "tts-1-hd": 15, // $0.030 / 1K characters - "tts-1-hd-1106": 15, - "davinci": 10, - "curie": 10, - "babbage": 10, - "ada": 10, - "text-embedding-ada-002": 0.05, - "text-embedding-3-small": 0.01, - "text-embedding-3-large": 0.065, - "text-search-ada-doc-001": 10, - "text-moderation-stable": 0.1, - "text-moderation-latest": 0.1, - "dall-e-2": 8, // $0.016 - $0.020 / image - "dall-e-3": 20, // $0.040 - $0.120 / image - "claude-instant-1": 0.815, // $1.63 / 1M tokens - "claude-2": 5.51, // $11.02 / 1M tokens - "claude-2.0": 5.51, // $11.02 / 1M tokens - "claude-2.1": 5.51, // $11.02 / 1M tokens - "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens - "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens - "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens + // https://openai.com/pricing + "gpt-4": 15, + "gpt-4-0314": 15, + "gpt-4-0613": 15, + "gpt-4-32k": 30, + "gpt-4-32k-0314": 30, + "gpt-4-32k-0613": 30, + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-0125-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens + "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-0301": 0.75, + "gpt-3.5-turbo-0613": 0.75, + "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens + "gpt-3.5-turbo-16k-0613": 1.5, + "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens + "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens + "davinci-002": 1, // $0.002 / 1K tokens + "babbage-002": 0.2, // $0.0004 / 1K tokens + "text-ada-001": 0.2, + "text-babbage-001": 0.25, + "text-curie-001": 1, + "text-davinci-002": 10, + "text-davinci-003": 10, + "text-davinci-edit-001": 10, + "code-davinci-edit-001": 10, + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // $0.015 / 1K characters + "tts-1-1106": 7.5, + "tts-1-hd": 15, // $0.030 / 1K characters + "tts-1-hd-1106": 15, + "davinci": 10, + "curie": 10, + "babbage": 10, + "ada": 10, + "text-embedding-ada-002": 0.05, + "text-embedding-3-small": 0.01, + "text-embedding-3-large": 0.065, + "text-search-ada-doc-001": 10, + "text-moderation-stable": 0.1, + "text-moderation-latest": 0.1, + "dall-e-2": 8, // $0.016 - $0.020 / image + "dall-e-3": 20, // $0.040 - $0.120 / image + "claude-instant-1": 0.815, // $1.63 / 1M tokens + "claude-2": 5.51, // $11.02 / 1M tokens + "claude-2.0": 5.51, // $11.02 / 1M tokens + "claude-2.1": 5.51, // $11.02 / 1M tokens + // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 + "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens + "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens + "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens + "ERNIE-Bot-8k": 0.024 * RMB, "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens diff --git a/controller/billing.go b/controller/billing.go index 7bc19b49..7317913d 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -4,7 +4,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" + relaymodel "github.com/songquanpeng/one-api/relay/model" ) func GetSubscription(c *gin.Context) { @@ -30,7 +30,7 @@ func GetSubscription(c *gin.Context) { expiredTime = 0 } if err != nil { - Error := openai.Error{ + Error := relaymodel.Error{ Message: err.Error(), Type: "upstream_error", } @@ -72,7 +72,7 @@ func GetUsage(c *gin.Context) { quota, err = model.GetUserUsedQuota(userId) } if err != nil { - Error := openai.Error{ + Error := relaymodel.Error{ Message: err.Error(), Type: "one_api_error", } diff --git a/controller/channel-test.go b/controller/channel-test.go index 9d21b469..8b2bb40d 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -9,10 +9,14 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/helper" + relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" + "net/http/httptest" + "net/url" "strconv" "sync" "time" @@ -20,87 +24,13 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) { - switch channel.Type { - case common.ChannelTypePaLM: - fallthrough - case common.ChannelTypeGemini: - fallthrough - case common.ChannelTypeAnthropic: - fallthrough - case common.ChannelTypeBaidu: - fallthrough - case common.ChannelTypeZhipu: - fallthrough - case common.ChannelTypeAli: - fallthrough - case common.ChannelType360: - fallthrough - case common.ChannelTypeXunfei: - return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil - case common.ChannelTypeAzure: - request.Model = "gpt-35-turbo" - defer func() { - if err != nil { - err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") - } - }() - default: - request.Model = "gpt-3.5-turbo" - } - requestURL := common.ChannelBaseURLs[channel.Type] - if channel.Type == common.ChannelTypeAzure { - requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) - } else { - if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { - requestURL = baseURL - } - - requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) - } - jsonData, err := json.Marshal(request) - if err != nil { - return err, nil - } - req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) - if err != nil { - return err, nil - } - if channel.Type == common.ChannelTypeAzure { - req.Header.Set("api-key", channel.Key) - } else { - req.Header.Set("Authorization", "Bearer "+channel.Key) - } - req.Header.Set("Content-Type", "application/json") - resp, err := util.HTTPClient.Do(req) - if err != nil { - return err, nil - } - defer resp.Body.Close() - var response openai.SlimTextResponse - body, err := io.ReadAll(resp.Body) - if err != nil { - return err, nil - } - err = json.Unmarshal(body, &response) - if err != nil { - return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil - } - if response.Usage.CompletionTokens == 0 { - if response.Error.Message == "" { - response.Error.Message = "补全 tokens 非预期返回 0" - } - return fmt.Errorf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message), &response.Error - } - return nil, nil -} - -func buildTestRequest() *openai.ChatRequest { - testRequest := &openai.ChatRequest{ - Model: "", // this will be set later +func buildTestRequest() *relaymodel.GeneralOpenAIRequest { + testRequest := &relaymodel.GeneralOpenAIRequest{ MaxTokens: 1, + Stream: false, + Model: "gpt-3.5-turbo", } - testMessage := openai.Message{ + testMessage := relaymodel.Message{ Role: "user", Content: "hi", } @@ -108,6 +38,64 @@ func buildTestRequest() *openai.ChatRequest { return testRequest } +func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = &http.Request{ + Method: "POST", + URL: &url.URL{Path: "/v1/chat/completions"}, + Body: nil, + Header: make(http.Header), + } + c.Request.Header.Set("Authorization", "Bearer "+channel.Key) + c.Request.Header.Set("Content-Type", "application/json") + c.Set("channel", channel.Type) + c.Set("base_url", channel.GetBaseURL()) + meta := util.GetRelayMeta(c) + apiType := constant.ChannelType2APIType(channel.Type) + adaptor := helper.GetAdaptor(apiType) + if adaptor == nil { + return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil + } + modelName := adaptor.GetModelList()[0] + request := buildTestRequest() + request.Model = modelName + meta.OriginModelName, meta.ActualModelName = modelName, modelName + convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) + if err != nil { + return err, nil + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return err, nil + } + requestBody := bytes.NewBuffer(jsonData) + c.Request.Body = io.NopCloser(requestBody) + resp, err := adaptor.DoRequest(c, meta, requestBody) + if err != nil { + return err, nil + } + if resp.StatusCode != http.StatusOK { + err := util.RelayErrorHandler(resp) + return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error + } + usage, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error + } + if usage == nil { + return errors.New("usage is nil"), nil + } + result := w.Result() + // print result.Body + respBody, err := io.ReadAll(result.Body) + if err != nil { + return err, nil + } + logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) + return nil, nil +} + func TestChannel(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { @@ -125,9 +113,8 @@ func TestChannel(c *gin.Context) { }) return } - testRequest := buildTestRequest() tik := time.Now() - err, _ = testChannel(channel, *testRequest) + err, _ = testChannel(channel) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() go channel.UpdateResponseTime(milliseconds) @@ -192,7 +179,6 @@ func testAllChannels(notify bool) error { if err != nil { return err } - testRequest := buildTestRequest() var disableThreshold = int64(config.ChannelDisableThreshold * 1000) if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value @@ -201,7 +187,7 @@ func testAllChannels(notify bool) error { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - err, openaiErr := testChannel(channel, *testRequest) + err, openaiErr := testChannel(channel) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() if isChannelEnabled && milliseconds > disableThreshold { diff --git a/controller/model.go b/controller/model.go index e3e83fcd..6ad21eda 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,7 +3,9 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/helper" + relaymodel "github.com/songquanpeng/one-api/relay/model" ) // https://platform.openai.com/docs/api-reference/models/list @@ -53,547 +55,24 @@ func init() { IsBlocking: false, }) // https://platform.openai.com/docs/models/model-endpoint-compatibility - openAIModels = []OpenAIModels{ - { - Id: "dall-e-2", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "dall-e-2", - Parent: nil, - }, - { - Id: "dall-e-3", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "dall-e-3", - Parent: nil, - }, - { - Id: "whisper-1", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "whisper-1", - Parent: nil, - }, - { - Id: "tts-1", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "tts-1", - Parent: nil, - }, - { - Id: "tts-1-1106", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "tts-1-1106", - Parent: nil, - }, - { - Id: "tts-1-hd", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "tts-1-hd", - Parent: nil, - }, - { - Id: "tts-1-hd-1106", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "tts-1-hd-1106", - Parent: nil, - }, - { - Id: "gpt-3.5-turbo", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-3.5-turbo", - Parent: nil, - }, - { - Id: "gpt-3.5-turbo-0301", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-3.5-turbo-0301", - Parent: nil, - }, - { - Id: "gpt-3.5-turbo-0613", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-3.5-turbo-0613", - Parent: nil, - }, - { - Id: "gpt-3.5-turbo-16k", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-3.5-turbo-16k", - Parent: nil, - }, - { - Id: "gpt-3.5-turbo-16k-0613", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-3.5-turbo-16k-0613", - Parent: nil, - }, - { - Id: "gpt-3.5-turbo-1106", - Object: "model", - Created: 1699593571, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-3.5-turbo-1106", - Parent: nil, - }, - { - Id: "gpt-3.5-turbo-0125", - Object: "model", - Created: 1706232090, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-3.5-turbo-0125", - Parent: nil, - }, - { - Id: "gpt-3.5-turbo-instruct", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-3.5-turbo-instruct", - Parent: nil, - }, - { - Id: "gpt-4", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-4", - Parent: nil, - }, - { - Id: "gpt-4-0314", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-4-0314", - Parent: nil, - }, - { - Id: "gpt-4-0613", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-4-0613", - Parent: nil, - }, - { - Id: "gpt-4-32k", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-4-32k", - Parent: nil, - }, - { - Id: "gpt-4-32k-0314", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-4-32k-0314", - Parent: nil, - }, - { - Id: "gpt-4-32k-0613", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-4-32k-0613", - Parent: nil, - }, - { - Id: "gpt-4-1106-preview", - Object: "model", - Created: 1699593571, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-4-1106-preview", - Parent: nil, - }, - { - Id: "gpt-4-0125-preview", - Object: "model", - Created: 1706232090, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-4-0125-preview", - Parent: nil, - }, - { - Id: "gpt-4-turbo-preview", - Object: "model", - Created: 1706232090, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-4-turbo-preview", - Parent: nil, - }, - { - Id: "gpt-4-vision-preview", - Object: "model", - Created: 1699593571, - OwnedBy: "openai", - Permission: permission, - Root: "gpt-4-vision-preview", - Parent: nil, - }, - { - Id: "text-embedding-ada-002", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "text-embedding-ada-002", - Parent: nil, - }, - { - Id: "text-embedding-3-small", - Object: "model", - Created: 1706232090, - OwnedBy: "openai", - Permission: permission, - Root: "text-embedding-3-small", - Parent: nil, - }, - { - Id: "text-embedding-3-large", - Object: "model", - Created: 1706232090, - OwnedBy: "openai", - Permission: permission, - Root: "text-embedding-3-large", - Parent: nil, - }, - { - Id: "text-davinci-003", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "text-davinci-003", - Parent: nil, - }, - { - Id: "text-davinci-002", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "text-davinci-002", - Parent: nil, - }, - { - Id: "text-curie-001", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "text-curie-001", - Parent: nil, - }, - { - Id: "text-babbage-001", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "text-babbage-001", - Parent: nil, - }, - { - Id: "text-ada-001", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "text-ada-001", - Parent: nil, - }, - { - Id: "text-moderation-latest", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "text-moderation-latest", - Parent: nil, - }, - { - Id: "text-moderation-stable", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "text-moderation-stable", - Parent: nil, - }, - { - Id: "text-davinci-edit-001", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "text-davinci-edit-001", - Parent: nil, - }, - { - Id: "code-davinci-edit-001", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "code-davinci-edit-001", - Parent: nil, - }, - { - Id: "davinci-002", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "davinci-002", - Parent: nil, - }, - { - Id: "babbage-002", - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: permission, - Root: "babbage-002", - Parent: nil, - }, - { - Id: "claude-instant-1", - Object: "model", - Created: 1677649963, - OwnedBy: "anthropic", - Permission: permission, - Root: "claude-instant-1", - Parent: nil, - }, - { - Id: "claude-2", - Object: "model", - Created: 1677649963, - OwnedBy: "anthropic", - Permission: permission, - Root: "claude-2", - Parent: nil, - }, - { - Id: "claude-2.1", - Object: "model", - Created: 1677649963, - OwnedBy: "anthropic", - Permission: permission, - Root: "claude-2.1", - Parent: nil, - }, - { - Id: "claude-2.0", - Object: "model", - Created: 1677649963, - OwnedBy: "anthropic", - Permission: permission, - Root: "claude-2.0", - Parent: nil, - }, - { - Id: "ERNIE-Bot", - Object: "model", - Created: 1677649963, - OwnedBy: "baidu", - Permission: permission, - Root: "ERNIE-Bot", - Parent: nil, - }, - { - Id: "ERNIE-Bot-turbo", - Object: "model", - Created: 1677649963, - OwnedBy: "baidu", - Permission: permission, - Root: "ERNIE-Bot-turbo", - Parent: nil, - }, - { - Id: "ERNIE-Bot-4", - Object: "model", - Created: 1677649963, - OwnedBy: "baidu", - Permission: permission, - Root: "ERNIE-Bot-4", - Parent: nil, - }, - { - Id: "Embedding-V1", - Object: "model", - Created: 1677649963, - OwnedBy: "baidu", - Permission: permission, - Root: "Embedding-V1", - Parent: nil, - }, - { - Id: "PaLM-2", - Object: "model", - Created: 1677649963, - OwnedBy: "google palm", - Permission: permission, - Root: "PaLM-2", - Parent: nil, - }, - { - Id: "gemini-pro", - Object: "model", - Created: 1677649963, - OwnedBy: "google gemini", - Permission: permission, - Root: "gemini-pro", - Parent: nil, - }, - { - Id: "gemini-pro-vision", - Object: "model", - Created: 1677649963, - OwnedBy: "google gemini", - Permission: permission, - Root: "gemini-pro-vision", - Parent: nil, - }, - { - Id: "chatglm_turbo", - Object: "model", - Created: 1677649963, - OwnedBy: "zhipu", - Permission: permission, - Root: "chatglm_turbo", - Parent: nil, - }, - { - Id: "chatglm_pro", - Object: "model", - Created: 1677649963, - OwnedBy: "zhipu", - Permission: permission, - Root: "chatglm_pro", - Parent: nil, - }, - { - Id: "chatglm_std", - Object: "model", - Created: 1677649963, - OwnedBy: "zhipu", - Permission: permission, - Root: "chatglm_std", - Parent: nil, - }, - { - Id: "chatglm_lite", - Object: "model", - Created: 1677649963, - OwnedBy: "zhipu", - Permission: permission, - Root: "chatglm_lite", - Parent: nil, - }, - { - Id: "qwen-turbo", - Object: "model", - Created: 1677649963, - OwnedBy: "ali", - Permission: permission, - Root: "qwen-turbo", - Parent: nil, - }, - { - Id: "qwen-plus", - Object: "model", - Created: 1677649963, - OwnedBy: "ali", - Permission: permission, - Root: "qwen-plus", - Parent: nil, - }, - { - Id: "qwen-max", - Object: "model", - Created: 1677649963, - OwnedBy: "ali", - Permission: permission, - Root: "qwen-max", - Parent: nil, - }, - { - Id: "qwen-max-longcontext", - Object: "model", - Created: 1677649963, - OwnedBy: "ali", - Permission: permission, - Root: "qwen-max-longcontext", - Parent: nil, - }, - { - Id: "text-embedding-v1", - Object: "model", - Created: 1677649963, - OwnedBy: "ali", - Permission: permission, - Root: "text-embedding-v1", - Parent: nil, - }, - { - Id: "SparkDesk", - Object: "model", - Created: 1677649963, - OwnedBy: "xunfei", - Permission: permission, - Root: "SparkDesk", - Parent: nil, - }, + for i := 0; i < constant.APITypeDummy; i++ { + adaptor := helper.GetAdaptor(i) + channelName := adaptor.GetChannelName() + modelNames := adaptor.GetModelList() + for _, modelName := range modelNames { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: channelName, + Permission: permission, + Root: modelName, + Parent: nil, + }) + } + } + // extra + openAIModels = append(openAIModels, []OpenAIModels{ { Id: "360GPT_S2_V9", Object: "model", @@ -630,16 +109,7 @@ func init() { Root: "semantic_similarity_s1_v1", Parent: nil, }, - { - Id: "hunyuan", - Object: "model", - Created: 1677649963, - OwnedBy: "tencent", - Permission: permission, - Root: "hunyuan", - Parent: nil, - }, - } + }...) openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { openAIModelsMap[model.Id] = model @@ -658,7 +128,7 @@ func RetrieveModel(c *gin.Context) { if model, ok := openAIModelsMap[modelId]; ok { c.JSON(200, model) } else { - Error := openai.Error{ + Error := relaymodel.Error{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), Type: "invalid_request_error", Param: "model", diff --git a/controller/relay.go b/controller/relay.go index cfe37984..6c6d268e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -6,9 +6,9 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/controller" + "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "net/http" "strconv" @@ -18,7 +18,7 @@ import ( func Relay(c *gin.Context) { relayMode := constant.Path2RelayMode(c.Request.URL.Path) - var err *openai.ErrorWithStatusCode + var err *model.ErrorWithStatusCode switch relayMode { case constant.RelayModeImagesGenerations: err = controller.RelayImageHelper(c, relayMode) @@ -61,7 +61,7 @@ func Relay(c *gin.Context) { } func RelayNotImplemented(c *gin.Context) { - err := openai.Error{ + err := model.Error{ Message: "API not implemented", Type: "one_api_error", Param: "", @@ -73,7 +73,7 @@ func RelayNotImplemented(c *gin.Context) { } func RelayNotFound(c *gin.Context) { - err := openai.Error{ + err := model.Error{ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Type: "invalid_request_error", Param: "", diff --git a/relay/channel/aiproxy/adaptor.go b/relay/channel/aiproxy/adaptor.go index 7e737e8f..eab79c30 100644 --- a/relay/channel/aiproxy/adaptor.go +++ b/relay/channel/aiproxy/adaptor.go @@ -1,22 +1,55 @@ package aiproxy import ( + "errors" + "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" ) type Adaptor struct { } -func (a *Adaptor) Auth(c *gin.Context) error { +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) return nil } -func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { - return nil, nil +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + aiProxyLibraryRequest := ConvertRequest(*request) + aiProxyLibraryRequest.LibraryId = c.GetString("library_id") + return aiProxyLibraryRequest, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { - return nil, nil, nil +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "aiproxy" } diff --git a/relay/channel/aiproxy/constants.go b/relay/channel/aiproxy/constants.go new file mode 100644 index 00000000..c4df51c4 --- /dev/null +++ b/relay/channel/aiproxy/constants.go @@ -0,0 +1,9 @@ +package aiproxy + +import "github.com/songquanpeng/one-api/relay/channel/openai" + +var ModelList = []string{""} + +func init() { + ModelList = openai.ModelList +} diff --git a/relay/channel/aiproxy/main.go b/relay/channel/aiproxy/main.go index 0bd345c7..0d3d0b60 100644 --- a/relay/channel/aiproxy/main.go +++ b/relay/channel/aiproxy/main.go @@ -10,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" "io" "net/http" "strconv" @@ -18,7 +19,7 @@ import ( // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 -func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest { +func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest { query := "" if len(request.Messages) != 0 { query = request.Messages[len(request.Messages)-1].StringContent() @@ -45,7 +46,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon content := response.Answer + aiProxyDocuments2Markdown(response.Documents) choice := openai.TextResponseChoice{ Index: 0, - Message: openai.Message{ + Message: model.Message{ Role: "assistant", Content: content, }, @@ -85,8 +86,8 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena } } -func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { - var usage openai.Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -157,7 +158,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus return nil, &usage } -func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var AIProxyLibraryResponse LibraryResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -172,8 +173,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if AIProxyLibraryResponse.ErrCode != 0 { - return &openai.ErrorWithStatusCode{ - Error: openai.Error{ + return &model.ErrorWithStatusCode{ + Error: model.Error{ Message: AIProxyLibraryResponse.Message, Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode), Code: AIProxyLibraryResponse.ErrCode, diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 9470eff0..177aa49e 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -1,22 +1,76 @@ package ali import ( + "errors" + "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" ) type Adaptor struct { } -func (a *Adaptor) Auth(c *gin.Context) error { +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) + if meta.Mode == constant.RelayModeEmbeddings { + fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) + } + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + if meta.IsStream { + req.Header.Set("X-DashScope-SSE", "enable") + } + if c.GetString("plugin") != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) + } return nil } -func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { - return nil, nil +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) + return baiduEmbeddingRequest, nil + default: + baiduRequest := ConvertRequest(*request) + return baiduRequest, nil + } } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { - return nil, nil, nil +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + switch meta.Mode { + case constant.RelayModeEmbeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "ali" } diff --git a/relay/channel/ali/constants.go b/relay/channel/ali/constants.go new file mode 100644 index 00000000..16bcfca4 --- /dev/null +++ b/relay/channel/ali/constants.go @@ -0,0 +1,6 @@ +package ali + +var ModelList = []string{ + "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", + "text-embedding-v1", +} diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index 70476d2e..416093d0 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" "io" "net/http" "strings" @@ -17,7 +18,7 @@ import ( const EnableSearchModelSuffix = "-internet" -func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] @@ -44,7 +45,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { } } -func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest { +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { return &EmbeddingRequest{ Model: "text-embedding-v1", Input: struct { @@ -55,7 +56,7 @@ func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequ } } -func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var aliResponse EmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&aliResponse) if err != nil { @@ -68,8 +69,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSta } if aliResponse.Code != "" { - return &openai.ErrorWithStatusCode{ - Error: openai.Error{ + return &model.ErrorWithStatusCode{ + Error: model.Error{ Message: aliResponse.Message, Type: aliResponse.Code, Param: aliResponse.RequestId, @@ -95,7 +96,7 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR Object: "list", Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)), Model: "text-embedding-v1", - Usage: openai.Usage{TotalTokens: response.Usage.TotalTokens}, + Usage: model.Usage{TotalTokens: response.Usage.TotalTokens}, } for _, item := range response.Output.Embeddings { @@ -111,7 +112,7 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: 0, - Message: openai.Message{ + Message: model.Message{ Role: "assistant", Content: response.Output.Text, }, @@ -122,7 +123,7 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { Object: "chat.completion", Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, - Usage: openai.Usage{ + Usage: model.Usage{ PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, @@ -148,8 +149,8 @@ func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletions return &response } -func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { - var usage openai.Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -217,7 +218,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus return nil, &usage } -func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var aliResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -232,8 +233,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if aliResponse.Code != "" { - return &openai.ErrorWithStatusCode{ - Error: openai.Error{ + return &model.ErrorWithStatusCode{ + Error: model.Error{ Message: aliResponse.Message, Type: aliResponse.Code, Param: aliResponse.RequestId, diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go index ee02fab5..b95516a8 100644 --- a/relay/channel/anthropic/adaptor.go +++ b/relay/channel/anthropic/adaptor.go @@ -1,22 +1,61 @@ package anthropic import ( + "errors" + "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" ) type Adaptor struct { } -func (a *Adaptor) Auth(c *gin.Context) error { +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("x-api-key", meta.APIKey) + anthropicVersion := c.Request.Header.Get("anthropic-version") + if anthropicVersion == "" { + anthropicVersion = "2023-06-01" + } + req.Header.Set("anthropic-version", anthropicVersion) return nil } -func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { - return nil, nil +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { - return nil, nil, nil +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText = StreamHandler(c, resp) + usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "authropic" } diff --git a/relay/channel/anthropic/constants.go b/relay/channel/anthropic/constants.go new file mode 100644 index 00000000..b98c15c2 --- /dev/null +++ b/relay/channel/anthropic/constants.go @@ -0,0 +1,5 @@ +package anthropic + +var ModelList = []string{ + "claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", +} diff --git a/relay/channel/anthropic/main.go b/relay/channel/anthropic/main.go index c1e39494..e2c575fa 100644 --- a/relay/channel/anthropic/main.go +++ b/relay/channel/anthropic/main.go @@ -9,6 +9,7 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" "io" "net/http" "strings" @@ -25,7 +26,7 @@ func stopReasonClaude2OpenAI(reason string) string { } } -func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request { +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { claudeRequest := Request{ Model: textRequest.Model, Prompt: "", @@ -72,7 +73,7 @@ func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletio func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: 0, - Message: openai.Message{ + Message: model.Message{ Role: "assistant", Content: strings.TrimPrefix(claudeResponse.Completion, " "), Name: nil, @@ -88,7 +89,7 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { return &fullTextResponse } -func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) createdTime := helper.GetTimestamp() @@ -153,7 +154,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus return nil, responseText } -func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil @@ -168,8 +169,8 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if claudeResponse.Error.Type != "" { - return &openai.ErrorWithStatusCode{ - Error: openai.Error{ + return &model.ErrorWithStatusCode{ + Error: model.Error{ Message: claudeResponse.Error.Message, Type: claudeResponse.Error.Type, Param: "", @@ -179,9 +180,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string }, nil } fullTextResponse := responseClaude2OpenAI(&claudeResponse) - fullTextResponse.Model = model - completionTokens := openai.CountTokenText(claudeResponse.Completion, model) - usage := openai.Usage{ + fullTextResponse.Model = modelName + completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName) + usage := model.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index c6304a74..afaf2393 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -1,22 +1,89 @@ package baidu import ( + "errors" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" ) type Adaptor struct { } -func (a *Adaptor) Auth(c *gin.Context) error { +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t + var fullRequestURL string + switch meta.ActualModelName { + case "ERNIE-Bot-4": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" + case "ERNIE-Bot-8K": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k" + case "ERNIE-Bot": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" + case "ERNIE-Speed": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed" + case "ERNIE-Bot-turbo": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" + case "BLOOMZ-7B": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" + case "Embedding-V1": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" + } + var accessToken string + var err error + if accessToken, err = GetAccessToken(meta.APIKey); err != nil { + return "", err + } + fullRequestURL += "?access_token=" + accessToken + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) return nil } -func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { - return nil, nil +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) + return baiduEmbeddingRequest, nil + default: + baiduRequest := ConvertRequest(*request) + return baiduRequest, nil + } } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { - return nil, nil, nil +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + switch meta.Mode { + case constant.RelayModeEmbeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "baidu" } diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go new file mode 100644 index 00000000..0fa8f2d6 --- /dev/null +++ b/relay/channel/baidu/constants.go @@ -0,0 +1,10 @@ +package baidu + +var ModelList = []string{ + "ERNIE-Bot-4", + "ERNIE-Bot-8K", + "ERNIE-Bot", + "ERNIE-Speed", + "ERNIE-Bot-turbo", + "Embedding-V1", +} diff --git a/relay/channel/baidu/main.go b/relay/channel/baidu/main.go index 00391602..4f2b13fc 100644 --- a/relay/channel/baidu/main.go +++ b/relay/channel/baidu/main.go @@ -10,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" @@ -43,7 +44,7 @@ type Error struct { var baiduTokenStore sync.Map -func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { @@ -71,7 +72,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: 0, - Message: openai.Message{ + Message: model.Message{ Role: "assistant", Content: response.Result, }, @@ -103,7 +104,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatC return &response } -func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest { +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { return &EmbeddingRequest{ Input: request.ParseInput(), } @@ -126,8 +127,8 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin return &openAIEmbeddingResponse } -func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { - var usage openai.Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -189,7 +190,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus return nil, &usage } -func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var baiduResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -204,8 +205,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if baiduResponse.ErrorMsg != "" { - return &openai.ErrorWithStatusCode{ - Error: openai.Error{ + return &model.ErrorWithStatusCode{ + Error: model.Error{ Message: baiduResponse.ErrorMsg, Type: "baidu_error", Param: "", @@ -226,7 +227,7 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, return nil, &fullTextResponse.Usage } -func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var baiduResponse EmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -241,8 +242,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSta return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if baiduResponse.ErrorMsg != "" { - return &openai.ErrorWithStatusCode{ - Error: openai.Error{ + return &model.ErrorWithStatusCode{ + Error: model.Error{ Message: baiduResponse.ErrorMsg, Type: "baidu_error", Param: "", diff --git a/relay/channel/baidu/model.go b/relay/channel/baidu/model.go index 524418e1..cc1feb2f 100644 --- a/relay/channel/baidu/model.go +++ b/relay/channel/baidu/model.go @@ -1,18 +1,18 @@ package baidu import ( - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" "time" ) type ChatResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Result string `json:"result"` - IsTruncated bool `json:"is_truncated"` - NeedClearHistory bool `json:"need_clear_history"` - Usage openai.Usage `json:"usage"` + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage model.Usage `json:"usage"` Error } @@ -37,7 +37,7 @@ type EmbeddingResponse struct { Object string `json:"object"` Created int64 `json:"created"` Data []EmbeddingData `json:"data"` - Usage openai.Usage `json:"usage"` + Usage model.Usage `json:"usage"` Error } diff --git a/relay/channel/common.go b/relay/channel/common.go new file mode 100644 index 00000000..c6e1abf2 --- /dev/null +++ b/relay/channel/common.go @@ -0,0 +1,51 @@ +package channel + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/util" + "io" + "net/http" +) + +func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) { + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + if meta.IsStream && c.Request.Header.Get("Accept") == "" { + req.Header.Set("Accept", "text/event-stream") + } +} + +func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + fullRequestURL, err := a.GetRequestURL(meta) + if err != nil { + return nil, fmt.Errorf("get request url failed: %w", err) + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + err = a.SetupRequestHeader(c, req, meta) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := DoRequest(c, req) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + +func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) { + resp, err := util.HTTPClient.Do(req) + if err != nil { + return nil, err + } + if resp == nil { + return nil, errors.New("resp is nil") + } + _ = req.Body.Close() + _ = c.Request.Body.Close() + return resp, nil +} diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go new file mode 100644 index 00000000..ef2935fd --- /dev/null +++ b/relay/channel/gemini/adaptor.go @@ -0,0 +1,62 @@ +package gemini + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" + channelhelper "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" + "net/http" +) + +type Adaptor struct { +} + +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + version := helper.AssignOrDefault(meta.APIVersion, "v1") + action := "generateContent" + if meta.IsStream { + action = "streamGenerateContent" + } + return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channelhelper.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("x-goog-api-key", meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channelhelper.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText = StreamHandler(c, resp) + usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "google gemini" +} diff --git a/relay/channel/gemini/constants.go b/relay/channel/gemini/constants.go new file mode 100644 index 00000000..5bb0c168 --- /dev/null +++ b/relay/channel/gemini/constants.go @@ -0,0 +1,6 @@ +package gemini + +var ModelList = []string{ + "gemini-pro", + "gemini-pro-vision", +} diff --git a/relay/channel/google/gemini.go b/relay/channel/gemini/main.go similarity index 77% rename from relay/channel/google/gemini.go rename to relay/channel/gemini/main.go index 13e6a4e8..c24694c8 100644 --- a/relay/channel/google/gemini.go +++ b/relay/channel/gemini/main.go @@ -1,4 +1,4 @@ -package google +package gemini import ( "bufio" @@ -11,6 +11,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" "io" "net/http" "strings" @@ -21,14 +22,14 @@ import ( // https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn const ( - GeminiVisionMaxImageNum = 16 + VisionMaxImageNum = 16 ) // Setting safety to the lowest possible values since Gemini is already powerless enough -func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest { - geminiRequest := GeminiChatRequest{ - Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), - SafetySettings: []GeminiChatSafetySettings{ +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { + geminiRequest := ChatRequest{ + Contents: make([]ChatContent, 0, len(textRequest.Messages)), + SafetySettings: []ChatSafetySettings{ { Category: "HARM_CATEGORY_HARASSMENT", Threshold: config.GeminiSafetySetting, @@ -46,14 +47,14 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe Threshold: config.GeminiSafetySetting, }, }, - GenerationConfig: GeminiChatGenerationConfig{ + GenerationConfig: ChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP, MaxOutputTokens: textRequest.MaxTokens, }, } if textRequest.Functions != nil { - geminiRequest.Tools = []GeminiChatTools{ + geminiRequest.Tools = []ChatTools{ { FunctionDeclarations: textRequest.Functions, }, @@ -61,30 +62,30 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe } shouldAddDummyModelMessage := false for _, message := range textRequest.Messages { - content := GeminiChatContent{ + content := ChatContent{ Role: message.Role, - Parts: []GeminiPart{ + Parts: []Part{ { Text: message.StringContent(), }, }, } openaiContent := message.ParseContent() - var parts []GeminiPart + var parts []Part imageNum := 0 for _, part := range openaiContent { - if part.Type == openai.ContentTypeText { - parts = append(parts, GeminiPart{ + if part.Type == model.ContentTypeText { + parts = append(parts, Part{ Text: part.Text, }) - } else if part.Type == openai.ContentTypeImageURL { + } else if part.Type == model.ContentTypeImageURL { imageNum += 1 - if imageNum > GeminiVisionMaxImageNum { + if imageNum > VisionMaxImageNum { continue } mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ + parts = append(parts, Part{ + InlineData: &InlineData{ MimeType: mimeType, Data: data, }, @@ -106,9 +107,9 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe // If a system message is the last message, we need to add a dummy model message to make gemini happy if shouldAddDummyModelMessage { - geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ + geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{ Role: "model", - Parts: []GeminiPart{ + Parts: []Part{ { Text: "Okay", }, @@ -121,12 +122,12 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe return &geminiRequest } -type GeminiChatResponse struct { - Candidates []GeminiChatCandidate `json:"candidates"` - PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` +type ChatResponse struct { + Candidates []ChatCandidate `json:"candidates"` + PromptFeedback ChatPromptFeedback `json:"promptFeedback"` } -func (g *GeminiChatResponse) GetResponseText() string { +func (g *ChatResponse) GetResponseText() string { if g == nil { return "" } @@ -136,23 +137,23 @@ func (g *GeminiChatResponse) GetResponseText() string { return "" } -type GeminiChatCandidate struct { - Content GeminiChatContent `json:"content"` - FinishReason string `json:"finishReason"` - Index int64 `json:"index"` - SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +type ChatCandidate struct { + Content ChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []ChatSafetyRating `json:"safetyRatings"` } -type GeminiChatSafetyRating struct { +type ChatSafetyRating struct { Category string `json:"category"` Probability string `json:"probability"` } -type GeminiChatPromptFeedback struct { - SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +type ChatPromptFeedback struct { + SafetyRatings []ChatSafetyRating `json:"safetyRatings"` } -func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse { +func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Object: "chat.completion", @@ -162,7 +163,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextRespons for i, candidate := range response.Candidates { choice := openai.TextResponseChoice{ Index: i, - Message: openai.Message{ + Message: model.Message{ Role: "assistant", Content: "", }, @@ -176,7 +177,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextRespons return &fullTextResponse } -func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse { +func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = geminiResponse.GetResponseText() choice.FinishReason = &constant.StopFinishReason @@ -187,7 +188,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai return &response } -func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" dataChan := make(chan string) stopChan := make(chan bool) @@ -257,7 +258,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus return nil, responseText } -func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil @@ -266,14 +267,14 @@ func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - var geminiResponse GeminiChatResponse + var geminiResponse ChatResponse err = json.Unmarshal(responseBody, &geminiResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if len(geminiResponse.Candidates) == 0 { - return &openai.ErrorWithStatusCode{ - Error: openai.Error{ + return &model.ErrorWithStatusCode{ + Error: model.Error{ Message: "No candidates returned", Type: "server_error", Param: "", @@ -283,9 +284,9 @@ func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - fullTextResponse.Model = model - completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model) - usage := openai.Usage{ + fullTextResponse.Model = modelName + completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName) + usage := model.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, diff --git a/relay/channel/gemini/model.go b/relay/channel/gemini/model.go new file mode 100644 index 00000000..d1e3c4fd --- /dev/null +++ b/relay/channel/gemini/model.go @@ -0,0 +1,41 @@ +package gemini + +type ChatRequest struct { + Contents []ChatContent `json:"contents"` + SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"` + Tools []ChatTools `json:"tools,omitempty"` +} + +type InlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type Part struct { + Text string `json:"text,omitempty"` + InlineData *InlineData `json:"inlineData,omitempty"` +} + +type ChatContent struct { + Role string `json:"role,omitempty"` + Parts []Part `json:"parts"` +} + +type ChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type ChatTools struct { + FunctionDeclarations any `json:"functionDeclarations,omitempty"` +} + +type ChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} diff --git a/relay/channel/google/adaptor.go b/relay/channel/google/adaptor.go deleted file mode 100644 index ad45bc48..00000000 --- a/relay/channel/google/adaptor.go +++ /dev/null @@ -1,22 +0,0 @@ -package google - -import ( - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel/openai" - "net/http" -) - -type Adaptor struct { -} - -func (a *Adaptor) Auth(c *gin.Context) error { - return nil -} - -func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { - return nil, nil -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { - return nil, nil, nil -} diff --git a/relay/channel/google/model.go b/relay/channel/google/model.go deleted file mode 100644 index e69a9445..00000000 --- a/relay/channel/google/model.go +++ /dev/null @@ -1,80 +0,0 @@ -package google - -import ( - "github.com/songquanpeng/one-api/relay/channel/openai" -) - -type GeminiChatRequest struct { - Contents []GeminiChatContent `json:"contents"` - SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` - GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` - Tools []GeminiChatTools `json:"tools,omitempty"` -} - -type GeminiInlineData struct { - MimeType string `json:"mimeType"` - Data string `json:"data"` -} - -type GeminiPart struct { - Text string `json:"text,omitempty"` - InlineData *GeminiInlineData `json:"inlineData,omitempty"` -} - -type GeminiChatContent struct { - Role string `json:"role,omitempty"` - Parts []GeminiPart `json:"parts"` -} - -type GeminiChatSafetySettings struct { - Category string `json:"category"` - Threshold string `json:"threshold"` -} - -type GeminiChatTools struct { - FunctionDeclarations any `json:"functionDeclarations,omitempty"` -} - -type GeminiChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` -} - -type PaLMChatMessage struct { - Author string `json:"author"` - Content string `json:"content"` -} - -type PaLMFilter struct { - Reason string `json:"reason"` - Message string `json:"message"` -} - -type PaLMPrompt struct { - Messages []PaLMChatMessage `json:"messages"` -} - -type PaLMChatRequest struct { - Prompt PaLMPrompt `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` -} - -type PaLMError struct { - Code int `json:"code"` - Message string `json:"message"` - Status string `json:"status"` -} - -type PaLMChatResponse struct { - Candidates []PaLMChatMessage `json:"candidates"` - Messages []openai.Message `json:"messages"` - Filters []PaLMFilter `json:"filters"` - Error PaLMError `json:"error"` -} diff --git a/relay/channel/interface.go b/relay/channel/interface.go index 2a28abb8..2ecf2677 100644 --- a/relay/channel/interface.go +++ b/relay/channel/interface.go @@ -2,14 +2,18 @@ package channel import ( "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" ) type Adaptor interface { - GetRequestURL() string - Auth(c *gin.Context) error - ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) - DoRequest(request *openai.GeneralOpenAIRequest) error - DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) + GetRequestURL(meta *util.RelayMeta) (string, error) + SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) + GetModelList() []string + GetChannelName() string } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index cc302611..9af2e5c1 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -1,21 +1,80 @@ package openai import ( + "errors" + "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" + "strings" ) type Adaptor struct { } -func (a *Adaptor) Auth(c *gin.Context) error { +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + if meta.ChannelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api + requestURL := strings.Split(meta.RequestURLPath, "?")[0] + requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) + task := strings.TrimPrefix(requestURL, "/v1/") + model_ := meta.ActualModelName + model_ = strings.Replace(model_, ".", "", -1) + // https://github.com/songquanpeng/one-api/issues/67 + model_ = strings.TrimSuffix(model_, "-0301") + model_ = strings.TrimSuffix(model_, "-0314") + model_ = strings.TrimSuffix(model_, "-0613") + + requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil + } + return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + if meta.ChannelType == common.ChannelTypeAzure { + req.Header.Set("api-key", meta.APIKey) + return nil + } + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + if meta.ChannelType == common.ChannelTypeOpenRouter { + req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") + req.Header.Set("X-Title", "One API") + } return nil } -func (a *Adaptor) ConvertRequest(request *GeneralOpenAIRequest) (any, error) { - return nil, nil +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*ErrorWithStatusCode, *Usage, error) { - return nil, nil, nil +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText = StreamHandler(c, resp, meta.Mode) + usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "openai" } diff --git a/relay/channel/openai/constants.go b/relay/channel/openai/constants.go new file mode 100644 index 00000000..ea236ea1 --- /dev/null +++ b/relay/channel/openai/constants.go @@ -0,0 +1,19 @@ +package openai + +var ModelList = []string{ + "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", + "gpt-3.5-turbo-instruct", + "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", + "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", + "gpt-4-turbo-preview", + "gpt-4-vision-preview", + "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", + "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", + "text-moderation-latest", "text-moderation-stable", + "text-davinci-edit-001", + "davinci-002", "babbage-002", + "dall-e-2", "dall-e-3", + "whisper-1", + "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", +} diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go new file mode 100644 index 00000000..9bca8cab --- /dev/null +++ b/relay/channel/openai/helper.go @@ -0,0 +1,11 @@ +package openai + +import "github.com/songquanpeng/one-api/relay/model" + +func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { + usage := &model.Usage{} + usage.PromptTokens = promptTokens + usage.CompletionTokens = CountTokenText(responseText, modeName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return usage +} diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go index f56028a2..fbe55cf9 100644 --- a/relay/channel/openai/main.go +++ b/relay/channel/openai/main.go @@ -8,12 +8,13 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" "io" "net/http" "strings" ) -func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string) { responseText := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -90,7 +91,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi return nil, responseText } -func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) { +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { var textResponse SlimTextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -105,7 +106,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if textResponse.Error.Type != "" { - return &ErrorWithStatusCode{ + return &model.ErrorWithStatusCode{ Error: textResponse.Error, StatusCode: resp.StatusCode, }, nil @@ -133,9 +134,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string if textResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += CountTokenText(choice.Message.StringContent(), model) + completionTokens += CountTokenText(choice.Message.StringContent(), modelName) } - textResponse.Usage = Usage{ + textResponse.Usage = model.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go index 937fb424..c09f2334 100644 --- a/relay/channel/openai/model.go +++ b/relay/channel/openai/model.go @@ -1,15 +1,6 @@ package openai -type Message struct { - Role string `json:"role"` - Content any `json:"content"` - Name *string `json:"name,omitempty"` -} - -type ImageURL struct { - Url string `json:"url,omitempty"` - Detail string `json:"detail,omitempty"` -} +import "github.com/songquanpeng/one-api/relay/model" type TextContent struct { Type string `json:"type,omitempty"` @@ -17,142 +8,21 @@ type TextContent struct { } type ImageContent struct { - Type string `json:"type,omitempty"` - ImageURL *ImageURL `json:"image_url,omitempty"` -} - -type OpenAIMessageContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text"` - ImageURL *ImageURL `json:"image_url,omitempty"` -} - -func (m Message) IsStringContent() bool { - _, ok := m.Content.(string) - return ok -} - -func (m Message) StringContent() string { - content, ok := m.Content.(string) - if ok { - return content - } - contentList, ok := m.Content.([]any) - if ok { - var contentStr string - for _, contentItem := range contentList { - contentMap, ok := contentItem.(map[string]any) - if !ok { - continue - } - if contentMap["type"] == ContentTypeText { - if subStr, ok := contentMap["text"].(string); ok { - contentStr += subStr - } - } - } - return contentStr - } - return "" -} - -func (m Message) ParseContent() []OpenAIMessageContent { - var contentList []OpenAIMessageContent - content, ok := m.Content.(string) - if ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeText, - Text: content, - }) - return contentList - } - anyList, ok := m.Content.([]any) - if ok { - for _, contentItem := range anyList { - contentMap, ok := contentItem.(map[string]any) - if !ok { - continue - } - switch contentMap["type"] { - case ContentTypeText: - if subStr, ok := contentMap["text"].(string); ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeText, - Text: subStr, - }) - } - case ContentTypeImageURL: - if subObj, ok := contentMap["image_url"].(map[string]any); ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeImageURL, - ImageURL: &ImageURL{ - Url: subObj["url"].(string), - }, - }) - } - } - } - return contentList - } - return nil -} - -type ResponseFormat struct { - Type string `json:"type,omitempty"` -} - -type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` -} - -func (r GeneralOpenAIRequest) ParseInput() []string { - if r.Input == nil { - return nil - } - var input []string - switch r.Input.(type) { - case string: - input = []string{r.Input.(string)} - case []any: - input = make([]string, 0, len(r.Input.([]any))) - for _, item := range r.Input.([]any) { - if str, ok := item.(string); ok { - input = append(input, str) - } - } - } - return input + Type string `json:"type,omitempty"` + ImageURL *model.ImageURL `json:"image_url,omitempty"` } type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - MaxTokens int `json:"max_tokens"` + Model string `json:"model"` + Messages []model.Message `json:"messages"` + MaxTokens int `json:"max_tokens"` } type TextRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Prompt string `json:"prompt"` - MaxTokens int `json:"max_tokens"` + Model string `json:"model"` + Messages []model.Message `json:"messages"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` //Stream bool `json:"stream"` } @@ -201,48 +71,30 @@ type TextToSpeechRequest struct { ResponseFormat string `json:"response_format"` } -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - type UsageOrResponseText struct { - *Usage + *model.Usage ResponseText string } -type Error struct { - Message string `json:"message"` - Type string `json:"type"` - Param string `json:"param"` - Code any `json:"code"` -} - -type ErrorWithStatusCode struct { - Error - StatusCode int `json:"status_code"` -} - type SlimTextResponse struct { - Choices []TextResponseChoice `json:"choices"` - Usage `json:"usage"` - Error Error `json:"error"` + Choices []TextResponseChoice `json:"choices"` + model.Usage `json:"usage"` + Error model.Error `json:"error"` } type TextResponseChoice struct { - Index int `json:"index"` - Message `json:"message"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + model.Message `json:"message"` + FinishReason string `json:"finish_reason"` } type TextResponse struct { - Id string `json:"id"` - Model string `json:"model,omitempty"` - Object string `json:"object"` - Created int64 `json:"created"` - Choices []TextResponseChoice `json:"choices"` - Usage `json:"usage"` + Id string `json:"id"` + Model string `json:"model,omitempty"` + Object string `json:"object"` + Created int64 `json:"created"` + Choices []TextResponseChoice `json:"choices"` + model.Usage `json:"usage"` } type EmbeddingResponseItem struct { @@ -252,10 +104,10 @@ type EmbeddingResponseItem struct { } type EmbeddingResponse struct { - Object string `json:"object"` - Data []EmbeddingResponseItem `json:"data"` - Model string `json:"model"` - Usage `json:"usage"` + Object string `json:"object"` + Data []EmbeddingResponseItem `json:"data"` + Model string `json:"model"` + model.Usage `json:"usage"` } type ImageResponse struct { diff --git a/relay/channel/openai/token.go b/relay/channel/openai/token.go index 686ac39f..0720425f 100644 --- a/relay/channel/openai/token.go +++ b/relay/channel/openai/token.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/model" "math" "strings" ) @@ -63,7 +64,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } -func CountTokenMessages(messages []Message, model string) int { +func CountTokenMessages(messages []model.Message, model string) int { tokenEncoder := getTokenEncoder(model) // Reference: // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb diff --git a/relay/channel/openai/util.go b/relay/channel/openai/util.go index 69ece6b3..ba0cab7d 100644 --- a/relay/channel/openai/util.go +++ b/relay/channel/openai/util.go @@ -1,12 +1,14 @@ package openai -func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode { - Error := Error{ +import "github.com/songquanpeng/one-api/relay/model" + +func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { + Error := model.Error{ Message: err.Error(), Type: "one_api_error", Code: code, } - return &ErrorWithStatusCode{ + return &model.ErrorWithStatusCode{ Error: Error, StatusCode: statusCode, } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go new file mode 100644 index 00000000..dfba56ff --- /dev/null +++ b/relay/channel/palm/adaptor.go @@ -0,0 +1,56 @@ +package palm + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" + "net/http" +) + +type Adaptor struct { +} + +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("x-goog-api-key", meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText = StreamHandler(c, resp) + usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "google palm" +} diff --git a/relay/channel/palm/constants.go b/relay/channel/palm/constants.go new file mode 100644 index 00000000..a8349362 --- /dev/null +++ b/relay/channel/palm/constants.go @@ -0,0 +1,5 @@ +package palm + +var ModelList = []string{ + "PaLM-2", +} diff --git a/relay/channel/palm/model.go b/relay/channel/palm/model.go new file mode 100644 index 00000000..f653022c --- /dev/null +++ b/relay/channel/palm/model.go @@ -0,0 +1,40 @@ +package palm + +import ( + "github.com/songquanpeng/one-api/relay/model" +) + +type ChatMessage struct { + Author string `json:"author"` + Content string `json:"content"` +} + +type Filter struct { + Reason string `json:"reason"` + Message string `json:"message"` +} + +type Prompt struct { + Messages []ChatMessage `json:"messages"` +} + +type ChatRequest struct { + Prompt Prompt `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` +} + +type Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` +} + +type ChatResponse struct { + Candidates []ChatMessage `json:"candidates"` + Messages []model.Message `json:"messages"` + Filters []Filter `json:"filters"` + Error Error `json:"error"` +} diff --git a/relay/channel/google/palm.go b/relay/channel/palm/palm.go similarity index 84% rename from relay/channel/google/palm.go rename to relay/channel/palm/palm.go index 7b9ee600..56738544 100644 --- a/relay/channel/google/palm.go +++ b/relay/channel/palm/palm.go @@ -1,4 +1,4 @@ -package google +package palm import ( "encoding/json" @@ -9,6 +9,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" "io" "net/http" ) @@ -16,10 +17,10 @@ import ( // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body -func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest { - palmRequest := PaLMChatRequest{ - Prompt: PaLMPrompt{ - Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { + palmRequest := ChatRequest{ + Prompt: Prompt{ + Messages: make([]ChatMessage, 0, len(textRequest.Messages)), }, Temperature: textRequest.Temperature, CandidateCount: textRequest.N, @@ -27,7 +28,7 @@ func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatReques TopK: textRequest.MaxTokens, } for _, message := range textRequest.Messages { - palmMessage := PaLMChatMessage{ + palmMessage := ChatMessage{ Content: message.StringContent(), } if message.Role == "user" { @@ -40,14 +41,14 @@ func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatReques return &palmRequest } -func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse { +func responsePaLM2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { choice := openai.TextResponseChoice{ Index: i, - Message: openai.Message{ + Message: model.Message{ Role: "assistant", Content: candidate.Content, }, @@ -58,7 +59,7 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse { return &fullTextResponse } -func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse { +func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { var choice openai.ChatCompletionsStreamResponseChoice if len(palmResponse.Candidates) > 0 { choice.Delta.Content = palmResponse.Candidates[0].Content @@ -71,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompl return &response } -func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) createdTime := helper.GetTimestamp() @@ -90,7 +91,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt stopChan <- true return } - var palmResponse PaLMChatResponse + var palmResponse ChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { logger.SysError("error unmarshalling stream response: " + err.Error()) @@ -130,7 +131,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt return nil, responseText } -func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil @@ -139,14 +140,14 @@ func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model st if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - var palmResponse PaLMChatResponse + var palmResponse ChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { - return &openai.ErrorWithStatusCode{ - Error: openai.Error{ + return &model.ErrorWithStatusCode{ + Error: model.Error{ Message: palmResponse.Error.Message, Type: palmResponse.Error.Status, Param: "", @@ -156,9 +157,9 @@ func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - fullTextResponse.Model = model - completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model) - usage := openai.Usage{ + fullTextResponse.Model = modelName + completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, modelName) + usage := model.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index c90509ca..e262bfb7 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -1,22 +1,69 @@ package tencent import ( + "errors" + "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" + "strings" ) type Adaptor struct { + Sign string } -func (a *Adaptor) Auth(c *gin.Context) error { +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", a.Sign) return nil } -func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { - return nil, nil +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + appId, secretId, secretKey, err := ParseConfig(apiKey) + if err != nil { + return nil, err + } + tencentRequest := ConvertRequest(*request) + tencentRequest.AppId = appId + tencentRequest.SecretId = secretId + // we have to calculate the sign here + a.Sign = GetSign(*tencentRequest, secretKey) + return tencentRequest, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { - return nil, nil, nil +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + var responseText string + err, responseText = StreamHandler(c, resp) + usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } else { + err, usage = Handler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "tencent" } diff --git a/relay/channel/tencent/constants.go b/relay/channel/tencent/constants.go new file mode 100644 index 00000000..1d13066d --- /dev/null +++ b/relay/channel/tencent/constants.go @@ -0,0 +1,5 @@ +package tencent + +var ModelList = []string{ + "hunyuan", +} diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go index 784f86fd..05edac20 100644 --- a/relay/channel/tencent/main.go +++ b/relay/channel/tencent/main.go @@ -14,6 +14,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" "io" "net/http" "sort" @@ -23,7 +24,7 @@ import ( // https://cloud.tencent.com/document/product/1729/97732 -func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] @@ -67,7 +68,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { if len(response.Choices) > 0 { choice := openai.TextResponseChoice{ Index: 0, - Message: openai.Message{ + Message: model.Message{ Role: "assistant", Content: response.Choices[0].Messages.Content, }, @@ -95,7 +96,7 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom return &response } -func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { var responseText string scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -159,7 +160,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus return nil, responseText } -func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var TencentResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -174,8 +175,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if TencentResponse.Error.Code != 0 { - return &openai.ErrorWithStatusCode{ - Error: openai.Error{ + return &model.ErrorWithStatusCode{ + Error: model.Error{ Message: TencentResponse.Error.Message, Code: TencentResponse.Error.Code, }, diff --git a/relay/channel/tencent/model.go b/relay/channel/tencent/model.go index b8aa7698..71286be9 100644 --- a/relay/channel/tencent/model.go +++ b/relay/channel/tencent/model.go @@ -1,7 +1,7 @@ package tencent import ( - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" ) type Message struct { @@ -56,7 +56,7 @@ type ChatResponse struct { Choices []ResponseChoices `json:"choices,omitempty"` // 结果 Created string `json:"created,omitempty"` // unix 时间戳的字符串 Id string `json:"id,omitempty"` // 会话 id - Usage openai.Usage `json:"usage,omitempty"` // token 数量 + Usage model.Usage `json:"usage,omitempty"` // token 数量 Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 Note string `json:"note,omitempty"` // 注释 ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index d2c80c64..8683fcf3 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -1,22 +1,66 @@ package xunfei import ( + "errors" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" + "strings" ) type Adaptor struct { + request *model.GeneralOpenAIRequest } -func (a *Adaptor) Auth(c *gin.Context) error { +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + // check DoResponse for auth part return nil } -func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + a.request = request return nil, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { - return nil, nil, nil +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + // xunfei's request is not http request, so we don't need to do anything here + dummyResp := &http.Response{} + dummyResp.StatusCode = http.StatusOK + return dummyResp, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + splits := strings.Split(meta.APIKey, "|") + if len(splits) != 3 { + return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + } + if a.request == nil { + return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) + } + if meta.IsStream { + err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2]) + } else { + err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2]) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "xunfei" } diff --git a/relay/channel/xunfei/constants.go b/relay/channel/xunfei/constants.go new file mode 100644 index 00000000..41846c41 --- /dev/null +++ b/relay/channel/xunfei/constants.go @@ -0,0 +1,5 @@ +package xunfei + +var ModelList = []string{ + "SparkDesk", +} diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go index ff5cdbea..8efade87 100644 --- a/relay/channel/xunfei/main.go +++ b/relay/channel/xunfei/main.go @@ -13,6 +13,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" "io" "net/http" "net/url" @@ -23,7 +24,7 @@ import ( // https://console.xfyun.cn/services/cbm // https://www.xfyun.cn/doc/spark/Web.html -func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { +func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { @@ -62,7 +63,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { } choice := openai.TextResponseChoice{ Index: 0, - Message: openai.Message{ + Message: model.Message{ Role: "assistant", Content: response.Payload.Choices.Text[0].Content, }, @@ -125,14 +126,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { return callUrl } -func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { +func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } common.SetEventStreamHeaders(c) - var usage openai.Usage + var usage model.Usage c.Stream(func(w io.Writer) bool { select { case xunfeiResponse := <-dataChan: @@ -155,13 +156,13 @@ func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appI return nil, &usage } -func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { +func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - var usage openai.Usage + var usage model.Usage var content string var xunfeiResponse ChatResponse stop := false @@ -197,7 +198,7 @@ func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId stri return nil, &usage } -func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) { +func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } diff --git a/relay/channel/xunfei/model.go b/relay/channel/xunfei/model.go index e015d164..1266739d 100644 --- a/relay/channel/xunfei/model.go +++ b/relay/channel/xunfei/model.go @@ -1,7 +1,7 @@ package xunfei import ( - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" ) type Message struct { @@ -55,7 +55,7 @@ type ChatResponse struct { // CompletionTokens string `json:"completion_tokens"` // TotalTokens string `json:"total_tokens"` //} `json:"text"` - Text openai.Usage `json:"text"` + Text model.Usage `json:"text"` } `json:"usage"` } `json:"payload"` } diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index ae0f6faa..3afbe5c6 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -1,22 +1,58 @@ package zhipu import ( + "errors" + "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" ) type Adaptor struct { } -func (a *Adaptor) Auth(c *gin.Context) error { +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + method := "invoke" + if meta.IsStream { + method = "sse-invoke" + } + return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + token := GetToken(meta.APIKey) + req.Header.Set("Authorization", token) return nil } -func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { - return nil, nil +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { - return nil, nil, nil +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "zhipu" } diff --git a/relay/channel/zhipu/constants.go b/relay/channel/zhipu/constants.go new file mode 100644 index 00000000..f0367b82 --- /dev/null +++ b/relay/channel/zhipu/constants.go @@ -0,0 +1,5 @@ +package zhipu + +var ModelList = []string{ + "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", +} diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go index fe7b7b4b..7c3e83f3 100644 --- a/relay/channel/zhipu/main.go +++ b/relay/channel/zhipu/main.go @@ -10,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" "io" "net/http" "strings" @@ -72,7 +73,7 @@ func GetToken(apikey string) string { return tokenString } -func ConvertRequest(request openai.GeneralOpenAIRequest) *Request { +func ConvertRequest(request model.GeneralOpenAIRequest) *Request { messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { @@ -110,7 +111,7 @@ func responseZhipu2OpenAI(response *Response) *openai.TextResponse { for i, choice := range response.Data.Choices { openaiChoice := openai.TextResponseChoice{ Index: i, - Message: openai.Message{ + Message: model.Message{ Role: choice.Role, Content: strings.Trim(choice.Content, "\""), }, @@ -136,7 +137,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStr return &response } -func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) { +func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *model.Usage) { var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = "" choice.FinishReason = &constant.StopFinishReason @@ -150,8 +151,8 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai. return &response, &zhipuResponse.Usage } -func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { - var usage *openai.Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage *model.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -228,7 +229,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus return nil, usage } -func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var zhipuResponse Response responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -243,8 +244,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if !zhipuResponse.Success { - return &openai.ErrorWithStatusCode{ - Error: openai.Error{ + return &model.ErrorWithStatusCode{ + Error: model.Error{ Message: zhipuResponse.Msg, Type: "zhipu_error", Param: "", diff --git a/relay/channel/zhipu/model.go b/relay/channel/zhipu/model.go index 67f1caeb..b63e1d6f 100644 --- a/relay/channel/zhipu/model.go +++ b/relay/channel/zhipu/model.go @@ -1,7 +1,7 @@ package zhipu import ( - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" "time" ) @@ -19,11 +19,11 @@ type Request struct { } type ResponseData struct { - TaskId string `json:"task_id"` - RequestId string `json:"request_id"` - TaskStatus string `json:"task_status"` - Choices []Message `json:"choices"` - openai.Usage `json:"usage"` + TaskId string `json:"task_id"` + RequestId string `json:"request_id"` + TaskStatus string `json:"task_status"` + Choices []Message `json:"choices"` + model.Usage `json:"usage"` } type Response struct { @@ -34,10 +34,10 @@ type Response struct { } type StreamMetaResponse struct { - RequestId string `json:"request_id"` - TaskId string `json:"task_id"` - TaskStatus string `json:"task_status"` - openai.Usage `json:"usage"` + RequestId string `json:"request_id"` + TaskId string `json:"task_id"` + TaskStatus string `json:"task_status"` + model.Usage `json:"usage"` } type tokenData struct { diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index e0458279..d2184dac 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -6,7 +6,7 @@ import ( const ( APITypeOpenAI = iota - APITypeClaude + APITypeAnthropic APITypePaLM APITypeBaidu APITypeZhipu @@ -15,13 +15,15 @@ const ( APITypeAIProxyLibrary APITypeTencent APITypeGemini + + APITypeDummy // this one is only for count, do not add any channel after this ) func ChannelType2APIType(channelType int) int { apiType := APITypeOpenAI switch channelType { case common.ChannelTypeAnthropic: - apiType = APITypeClaude + apiType = APITypeAnthropic case common.ChannelTypeBaidu: apiType = APITypeBaidu case common.ChannelTypePaLM: @@ -41,29 +43,3 @@ func ChannelType2APIType(channelType int) int { } return apiType } - -//func GetAdaptor(apiType int) channel.Adaptor { -// switch apiType { -// case APITypeOpenAI: -// return &openai.Adaptor{} -// case APITypeClaude: -// return &anthropic.Adaptor{} -// case APITypePaLM: -// return &google.Adaptor{} -// case APITypeZhipu: -// return &baidu.Adaptor{} -// case APITypeBaidu: -// return &baidu.Adaptor{} -// case APITypeAli: -// return &ali.Adaptor{} -// case APITypeXunfei: -// return &xunfei.Adaptor{} -// case APITypeAIProxyLibrary: -// return &aiproxy.Adaptor{} -// case APITypeTencent: -// return &tencent.Adaptor{} -// case APITypeGemini: -// return &google.Adaptor{} -// } -// return nil -//} diff --git a/relay/controller/audio.go b/relay/controller/audio.go index cbbd8a04..ee8771c9 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -14,13 +14,14 @@ import ( "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" ) -func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { +func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { audioModel := "whisper-1" tokenId := c.GetInt("token_id") diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 6154f291..a06b2768 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -11,14 +11,14 @@ import ( "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" - "io" "math" "net/http" ) -func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOpenAIRequest, error) { - textRequest := &openai.GeneralOpenAIRequest{} +func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { + textRequest := &relaymodel.GeneralOpenAIRequest{} err := common.UnmarshalBodyReusable(c, textRequest) if err != nil { return nil, err @@ -36,7 +36,7 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOp return textRequest, nil } -func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) int { +func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case constant.RelayModeChatCompletions: return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) @@ -48,7 +48,7 @@ func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) in return 0 } -func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64) int { +func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int { preConsumedTokens := config.PreConsumedQuota if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + textRequest.MaxTokens @@ -56,7 +56,7 @@ func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens return int(float64(preConsumedTokens) * ratio) } -func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *openai.ErrorWithStatusCode) { +func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) userQuota, err := model.CacheGetUserQuota(meta.UserId) @@ -85,7 +85,7 @@ func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIReque return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return @@ -120,27 +120,3 @@ func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.Relay model.UpdateChannelUsedQuota(meta.ChannelId, quota) } } - -func doRequest(ctx context.Context, c *gin.Context, meta *util.RelayMeta, isStream bool, fullRequestURL string, requestBody io.Reader) (*http.Response, error) { - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return nil, err - } - SetupRequestHeaders(c, req, meta, isStream) - resp, err := util.HTTPClient.Do(req) - if err != nil { - return nil, err - } - if resp == nil { - return nil, errors.New("resp is nil") - } - err = req.Body.Close() - if err != nil { - logger.Warnf(ctx, "close req.Body failed: %+v", err) - } - err = c.Request.Body.Close() - if err != nil { - logger.Warnf(ctx, "close c.Request.Body failed: %+v", err) - } - return resp, nil -} diff --git a/relay/controller/image.go b/relay/controller/image.go index c64e001b..6ec368f5 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -10,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" + relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" @@ -28,7 +29,7 @@ func isWithinRange(element string, value int) bool { return value >= min && value <= max } -func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { +func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { imageModel := "dall-e-2" imageSize := "1024x1024" diff --git a/relay/controller/temp.go b/relay/controller/temp.go index 6339bdab..75aea4ff 100644 --- a/relay/controller/temp.go +++ b/relay/controller/temp.go @@ -12,19 +12,21 @@ import ( "github.com/songquanpeng/one-api/relay/channel/ali" "github.com/songquanpeng/one-api/relay/channel/anthropic" "github.com/songquanpeng/one-api/relay/channel/baidu" - "github.com/songquanpeng/one-api/relay/channel/google" + "github.com/songquanpeng/one-api/relay/channel/gemini" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/channel/palm" "github.com/songquanpeng/one-api/relay/channel/tencent" "github.com/songquanpeng/one-api/relay/channel/xunfei" "github.com/songquanpeng/one-api/relay/channel/zhipu" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" ) -func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) { +func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *model.GeneralOpenAIRequest) (string, error) { fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) switch meta.APIType { case constant.APITypeOpenAI: @@ -43,7 +45,7 @@ func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai. requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) } - case constant.APITypeClaude: + case constant.APITypeAnthropic: fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL) case constant.APITypeBaidu: switch textRequest.Model { @@ -92,19 +94,10 @@ func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai. return fullRequestURL, nil } -func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) { +func GetRequestBody(c *gin.Context, textRequest model.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) { var requestBody io.Reader - if isModelMapped { - jsonStr, err := json.Marshal(textRequest) - if err != nil { - return nil, err - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } switch apiType { - case constant.APITypeClaude: + case constant.APITypeAnthropic: claudeRequest := anthropic.ConvertRequest(textRequest) jsonStr, err := json.Marshal(claudeRequest) if err != nil { @@ -127,14 +120,14 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM } requestBody = bytes.NewBuffer(jsonData) case constant.APITypePaLM: - palmRequest := google.ConvertPaLMRequest(textRequest) + palmRequest := palm.ConvertRequest(textRequest) jsonStr, err := json.Marshal(palmRequest) if err != nil { return nil, err } requestBody = bytes.NewBuffer(jsonStr) case constant.APITypeGemini: - geminiChatRequest := google.ConvertGeminiRequest(textRequest) + geminiChatRequest := gemini.ConvertRequest(textRequest) jsonStr, err := json.Marshal(geminiChatRequest) if err != nil { return nil, err @@ -187,19 +180,20 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM return nil, err } requestBody = bytes.NewBuffer(jsonStr) + default: + if isModelMapped { + jsonStr, err := json.Marshal(textRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = c.Request.Body + } } return requestBody, nil } -func SetupRequestHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) { - SetupAuthHeaders(c, req, meta, isStream) - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - if isStream && c.Request.Header.Get("Accept") == "" { - req.Header.Set("Accept", "text/event-stream") - } -} - func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) { apiKey := meta.APIKey switch meta.APIType { @@ -213,7 +207,7 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i req.Header.Set("X-Title", "One API") } } - case constant.APITypeClaude: + case constant.APITypeAnthropic: req.Header.Set("x-api-key", apiKey) anthropicVersion := c.Request.Header.Get("anthropic-version") if anthropicVersion == "" { @@ -242,7 +236,7 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i } } -func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) { +func DoResponse(c *gin.Context, textRequest *model.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *model.Usage, err *model.ErrorWithStatusCode) { var responseText string switch apiType { case constant.APITypeOpenAI: @@ -251,7 +245,7 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp * } else { err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model) } - case constant.APITypeClaude: + case constant.APITypeAnthropic: if isStream { err, responseText = anthropic.StreamHandler(c, resp) } else { @@ -270,15 +264,15 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp * } case constant.APITypePaLM: if isStream { // PaLM2 API does not support stream - err, responseText = google.PaLMStreamHandler(c, resp) + err, responseText = palm.StreamHandler(c, resp) } else { - err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model) + err, usage = palm.Handler(c, resp, promptTokens, textRequest.Model) } case constant.APITypeGemini: if isStream { - err, responseText = google.StreamHandler(c, resp) + err, responseText = gemini.StreamHandler(c, resp) } else { - err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model) + err, usage = gemini.Handler(c, resp, promptTokens, textRequest.Model) } case constant.APITypeZhipu: if isStream { @@ -328,7 +322,7 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp * return nil, err } if usage == nil && responseText != "" { - usage = &openai.Usage{} + usage = &model.Usage{} usage.PromptTokens = promptTokens usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens diff --git a/relay/controller/text.go b/relay/controller/text.go index 0445aa90..7c49bcce 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -1,18 +1,23 @@ package controller import ( + "bytes" + "encoding/json" "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" "strings" ) -func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode { +func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { ctx := c.Request.Context() meta := util.GetRelayMeta(c) // get & validate textRequest @@ -21,9 +26,13 @@ func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode { logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) } + meta.IsStream = textRequest.Stream + // map model name var isModelMapped bool + meta.OriginModelName = textRequest.Model textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping) + meta.ActualModelName = textRequest.Model // get model ratio & group ratio modelRatio := common.GetModelRatio(textRequest.Model) groupRatio := common.GetGroupRatio(meta.Group) @@ -36,35 +45,50 @@ func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode { return bizErr } + adaptor := helper.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + // get request body - requestBody, err := GetRequestBody(c, *textRequest, isModelMapped, meta.APIType, meta.Mode) - if err != nil { - return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError) + var requestBody io.Reader + if meta.APIType == constant.APITypeOpenAI { + // no need to convert request for openai + if isModelMapped { + jsonStr, err := json.Marshal(textRequest) + if err != nil { + return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = c.Request.Body + } + } else { + convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError) + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonData) } + // do request - var resp *http.Response - isStream := textRequest.Stream - if meta.APIType != constant.APITypeXunfei { // cause xunfei use websocket - fullRequestURL, err := GetRequestURL(c.Request.URL.String(), meta, textRequest) - if err != nil { - logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error())) - return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError) - } - - resp, err = doRequest(ctx, c, meta, isStream, fullRequestURL, requestBody) - if err != nil { - logger.Errorf(ctx, "doRequest failed: %s", err.Error()) - return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) - } - isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - - if resp.StatusCode != http.StatusOK { - util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) - return util.RelayErrorHandler(resp) - } + resp, err := adaptor.DoRequest(c, meta, requestBody) + if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } + meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + if resp.StatusCode != http.StatusOK { + util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + return util.RelayErrorHandler(resp) + } + // do response - usage, respErr := DoResponse(c, textRequest, resp, meta.Mode, meta.APIType, isStream, promptTokens) + usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { logger.Errorf(ctx, "respErr is not nil: %+v", respErr) util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) diff --git a/relay/helper/main.go b/relay/helper/main.go new file mode 100644 index 00000000..c2b6e6af --- /dev/null +++ b/relay/helper/main.go @@ -0,0 +1,42 @@ +package helper + +import ( + "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/channel/aiproxy" + "github.com/songquanpeng/one-api/relay/channel/ali" + "github.com/songquanpeng/one-api/relay/channel/anthropic" + "github.com/songquanpeng/one-api/relay/channel/baidu" + "github.com/songquanpeng/one-api/relay/channel/gemini" + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/channel/palm" + "github.com/songquanpeng/one-api/relay/channel/tencent" + "github.com/songquanpeng/one-api/relay/channel/xunfei" + "github.com/songquanpeng/one-api/relay/channel/zhipu" + "github.com/songquanpeng/one-api/relay/constant" +) + +func GetAdaptor(apiType int) channel.Adaptor { + switch apiType { + case constant.APITypeAIProxyLibrary: + return &aiproxy.Adaptor{} + case constant.APITypeAli: + return &ali.Adaptor{} + case constant.APITypeAnthropic: + return &anthropic.Adaptor{} + case constant.APITypeBaidu: + return &baidu.Adaptor{} + case constant.APITypeGemini: + return &gemini.Adaptor{} + case constant.APITypeOpenAI: + return &openai.Adaptor{} + case constant.APITypePaLM: + return &palm.Adaptor{} + case constant.APITypeTencent: + return &tencent.Adaptor{} + case constant.APITypeXunfei: + return &xunfei.Adaptor{} + case constant.APITypeZhipu: + return &zhipu.Adaptor{} + } + return nil +} diff --git a/relay/channel/openai/constant.go b/relay/model/constant.go similarity index 83% rename from relay/channel/openai/constant.go rename to relay/model/constant.go index 000f72ee..f6cf1924 100644 --- a/relay/channel/openai/constant.go +++ b/relay/model/constant.go @@ -1,4 +1,4 @@ -package openai +package model const ( ContentTypeText = "text" diff --git a/relay/model/general.go b/relay/model/general.go new file mode 100644 index 00000000..fbcc04e8 --- /dev/null +++ b/relay/model/general.go @@ -0,0 +1,46 @@ +package model + +type ResponseFormat struct { + Type string `json:"type,omitempty"` +} + +type GeneralOpenAIRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` +} + +func (r GeneralOpenAIRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} diff --git a/relay/model/message.go b/relay/model/message.go new file mode 100644 index 00000000..c6c8a271 --- /dev/null +++ b/relay/model/message.go @@ -0,0 +1,88 @@ +package model + +type Message struct { + Role string `json:"role"` + Content any `json:"content"` + Name *string `json:"name,omitempty"` +} + +func (m Message) IsStringContent() bool { + _, ok := m.Content.(string) + return ok +} + +func (m Message) StringContent() string { + content, ok := m.Content.(string) + if ok { + return content + } + contentList, ok := m.Content.([]any) + if ok { + var contentStr string + for _, contentItem := range contentList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + if contentMap["type"] == ContentTypeText { + if subStr, ok := contentMap["text"].(string); ok { + contentStr += subStr + } + } + } + return contentStr + } + return "" +} + +func (m Message) ParseContent() []MessageContent { + var contentList []MessageContent + content, ok := m.Content.(string) + if ok { + contentList = append(contentList, MessageContent{ + Type: ContentTypeText, + Text: content, + }) + return contentList + } + anyList, ok := m.Content.([]any) + if ok { + for _, contentItem := range anyList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + switch contentMap["type"] { + case ContentTypeText: + if subStr, ok := contentMap["text"].(string); ok { + contentList = append(contentList, MessageContent{ + Type: ContentTypeText, + Text: subStr, + }) + } + case ContentTypeImageURL: + if subObj, ok := contentMap["image_url"].(map[string]any); ok { + contentList = append(contentList, MessageContent{ + Type: ContentTypeImageURL, + ImageURL: &ImageURL{ + Url: subObj["url"].(string), + }, + }) + } + } + } + return contentList + } + return nil +} + +type ImageURL struct { + Url string `json:"url,omitempty"` + Detail string `json:"detail,omitempty"` +} + +type MessageContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} diff --git a/relay/model/misc.go b/relay/model/misc.go new file mode 100644 index 00000000..163bc398 --- /dev/null +++ b/relay/model/misc.go @@ -0,0 +1,19 @@ +package model + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Error struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param"` + Code any `json:"code"` +} + +type ErrorWithStatusCode struct { + Error + StatusCode int `json:"status_code"` +} diff --git a/relay/util/common.go b/relay/util/common.go index 3a28b09e..21e1dfaf 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -8,7 +8,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" + relaymodel "github.com/songquanpeng/one-api/relay/model" "io" "net/http" "strconv" @@ -17,7 +17,7 @@ import ( "github.com/gin-gonic/gin" ) -func ShouldDisableChannel(err *openai.Error, statusCode int) bool { +func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { if !config.AutomaticDisableChannelEnabled { return false } @@ -33,7 +33,7 @@ func ShouldDisableChannel(err *openai.Error, statusCode int) bool { return false } -func ShouldEnableChannel(err error, openAIErr *openai.Error) bool { +func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool { if !config.AutomaticEnableChannelEnabled { return false } @@ -47,11 +47,11 @@ func ShouldEnableChannel(err error, openAIErr *openai.Error) bool { } type GeneralErrorResponse struct { - Error openai.Error `json:"error"` - Message string `json:"message"` - Msg string `json:"msg"` - Err string `json:"err"` - ErrorMsg string `json:"error_msg"` + Error relaymodel.Error `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` Header struct { Message string `json:"message"` } `json:"header"` @@ -87,10 +87,10 @@ func (e GeneralErrorResponse) ToMessage() string { return "" } -func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) { - ErrorWithStatusCode = &openai.ErrorWithStatusCode{ +func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) { + ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{ StatusCode: resp.StatusCode, - Error: openai.Error{ + Error: relaymodel.Error{ Message: "", Type: "upstream_error", Code: "bad_response_status_code", diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go index 27757dcf..58afc23f 100644 --- a/relay/util/relay_meta.go +++ b/relay/util/relay_meta.go @@ -8,35 +8,41 @@ import ( ) type RelayMeta struct { - Mode int - ChannelType int - ChannelId int - TokenId int - TokenName string - UserId int - Group string - ModelMapping map[string]string - BaseURL string - APIVersion string - APIKey string - APIType int - Config map[string]string + Mode int + ChannelType int + ChannelId int + TokenId int + TokenName string + UserId int + Group string + ModelMapping map[string]string + BaseURL string + APIVersion string + APIKey string + APIType int + Config map[string]string + IsStream bool + OriginModelName string + ActualModelName string + RequestURLPath string + PromptTokens int // only for DoResponse } func GetRelayMeta(c *gin.Context) *RelayMeta { meta := RelayMeta{ - Mode: constant.Path2RelayMode(c.Request.URL.Path), - ChannelType: c.GetInt("channel"), - ChannelId: c.GetInt("channel_id"), - TokenId: c.GetInt("token_id"), - TokenName: c.GetString("token_name"), - UserId: c.GetInt("id"), - Group: c.GetString("group"), - ModelMapping: c.GetStringMapString("model_mapping"), - BaseURL: c.GetString("base_url"), - APIVersion: c.GetString("api_version"), - APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - Config: nil, + Mode: constant.Path2RelayMode(c.Request.URL.Path), + ChannelType: c.GetInt("channel"), + ChannelId: c.GetInt("channel_id"), + TokenId: c.GetInt("token_id"), + TokenName: c.GetString("token_name"), + UserId: c.GetInt("id"), + Group: c.GetString("group"), + ModelMapping: c.GetStringMapString("model_mapping"), + BaseURL: c.GetString("base_url"), + APIVersion: c.GetString("api_version"), + APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + Config: nil, + RequestURLPath: c.Request.URL.String(), } if meta.ChannelType == common.ChannelTypeAzure { meta.APIVersion = GetAzureAPIVersion(c) diff --git a/relay/util/validation.go b/relay/util/validation.go index 8848af8e..ef8d840c 100644 --- a/relay/util/validation.go +++ b/relay/util/validation.go @@ -2,12 +2,12 @@ package util import ( "errors" - "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" "math" ) -func ValidateTextRequest(textRequest *openai.GeneralOpenAIRequest, relayMode int) error { +func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error { if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { return errors.New("max_tokens is invalid") }