diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml
index d92ed0da..04782864 100644
--- a/.github/workflows/linux-release.yml
+++ b/.github/workflows/linux-release.yml
@@ -23,7 +23,7 @@ jobs:
- uses: actions/setup-node@v3
with:
node-version: 16
- - name: Build Frontend (theme default)
+ - name: Build Frontend
env:
CI: ""
run: |
diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml
index ce9d1f11..9142609f 100644
--- a/.github/workflows/macos-release.yml
+++ b/.github/workflows/macos-release.yml
@@ -23,7 +23,7 @@ jobs:
- uses: actions/setup-node@v3
with:
node-version: 16
- - name: Build Frontend (theme default)
+ - name: Build Frontend
env:
CI: ""
run: |
diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml
index 9b1f16ba..c058f41d 100644
--- a/.github/workflows/windows-release.yml
+++ b/.github/workflows/windows-release.yml
@@ -26,7 +26,7 @@ jobs:
- uses: actions/setup-node@v3
with:
node-version: 16
- - name: Build Frontend (theme default)
+ - name: Build Frontend
env:
CI: ""
run: |
diff --git a/Dockerfile b/Dockerfile
index 94cd8468..ec2f9d43 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -23,7 +23,7 @@ ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /web/build ./web/build
-RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
+RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine
diff --git a/README.md b/README.md
index ff5e07d4..ff1fffd2 100644
--- a/README.md
+++ b/README.md
@@ -73,6 +73,9 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
+ [x] [360 智脑](https://ai.360.cn)
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
+ + [x] [Moonshot AI](https://platform.moonshot.cn/)
+ + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP)
+ + [ ] [MINIMAX](https://api.minimax.chat/) (WIP)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
diff --git a/common/constants.go b/common/constants.go
index 325454d4..ccaa3560 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -63,6 +63,7 @@ const (
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
+ ChannelTypeMoonshot = 25
)
var ChannelBaseURLs = []string{
@@ -91,4 +92,13 @@ var ChannelBaseURLs = []string{
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
+ "https://api.moonshot.cn", // 25
}
+
+const (
+ ConfigKeyPrefix = "cfg_"
+
+ ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
+ ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
+ ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
+)
diff --git a/common/gin.go b/common/gin.go
index bed2c2b1..b6ef96a6 100644
--- a/common/gin.go
+++ b/common/gin.go
@@ -8,12 +8,24 @@ import (
"strings"
)
-func UnmarshalBodyReusable(c *gin.Context, v any) error {
+const KeyRequestBody = "key_request_body"
+
+func GetRequestBody(c *gin.Context) ([]byte, error) {
+ requestBody, _ := c.Get(KeyRequestBody)
+ if requestBody != nil {
+ return requestBody.([]byte), nil
+ }
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
- return err
+ return nil, err
}
- err = c.Request.Body.Close()
+ _ = c.Request.Body.Close()
+ c.Set(KeyRequestBody, requestBody)
+ return requestBody.([]byte), nil
+}
+
+func UnmarshalBodyReusable(c *gin.Context, v any) error {
+ requestBody, err := GetRequestBody(c)
if err != nil {
return err
}
diff --git a/common/helper/helper.go b/common/helper/helper.go
index a0d88ec2..babe422b 100644
--- a/common/helper/helper.go
+++ b/common/helper/helper.go
@@ -137,6 +137,7 @@ func GetUUID() string {
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+const keyNumbers = "0123456789"
func init() {
rand.Seed(time.Now().UnixNano())
@@ -168,6 +169,15 @@ func GetRandomString(length int) string {
return string(key)
}
+func GetRandomNumberString(length int) string {
+ rand.Seed(time.Now().UnixNano())
+ key := make([]byte, length)
+ for i := 0; i < length; i++ {
+ key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
+ }
+ return string(key)
+}
+
func GetTimestamp() int64 {
return time.Now().Unix()
}
diff --git a/common/model-ratio.go b/common/model-ratio.go
index 08cde8c7..2e7aae71 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -30,6 +30,12 @@ var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-3": 4000,
}
+const (
+ USD2RMB = 7
+ USD = 500 // $0.002 = 1 -> $1 = 500
+ RMB = USD / USD2RMB
+)
+
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
@@ -38,57 +44,60 @@ var DalleImagePromptLengthLimitations = map[string]int{
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
- "gpt-4": 15,
- "gpt-4-0314": 15,
- "gpt-4-0613": 15,
- "gpt-4-32k": 30,
- "gpt-4-32k-0314": 30,
- "gpt-4-32k-0613": 30,
- "gpt-4-1106-preview": 5, // $0.01 / 1K tokens
- "gpt-4-0125-preview": 5, // $0.01 / 1K tokens
- "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
- "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
- "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
- "gpt-3.5-turbo-0301": 0.75,
- "gpt-3.5-turbo-0613": 0.75,
- "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
- "gpt-3.5-turbo-16k-0613": 1.5,
- "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
- "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
- "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
- "davinci-002": 1, // $0.002 / 1K tokens
- "babbage-002": 0.2, // $0.0004 / 1K tokens
- "text-ada-001": 0.2,
- "text-babbage-001": 0.25,
- "text-curie-001": 1,
- "text-davinci-002": 10,
- "text-davinci-003": 10,
- "text-davinci-edit-001": 10,
- "code-davinci-edit-001": 10,
- "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
- "tts-1": 7.5, // $0.015 / 1K characters
- "tts-1-1106": 7.5,
- "tts-1-hd": 15, // $0.030 / 1K characters
- "tts-1-hd-1106": 15,
- "davinci": 10,
- "curie": 10,
- "babbage": 10,
- "ada": 10,
- "text-embedding-ada-002": 0.05,
- "text-embedding-3-small": 0.01,
- "text-embedding-3-large": 0.065,
- "text-search-ada-doc-001": 10,
- "text-moderation-stable": 0.1,
- "text-moderation-latest": 0.1,
- "dall-e-2": 8, // $0.016 - $0.020 / image
- "dall-e-3": 20, // $0.040 - $0.120 / image
- "claude-instant-1": 0.815, // $1.63 / 1M tokens
- "claude-2": 5.51, // $11.02 / 1M tokens
- "claude-2.0": 5.51, // $11.02 / 1M tokens
- "claude-2.1": 5.51, // $11.02 / 1M tokens
- "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
- "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
- "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
+ // https://openai.com/pricing
+ "gpt-4": 15,
+ "gpt-4-0314": 15,
+ "gpt-4-0613": 15,
+ "gpt-4-32k": 30,
+ "gpt-4-32k-0314": 30,
+ "gpt-4-32k-0613": 30,
+ "gpt-4-1106-preview": 5, // $0.01 / 1K tokens
+ "gpt-4-0125-preview": 5, // $0.01 / 1K tokens
+ "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
+ "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
+ "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
+ "gpt-3.5-turbo-0301": 0.75,
+ "gpt-3.5-turbo-0613": 0.75,
+ "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
+ "gpt-3.5-turbo-16k-0613": 1.5,
+ "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
+ "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
+ "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
+ "davinci-002": 1, // $0.002 / 1K tokens
+ "babbage-002": 0.2, // $0.0004 / 1K tokens
+ "text-ada-001": 0.2,
+ "text-babbage-001": 0.25,
+ "text-curie-001": 1,
+ "text-davinci-002": 10,
+ "text-davinci-003": 10,
+ "text-davinci-edit-001": 10,
+ "code-davinci-edit-001": 10,
+ "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+ "tts-1": 7.5, // $0.015 / 1K characters
+ "tts-1-1106": 7.5,
+ "tts-1-hd": 15, // $0.030 / 1K characters
+ "tts-1-hd-1106": 15,
+ "davinci": 10,
+ "curie": 10,
+ "babbage": 10,
+ "ada": 10,
+ "text-embedding-ada-002": 0.05,
+ "text-embedding-3-small": 0.01,
+ "text-embedding-3-large": 0.065,
+ "text-search-ada-doc-001": 10,
+ "text-moderation-stable": 0.1,
+ "text-moderation-latest": 0.1,
+ "dall-e-2": 8, // $0.016 - $0.020 / image
+ "dall-e-3": 20, // $0.040 - $0.120 / image
+ "claude-instant-1": 0.815, // $1.63 / 1M tokens
+ "claude-2": 5.51, // $11.02 / 1M tokens
+ "claude-2.0": 5.51, // $11.02 / 1M tokens
+ "claude-2.1": 5.51, // $11.02 / 1M tokens
+ // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
+ "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
+ "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
+ "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
+ "ERNIE-Bot-8k": 0.024 * RMB,
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
@@ -103,11 +112,21 @@ var ModelRatio = map[string]float64{
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
+ "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
+ "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
+ "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
+ "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
+ "ChatStd": 0.01 * RMB,
+ "ChatPro": 0.1 * RMB,
+ // https://platform.moonshot.cn/pricing
+ "moonshot-v1-8k": 0.012 * RMB,
+ "moonshot-v1-32k": 0.024 * RMB,
+ "moonshot-v1-128k": 0.06 * RMB,
}
func ModelRatio2JSONString() string {
diff --git a/controller/billing.go b/controller/billing.go
index 7bc19b49..7317913d 100644
--- a/controller/billing.go
+++ b/controller/billing.go
@@ -4,7 +4,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/model"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func GetSubscription(c *gin.Context) {
@@ -30,7 +30,7 @@ func GetSubscription(c *gin.Context) {
expiredTime = 0
}
if err != nil {
- Error := openai.Error{
+ Error := relaymodel.Error{
Message: err.Error(),
Type: "upstream_error",
}
@@ -72,7 +72,7 @@ func GetUsage(c *gin.Context) {
quota, err = model.GetUserUsedQuota(userId)
}
if err != nil {
- Error := openai.Error{
+ Error := relaymodel.Error{
Message: err.Error(),
Type: "one_api_error",
}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 9d21b469..b498f4f1 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -9,10 +9,14 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/helper"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
+ "net/http/httptest"
+ "net/url"
"strconv"
"sync"
"time"
@@ -20,87 +24,13 @@ import (
"github.com/gin-gonic/gin"
)
-func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) {
- switch channel.Type {
- case common.ChannelTypePaLM:
- fallthrough
- case common.ChannelTypeGemini:
- fallthrough
- case common.ChannelTypeAnthropic:
- fallthrough
- case common.ChannelTypeBaidu:
- fallthrough
- case common.ChannelTypeZhipu:
- fallthrough
- case common.ChannelTypeAli:
- fallthrough
- case common.ChannelType360:
- fallthrough
- case common.ChannelTypeXunfei:
- return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
- case common.ChannelTypeAzure:
- request.Model = "gpt-35-turbo"
- defer func() {
- if err != nil {
- err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
- }
- }()
- default:
- request.Model = "gpt-3.5-turbo"
- }
- requestURL := common.ChannelBaseURLs[channel.Type]
- if channel.Type == common.ChannelTypeAzure {
- requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
- } else {
- if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
- requestURL = baseURL
- }
-
- requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
- }
- jsonData, err := json.Marshal(request)
- if err != nil {
- return err, nil
- }
- req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
- if err != nil {
- return err, nil
- }
- if channel.Type == common.ChannelTypeAzure {
- req.Header.Set("api-key", channel.Key)
- } else {
- req.Header.Set("Authorization", "Bearer "+channel.Key)
- }
- req.Header.Set("Content-Type", "application/json")
- resp, err := util.HTTPClient.Do(req)
- if err != nil {
- return err, nil
- }
- defer resp.Body.Close()
- var response openai.SlimTextResponse
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return err, nil
- }
- err = json.Unmarshal(body, &response)
- if err != nil {
- return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
- }
- if response.Usage.CompletionTokens == 0 {
- if response.Error.Message == "" {
- response.Error.Message = "补全 tokens 非预期返回 0"
- }
- return fmt.Errorf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message), &response.Error
- }
- return nil, nil
-}
-
-func buildTestRequest() *openai.ChatRequest {
- testRequest := &openai.ChatRequest{
- Model: "", // this will be set later
+func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
+ testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 1,
+ Stream: false,
+ Model: "gpt-3.5-turbo",
}
- testMessage := openai.Message{
+ testMessage := relaymodel.Message{
Role: "user",
Content: "hi",
}
@@ -108,6 +38,65 @@ func buildTestRequest() *openai.ChatRequest {
return testRequest
}
+func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = &http.Request{
+ Method: "POST",
+ URL: &url.URL{Path: "/v1/chat/completions"},
+ Body: nil,
+ Header: make(http.Header),
+ }
+ c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set("channel", channel.Type)
+ c.Set("base_url", channel.GetBaseURL())
+ meta := util.GetRelayMeta(c)
+ apiType := constant.ChannelType2APIType(channel.Type)
+ adaptor := helper.GetAdaptor(apiType)
+ if adaptor == nil {
+ return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
+ }
+ adaptor.Init(meta)
+ modelName := adaptor.GetModelList()[0]
+ request := buildTestRequest()
+ request.Model = modelName
+ meta.OriginModelName, meta.ActualModelName = modelName, modelName
+ convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
+ if err != nil {
+ return err, nil
+ }
+ jsonData, err := json.Marshal(convertedRequest)
+ if err != nil {
+ return err, nil
+ }
+ requestBody := bytes.NewBuffer(jsonData)
+ c.Request.Body = io.NopCloser(requestBody)
+ resp, err := adaptor.DoRequest(c, meta, requestBody)
+ if err != nil {
+ return err, nil
+ }
+ if resp.StatusCode != http.StatusOK {
+ err := util.RelayErrorHandler(resp)
+ return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
+ }
+ usage, respErr := adaptor.DoResponse(c, resp, meta)
+ if respErr != nil {
+ return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
+ }
+ if usage == nil {
+ return errors.New("usage is nil"), nil
+ }
+ result := w.Result()
+ // print result.Body
+ respBody, err := io.ReadAll(result.Body)
+ if err != nil {
+ return err, nil
+ }
+ logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
+ return nil, nil
+}
+
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -125,9 +114,8 @@ func TestChannel(c *gin.Context) {
})
return
}
- testRequest := buildTestRequest()
tik := time.Now()
- err, _ = testChannel(channel, *testRequest)
+ err, _ = testChannel(channel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
@@ -192,7 +180,6 @@ func testAllChannels(notify bool) error {
if err != nil {
return err
}
- testRequest := buildTestRequest()
var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
@@ -201,7 +188,7 @@ func testAllChannels(notify bool) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
- err, openaiErr := testChannel(channel, *testRequest)
+ err, openaiErr := testChannel(channel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold {
diff --git a/controller/model.go b/controller/model.go
index e3e83fcd..f5760901 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -3,7 +3,11 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/channel/ai360"
+ "github.com/songquanpeng/one-api/relay/channel/moonshot"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/helper"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -53,592 +57,46 @@ func init() {
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
- openAIModels = []OpenAIModels{
- {
- Id: "dall-e-2",
+ for i := 0; i < constant.APITypeDummy; i++ {
+ if i == constant.APITypeAIProxyLibrary {
+ continue
+ }
+ adaptor := helper.GetAdaptor(i)
+ channelName := adaptor.GetChannelName()
+ modelNames := adaptor.GetModelList()
+ for _, modelName := range modelNames {
+ openAIModels = append(openAIModels, OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: channelName,
+ Permission: permission,
+ Root: modelName,
+ Parent: nil,
+ })
+ }
+ }
+ for _, modelName := range ai360.ModelList {
+ openAIModels = append(openAIModels, OpenAIModels{
+ Id: modelName,
Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "dall-e-2",
- Parent: nil,
- },
- {
- Id: "dall-e-3",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "dall-e-3",
- Parent: nil,
- },
- {
- Id: "whisper-1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "whisper-1",
- Parent: nil,
- },
- {
- Id: "tts-1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "tts-1",
- Parent: nil,
- },
- {
- Id: "tts-1-1106",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "tts-1-1106",
- Parent: nil,
- },
- {
- Id: "tts-1-hd",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "tts-1-hd",
- Parent: nil,
- },
- {
- Id: "tts-1-hd-1106",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "tts-1-hd-1106",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-0301",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-0301",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-0613",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-0613",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-16k",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-16k",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-16k-0613",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-16k-0613",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-1106",
- Object: "model",
- Created: 1699593571,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-1106",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-0125",
- Object: "model",
- Created: 1706232090,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-0125",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-instruct",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-instruct",
- Parent: nil,
- },
- {
- Id: "gpt-4",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4",
- Parent: nil,
- },
- {
- Id: "gpt-4-0314",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-0314",
- Parent: nil,
- },
- {
- Id: "gpt-4-0613",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-0613",
- Parent: nil,
- },
- {
- Id: "gpt-4-32k",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-32k",
- Parent: nil,
- },
- {
- Id: "gpt-4-32k-0314",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-32k-0314",
- Parent: nil,
- },
- {
- Id: "gpt-4-32k-0613",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-32k-0613",
- Parent: nil,
- },
- {
- Id: "gpt-4-1106-preview",
- Object: "model",
- Created: 1699593571,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-1106-preview",
- Parent: nil,
- },
- {
- Id: "gpt-4-0125-preview",
- Object: "model",
- Created: 1706232090,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-0125-preview",
- Parent: nil,
- },
- {
- Id: "gpt-4-turbo-preview",
- Object: "model",
- Created: 1706232090,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-turbo-preview",
- Parent: nil,
- },
- {
- Id: "gpt-4-vision-preview",
- Object: "model",
- Created: 1699593571,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-vision-preview",
- Parent: nil,
- },
- {
- Id: "text-embedding-ada-002",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-embedding-ada-002",
- Parent: nil,
- },
- {
- Id: "text-embedding-3-small",
- Object: "model",
- Created: 1706232090,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-embedding-3-small",
- Parent: nil,
- },
- {
- Id: "text-embedding-3-large",
- Object: "model",
- Created: 1706232090,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-embedding-3-large",
- Parent: nil,
- },
- {
- Id: "text-davinci-003",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-davinci-003",
- Parent: nil,
- },
- {
- Id: "text-davinci-002",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-davinci-002",
- Parent: nil,
- },
- {
- Id: "text-curie-001",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-curie-001",
- Parent: nil,
- },
- {
- Id: "text-babbage-001",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-babbage-001",
- Parent: nil,
- },
- {
- Id: "text-ada-001",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-ada-001",
- Parent: nil,
- },
- {
- Id: "text-moderation-latest",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-moderation-latest",
- Parent: nil,
- },
- {
- Id: "text-moderation-stable",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-moderation-stable",
- Parent: nil,
- },
- {
- Id: "text-davinci-edit-001",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-davinci-edit-001",
- Parent: nil,
- },
- {
- Id: "code-davinci-edit-001",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "code-davinci-edit-001",
- Parent: nil,
- },
- {
- Id: "davinci-002",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "davinci-002",
- Parent: nil,
- },
- {
- Id: "babbage-002",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "babbage-002",
- Parent: nil,
- },
- {
- Id: "claude-instant-1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "anthropic",
- Permission: permission,
- Root: "claude-instant-1",
- Parent: nil,
- },
- {
- Id: "claude-2",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "anthropic",
- Permission: permission,
- Root: "claude-2",
- Parent: nil,
- },
- {
- Id: "claude-2.1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "anthropic",
- Permission: permission,
- Root: "claude-2.1",
- Parent: nil,
- },
- {
- Id: "claude-2.0",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "anthropic",
- Permission: permission,
- Root: "claude-2.0",
- Parent: nil,
- },
- {
- Id: "ERNIE-Bot",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "baidu",
- Permission: permission,
- Root: "ERNIE-Bot",
- Parent: nil,
- },
- {
- Id: "ERNIE-Bot-turbo",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "baidu",
- Permission: permission,
- Root: "ERNIE-Bot-turbo",
- Parent: nil,
- },
- {
- Id: "ERNIE-Bot-4",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "baidu",
- Permission: permission,
- Root: "ERNIE-Bot-4",
- Parent: nil,
- },
- {
- Id: "Embedding-V1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "baidu",
- Permission: permission,
- Root: "Embedding-V1",
- Parent: nil,
- },
- {
- Id: "PaLM-2",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "google palm",
- Permission: permission,
- Root: "PaLM-2",
- Parent: nil,
- },
- {
- Id: "gemini-pro",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "google gemini",
- Permission: permission,
- Root: "gemini-pro",
- Parent: nil,
- },
- {
- Id: "gemini-pro-vision",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "google gemini",
- Permission: permission,
- Root: "gemini-pro-vision",
- Parent: nil,
- },
- {
- Id: "chatglm_turbo",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "zhipu",
- Permission: permission,
- Root: "chatglm_turbo",
- Parent: nil,
- },
- {
- Id: "chatglm_pro",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "zhipu",
- Permission: permission,
- Root: "chatglm_pro",
- Parent: nil,
- },
- {
- Id: "chatglm_std",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "zhipu",
- Permission: permission,
- Root: "chatglm_std",
- Parent: nil,
- },
- {
- Id: "chatglm_lite",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "zhipu",
- Permission: permission,
- Root: "chatglm_lite",
- Parent: nil,
- },
- {
- Id: "qwen-turbo",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "ali",
- Permission: permission,
- Root: "qwen-turbo",
- Parent: nil,
- },
- {
- Id: "qwen-plus",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "ali",
- Permission: permission,
- Root: "qwen-plus",
- Parent: nil,
- },
- {
- Id: "qwen-max",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "ali",
- Permission: permission,
- Root: "qwen-max",
- Parent: nil,
- },
- {
- Id: "qwen-max-longcontext",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "ali",
- Permission: permission,
- Root: "qwen-max-longcontext",
- Parent: nil,
- },
- {
- Id: "text-embedding-v1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "ali",
- Permission: permission,
- Root: "text-embedding-v1",
- Parent: nil,
- },
- {
- Id: "SparkDesk",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "xunfei",
- Permission: permission,
- Root: "SparkDesk",
- Parent: nil,
- },
- {
- Id: "360GPT_S2_V9",
- Object: "model",
- Created: 1677649963,
+ Created: 1626777600,
OwnedBy: "360",
Permission: permission,
- Root: "360GPT_S2_V9",
+ Root: modelName,
Parent: nil,
- },
- {
- Id: "embedding-bert-512-v1",
+ })
+ }
+ for _, modelName := range moonshot.ModelList {
+ openAIModels = append(openAIModels, OpenAIModels{
+ Id: modelName,
Object: "model",
- Created: 1677649963,
- OwnedBy: "360",
+ Created: 1626777600,
+ OwnedBy: "moonshot",
Permission: permission,
- Root: "embedding-bert-512-v1",
+ Root: modelName,
Parent: nil,
- },
- {
- Id: "embedding_s1_v1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "360",
- Permission: permission,
- Root: "embedding_s1_v1",
- Parent: nil,
- },
- {
- Id: "semantic_similarity_s1_v1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "360",
- Permission: permission,
- Root: "semantic_similarity_s1_v1",
- Parent: nil,
- },
- {
- Id: "hunyuan",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "tencent",
- Permission: permission,
- Root: "hunyuan",
- Parent: nil,
- },
+ })
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
@@ -658,7 +116,7 @@ func RetrieveModel(c *gin.Context) {
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
- Error := openai.Error{
+ Error := relaymodel.Error{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",
diff --git a/controller/relay.go b/controller/relay.go
index cfe37984..499e8ddc 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -1,24 +1,28 @@
package controller
import (
+ "bytes"
+ "context"
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/middleware"
+ dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller"
+ "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
+ "io"
"net/http"
- "strconv"
)
// https://platform.openai.com/docs/api-reference/chat
-func Relay(c *gin.Context) {
- relayMode := constant.Path2RelayMode(c.Request.URL.Path)
- var err *openai.ErrorWithStatusCode
+func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
+ var err *model.ErrorWithStatusCode
switch relayMode {
case constant.RelayModeImagesGenerations:
err = controller.RelayImageHelper(c, relayMode)
@@ -31,37 +35,90 @@ func Relay(c *gin.Context) {
default:
err = controller.RelayTextHelper(c)
}
- if err != nil {
- requestId := c.GetString(logger.RequestIdKey)
- retryTimesStr := c.Query("retry")
- retryTimes, _ := strconv.Atoi(retryTimesStr)
- if retryTimesStr == "" {
- retryTimes = config.RetryTimes
+ return err
+}
+
+func Relay(c *gin.Context) {
+ ctx := c.Request.Context()
+ relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+ bizErr := relay(c, relayMode)
+ if bizErr == nil {
+ return
+ }
+ channelId := c.GetInt("channel_id")
+ lastFailedChannelId := channelId
+ channelName := c.GetString("channel_name")
+ group := c.GetString("group")
+ originalModel := c.GetString("original_model")
+ go processChannelRelayError(ctx, channelId, channelName, bizErr)
+ requestId := c.GetString(logger.RequestIdKey)
+ retryTimes := config.RetryTimes
+ if !shouldRetry(c, bizErr.StatusCode) {
+ logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode)
+ retryTimes = 0
+ }
+ for i := retryTimes; i > 0; i-- {
+ channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel)
+ if err != nil {
+ logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
+ break
}
- if retryTimes > 0 {
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
- } else {
- if err.StatusCode == http.StatusTooManyRequests {
- err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
- }
- err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId)
- c.JSON(err.StatusCode, gin.H{
- "error": err.Error,
- })
+ logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i)
+ if channel.Id == lastFailedChannelId {
+ continue
+ }
+ middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+ requestBody, err := common.GetRequestBody(c)
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+ bizErr = relay(c, relayMode)
+ if bizErr == nil {
+ return
}
channelId := c.GetInt("channel_id")
- logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
- // https://platform.openai.com/docs/guides/error-codes/api-errors
- if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
- channelId := c.GetInt("channel_id")
- channelName := c.GetString("channel_name")
- disableChannel(channelId, channelName, err.Message)
+ lastFailedChannelId = channelId
+ channelName := c.GetString("channel_name")
+ go processChannelRelayError(ctx, channelId, channelName, bizErr)
+ }
+ if bizErr != nil {
+ if bizErr.StatusCode == http.StatusTooManyRequests {
+ bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
+ bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
+ c.JSON(bizErr.StatusCode, gin.H{
+ "error": bizErr.Error,
+ })
+ }
+}
+
+func shouldRetry(c *gin.Context, statusCode int) bool {
+ if _, ok := c.Get("specific_channel_id"); ok {
+ return false
+ }
+ if statusCode == http.StatusTooManyRequests {
+ return true
+ }
+ if statusCode/100 == 5 {
+ return true
+ }
+ if statusCode == http.StatusBadRequest {
+ return false
+ }
+ if statusCode/100 == 2 {
+ return false
+ }
+ return true
+}
+
+func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
+ logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
+ // https://platform.openai.com/docs/guides/error-codes/api-errors
+ if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
+ disableChannel(channelId, channelName, err.Message)
}
}
func RelayNotImplemented(c *gin.Context) {
- err := openai.Error{
+ err := model.Error{
Message: "API not implemented",
Type: "one_api_error",
Param: "",
@@ -73,7 +130,7 @@ func RelayNotImplemented(c *gin.Context) {
}
func RelayNotFound(c *gin.Context) {
- err := openai.Error{
+ err := model.Error{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",
diff --git a/i18n/en.json b/i18n/en.json
index 774be837..54728e2f 100644
--- a/i18n/en.json
+++ b/i18n/en.json
@@ -456,6 +456,7 @@
"已绑定的邮箱账户": "Email Account Bound",
"用户信息更新成功!": "User information updated successfully!",
"模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f",
+ "模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f": "model rate %.2f, group rate %.2f, completion rate %.2f",
"使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})",
"用户名称": "User Name",
"令牌名称": "Token Name",
diff --git a/middleware/auth.go b/middleware/auth.go
index 42a599d0..9d25f395 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -108,7 +108,7 @@ func TokenAuth() func(c *gin.Context) {
c.Set("token_name", token.Name)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
- c.Set("channelId", parts[1])
+ c.Set("specific_channel_id", parts[1])
} else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 0ed250fd..aeb2796a 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -21,8 +21,9 @@ func Distribute() func(c *gin.Context) {
userId := c.GetInt("id")
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
+ var requestModel string
var channel *model.Channel
- channelId, ok := c.Get("channelId")
+ channelId, ok := c.Get("specific_channel_id")
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -66,6 +67,7 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "whisper-1"
}
}
+ requestModel = modelRequest.Model
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
@@ -77,24 +79,34 @@ func Distribute() func(c *gin.Context) {
return
}
}
- c.Set("channel", channel.Type)
- c.Set("channel_id", channel.Id)
- c.Set("channel_name", channel.Name)
- c.Set("model_mapping", channel.GetModelMapping())
- c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
- c.Set("base_url", channel.GetBaseURL())
- switch channel.Type {
- case common.ChannelTypeAzure:
- c.Set("api_version", channel.Other)
- case common.ChannelTypeXunfei:
- c.Set("api_version", channel.Other)
- case common.ChannelTypeGemini:
- c.Set("api_version", channel.Other)
- case common.ChannelTypeAIProxyLibrary:
- c.Set("library_id", channel.Other)
- case common.ChannelTypeAli:
- c.Set("plugin", channel.Other)
- }
+ SetupContextForSelectedChannel(c, channel, requestModel)
c.Next()
}
}
+
+func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
+ c.Set("channel", channel.Type)
+ c.Set("channel_id", channel.Id)
+ c.Set("channel_name", channel.Name)
+ c.Set("model_mapping", channel.GetModelMapping())
+ c.Set("original_model", modelName) // for retry
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+ c.Set("base_url", channel.GetBaseURL())
+ // this is for backward compatibility
+ switch channel.Type {
+ case common.ChannelTypeAzure:
+ c.Set(common.ConfigKeyAPIVersion, channel.Other)
+ case common.ChannelTypeXunfei:
+ c.Set(common.ConfigKeyAPIVersion, channel.Other)
+ case common.ChannelTypeGemini:
+ c.Set(common.ConfigKeyAPIVersion, channel.Other)
+ case common.ChannelTypeAIProxyLibrary:
+ c.Set(common.ConfigKeyLibraryID, channel.Other)
+ case common.ChannelTypeAli:
+ c.Set(common.ConfigKeyPlugin, channel.Other)
+ }
+ cfg, _ := channel.LoadConfig()
+ for k, v := range cfg {
+ c.Set(common.ConfigKeyPrefix+k, v)
+ }
+}
diff --git a/middleware/request-id.go b/middleware/request-id.go
index 7cb66e93..234a93d8 100644
--- a/middleware/request-id.go
+++ b/middleware/request-id.go
@@ -9,7 +9,7 @@ import (
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
- id := helper.GetTimeString() + helper.GetRandomString(8)
+ id := helper.GetTimeString() + helper.GetRandomNumberString(8)
c.Set(logger.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
diff --git a/model/cache.go b/model/cache.go
index 297df153..04a60348 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -94,7 +94,7 @@ func CacheUpdateUserQuota(id int) error {
if !common.RedisEnabled {
return nil
}
- quota, err := GetUserQuota(id)
+ quota, err := CacheGetUserQuota(id)
if err != nil {
return err
}
diff --git a/model/channel.go b/model/channel.go
index 0503a620..19af2263 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -21,7 +21,7 @@ type Channel struct {
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
- Other string `json:"other"`
+ Other string `json:"other"` // DEPRECATED: please save config to field Config
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
@@ -29,6 +29,7 @@ type Channel struct {
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
+ Config string `json:"config"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -155,6 +156,18 @@ func (channel *Channel) Delete() error {
return err
}
+func (channel *Channel) LoadConfig() (map[string]string, error) {
+ if channel.Config == "" {
+ return nil, nil
+ }
+ cfg := make(map[string]string)
+ err := json.Unmarshal([]byte(channel.Config), &cfg)
+ if err != nil {
+ return nil, err
+ }
+ return cfg, nil
+}
+
func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil {
diff --git a/relay/channel/ai360/constants.go b/relay/channel/ai360/constants.go
new file mode 100644
index 00000000..cfc3cb28
--- /dev/null
+++ b/relay/channel/ai360/constants.go
@@ -0,0 +1,8 @@
+package ai360
+
+var ModelList = []string{
+ "360GPT_S2_V9",
+ "embedding-bert-512-v1",
+ "embedding_s1_v1",
+ "semantic_similarity_s1_v1",
+}
diff --git a/relay/channel/aiproxy/adaptor.go b/relay/channel/aiproxy/adaptor.go
index 7e737e8f..2b4e3022 100644
--- a/relay/channel/aiproxy/adaptor.go
+++ b/relay/channel/aiproxy/adaptor.go
@@ -1,22 +1,60 @@
package aiproxy
import (
+ "errors"
+ "fmt"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
"net/http"
)
type Adaptor struct {
}
-func (a *Adaptor) Auth(c *gin.Context) error {
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
-func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
- return nil, nil
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ aiProxyLibraryRequest := ConvertRequest(*request)
+ aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
+ return aiProxyLibraryRequest, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
- return nil, nil, nil
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ err, usage = StreamHandler(c, resp)
+ } else {
+ err, usage = Handler(c, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "aiproxy"
}
diff --git a/relay/channel/aiproxy/constants.go b/relay/channel/aiproxy/constants.go
new file mode 100644
index 00000000..c4df51c4
--- /dev/null
+++ b/relay/channel/aiproxy/constants.go
@@ -0,0 +1,9 @@
+package aiproxy
+
+import "github.com/songquanpeng/one-api/relay/channel/openai"
+
+var ModelList = []string{""}
+
+func init() {
+ ModelList = openai.ModelList
+}
diff --git a/relay/channel/aiproxy/main.go b/relay/channel/aiproxy/main.go
index 0bd345c7..96972407 100644
--- a/relay/channel/aiproxy/main.go
+++ b/relay/channel/aiproxy/main.go
@@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
@@ -18,7 +19,7 @@ import (
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
-func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest {
+func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].StringContent()
@@ -45,14 +46,14 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
- Id: helper.GetUUID(),
+ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
@@ -65,7 +66,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{
- Id: helper.GetUUID(),
+ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "",
@@ -77,7 +78,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &openai.ChatCompletionsStreamResponse{
- Id: helper.GetUUID(),
+ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: response.Model,
@@ -85,8 +86,8 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
}
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
- var usage openai.Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -157,7 +158,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, &usage
}
-func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var AIProxyLibraryResponse LibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -172,8 +173,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
index 9470eff0..6c6f433e 100644
--- a/relay/channel/ali/adaptor.go
+++ b/relay/channel/ali/adaptor.go
@@ -1,22 +1,83 @@
package ali
import (
+ "errors"
+ "fmt"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
"net/http"
)
+// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
+
type Adaptor struct {
}
-func (a *Adaptor) Auth(c *gin.Context) error {
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
+ if meta.Mode == constant.RelayModeEmbeddings {
+ fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
+ }
+ return fullRequestURL, nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("Authorization", "Bearer "+meta.APIKey)
+ if meta.IsStream {
+ req.Header.Set("X-DashScope-SSE", "enable")
+ }
+ if c.GetString(common.ConfigKeyPlugin) != "" {
+ req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
+ }
return nil
}
-func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
- return nil, nil
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
+ return baiduEmbeddingRequest, nil
+ default:
+ baiduRequest := ConvertRequest(*request)
+ return baiduRequest, nil
+ }
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
- return nil, nil, nil
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ err, usage = StreamHandler(c, resp)
+ } else {
+ switch meta.Mode {
+ case constant.RelayModeEmbeddings:
+ err, usage = EmbeddingHandler(c, resp)
+ default:
+ err, usage = Handler(c, resp)
+ }
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "ali"
}
diff --git a/relay/channel/ali/constants.go b/relay/channel/ali/constants.go
new file mode 100644
index 00000000..16bcfca4
--- /dev/null
+++ b/relay/channel/ali/constants.go
@@ -0,0 +1,6 @@
+package ali
+
+var ModelList = []string{
+ "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
+ "text-embedding-v1",
+}
diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go
index 70476d2e..b9625584 100644
--- a/relay/channel/ali/main.go
+++ b/relay/channel/ali/main.go
@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
@@ -17,7 +18,7 @@ import (
const EnableSearchModelSuffix = "-internet"
-func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
+func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
@@ -40,11 +41,12 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
Parameters: Parameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
+ Seed: uint64(request.Seed),
},
}
}
-func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
+func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
@@ -55,7 +57,7 @@ func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequ
}
}
-func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
@@ -68,8 +70,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSta
}
if aliResponse.Code != "" {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
@@ -95,7 +97,7 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
- Usage: openai.Usage{TotalTokens: response.Usage.TotalTokens},
+ Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
@@ -111,7 +113,7 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: response.Output.Text,
},
@@ -122,7 +124,7 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
- Usage: openai.Usage{
+ Usage: model.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
@@ -148,8 +150,8 @@ func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletions
return &response
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
- var usage openai.Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -217,7 +219,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, &usage
}
-func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -232,8 +234,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go
index ee02fab5..4b873715 100644
--- a/relay/channel/anthropic/adaptor.go
+++ b/relay/channel/anthropic/adaptor.go
@@ -1,22 +1,65 @@
package anthropic
import (
+ "errors"
+ "fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
"net/http"
)
type Adaptor struct {
}
-func (a *Adaptor) Auth(c *gin.Context) error {
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("x-api-key", meta.APIKey)
+ anthropicVersion := c.Request.Header.Get("anthropic-version")
+ if anthropicVersion == "" {
+ anthropicVersion = "2023-06-01"
+ }
+ req.Header.Set("anthropic-version", anthropicVersion)
return nil
}
-func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
- return nil, nil
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return ConvertRequest(*request), nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
- return nil, nil, nil
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ var responseText string
+ err, responseText = StreamHandler(c, resp)
+ usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ } else {
+ err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "authropic"
}
diff --git a/relay/channel/anthropic/constants.go b/relay/channel/anthropic/constants.go
new file mode 100644
index 00000000..b98c15c2
--- /dev/null
+++ b/relay/channel/anthropic/constants.go
@@ -0,0 +1,5 @@
+package anthropic
+
+var ModelList = []string{
+ "claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
+}
diff --git a/relay/channel/anthropic/main.go b/relay/channel/anthropic/main.go
index c1e39494..e2c575fa 100644
--- a/relay/channel/anthropic/main.go
+++ b/relay/channel/anthropic/main.go
@@ -9,6 +9,7 @@ import (
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
@@ -25,7 +26,7 @@ func stopReasonClaude2OpenAI(reason string) string {
}
}
-func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request {
+func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeRequest := Request{
Model: textRequest.Model,
Prompt: "",
@@ -72,7 +73,7 @@ func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletio
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
@@ -88,7 +89,7 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
return &fullTextResponse
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
createdTime := helper.GetTimestamp()
@@ -153,7 +154,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, responseText
}
-func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -168,8 +169,8 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
@@ -179,9 +180,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
- fullTextResponse.Model = model
- completionTokens := openai.CountTokenText(claudeResponse.Completion, model)
- usage := openai.Usage{
+ fullTextResponse.Model = modelName
+ completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName)
+ usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go
index c6304a74..d2d06ce0 100644
--- a/relay/channel/baidu/adaptor.go
+++ b/relay/channel/baidu/adaptor.go
@@ -1,22 +1,93 @@
package baidu
import (
+ "errors"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
"net/http"
)
type Adaptor struct {
}
-func (a *Adaptor) Auth(c *gin.Context) error {
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
+ var fullRequestURL string
+ switch meta.ActualModelName {
+ case "ERNIE-Bot-4":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
+ case "ERNIE-Bot-8K":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
+ case "ERNIE-Bot":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
+ case "ERNIE-Speed":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
+ case "ERNIE-Bot-turbo":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
+ case "BLOOMZ-7B":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
+ case "Embedding-V1":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
+ }
+ var accessToken string
+ var err error
+ if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
+ return "", err
+ }
+ fullRequestURL += "?access_token=" + accessToken
+ return fullRequestURL, nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
-func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
- return nil, nil
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
+ return baiduEmbeddingRequest, nil
+ default:
+ baiduRequest := ConvertRequest(*request)
+ return baiduRequest, nil
+ }
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
- return nil, nil, nil
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ err, usage = StreamHandler(c, resp)
+ } else {
+ switch meta.Mode {
+ case constant.RelayModeEmbeddings:
+ err, usage = EmbeddingHandler(c, resp)
+ default:
+ err, usage = Handler(c, resp)
+ }
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "baidu"
}
diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go
new file mode 100644
index 00000000..0fa8f2d6
--- /dev/null
+++ b/relay/channel/baidu/constants.go
@@ -0,0 +1,10 @@
+package baidu
+
+var ModelList = []string{
+ "ERNIE-Bot-4",
+ "ERNIE-Bot-8K",
+ "ERNIE-Bot",
+ "ERNIE-Speed",
+ "ERNIE-Bot-turbo",
+ "Embedding-V1",
+}
diff --git a/relay/channel/baidu/main.go b/relay/channel/baidu/main.go
index 00391602..4f2b13fc 100644
--- a/relay/channel/baidu/main.go
+++ b/relay/channel/baidu/main.go
@@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
@@ -43,7 +44,7 @@ type Error struct {
var baiduTokenStore sync.Map
-func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
+func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
@@ -71,7 +72,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: response.Result,
},
@@ -103,7 +104,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatC
return &response
}
-func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
+func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Input: request.ParseInput(),
}
@@ -126,8 +127,8 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin
return &openAIEmbeddingResponse
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
- var usage openai.Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -189,7 +190,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, &usage
}
-func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -204,8 +205,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -226,7 +227,7 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return nil, &fullTextResponse.Usage
}
-func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -241,8 +242,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSta
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
diff --git a/relay/channel/baidu/model.go b/relay/channel/baidu/model.go
index 524418e1..cc1feb2f 100644
--- a/relay/channel/baidu/model.go
+++ b/relay/channel/baidu/model.go
@@ -1,18 +1,18 @@
package baidu
import (
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
"time"
)
type ChatResponse struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Result string `json:"result"`
- IsTruncated bool `json:"is_truncated"`
- NeedClearHistory bool `json:"need_clear_history"`
- Usage openai.Usage `json:"usage"`
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Result string `json:"result"`
+ IsTruncated bool `json:"is_truncated"`
+ NeedClearHistory bool `json:"need_clear_history"`
+ Usage model.Usage `json:"usage"`
Error
}
@@ -37,7 +37,7 @@ type EmbeddingResponse struct {
Object string `json:"object"`
Created int64 `json:"created"`
Data []EmbeddingData `json:"data"`
- Usage openai.Usage `json:"usage"`
+ Usage model.Usage `json:"usage"`
Error
}
diff --git a/relay/channel/common.go b/relay/channel/common.go
new file mode 100644
index 00000000..c6e1abf2
--- /dev/null
+++ b/relay/channel/common.go
@@ -0,0 +1,51 @@
+package channel
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) {
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+ if meta.IsStream && c.Request.Header.Get("Accept") == "" {
+ req.Header.Set("Accept", "text/event-stream")
+ }
+}
+
+func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ fullRequestURL, err := a.GetRequestURL(meta)
+ if err != nil {
+ return nil, fmt.Errorf("get request url failed: %w", err)
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("new request failed: %w", err)
+ }
+ err = a.SetupRequestHeader(c, req, meta)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ resp, err := DoRequest(c, req)
+ if err != nil {
+ return nil, fmt.Errorf("do request failed: %w", err)
+ }
+ return resp, nil
+}
+
+func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
+ resp, err := util.HTTPClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ if resp == nil {
+ return nil, errors.New("resp is nil")
+ }
+ _ = req.Body.Close()
+ _ = c.Request.Body.Close()
+ return resp, nil
+}
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
new file mode 100644
index 00000000..f3305e5d
--- /dev/null
+++ b/relay/channel/gemini/adaptor.go
@@ -0,0 +1,66 @@
+package gemini
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/helper"
+ channelhelper "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ version := helper.AssignOrDefault(meta.APIVersion, "v1")
+ action := "generateContent"
+ if meta.IsStream {
+ action = "streamGenerateContent"
+ }
+ return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channelhelper.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("x-goog-api-key", meta.APIKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return ConvertRequest(*request), nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channelhelper.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ var responseText string
+ err, responseText = StreamHandler(c, resp)
+ usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ } else {
+ err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "google gemini"
+}
diff --git a/relay/channel/gemini/constants.go b/relay/channel/gemini/constants.go
new file mode 100644
index 00000000..5bb0c168
--- /dev/null
+++ b/relay/channel/gemini/constants.go
@@ -0,0 +1,6 @@
+package gemini
+
+var ModelList = []string{
+ "gemini-pro",
+ "gemini-pro-vision",
+}
diff --git a/relay/channel/google/gemini.go b/relay/channel/gemini/main.go
similarity index 77%
rename from relay/channel/google/gemini.go
rename to relay/channel/gemini/main.go
index 13e6a4e8..c24694c8 100644
--- a/relay/channel/google/gemini.go
+++ b/relay/channel/gemini/main.go
@@ -1,4 +1,4 @@
-package google
+package gemini
import (
"bufio"
@@ -11,6 +11,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
@@ -21,14 +22,14 @@ import (
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
- GeminiVisionMaxImageNum = 16
+ VisionMaxImageNum = 16
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
-func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest {
- geminiRequest := GeminiChatRequest{
- Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
- SafetySettings: []GeminiChatSafetySettings{
+func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
+ geminiRequest := ChatRequest{
+ Contents: make([]ChatContent, 0, len(textRequest.Messages)),
+ SafetySettings: []ChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: config.GeminiSafetySetting,
@@ -46,14 +47,14 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
Threshold: config.GeminiSafetySetting,
},
},
- GenerationConfig: GeminiChatGenerationConfig{
+ GenerationConfig: ChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
- geminiRequest.Tools = []GeminiChatTools{
+ geminiRequest.Tools = []ChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
@@ -61,30 +62,30 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
- content := GeminiChatContent{
+ content := ChatContent{
Role: message.Role,
- Parts: []GeminiPart{
+ Parts: []Part{
{
Text: message.StringContent(),
},
},
}
openaiContent := message.ParseContent()
- var parts []GeminiPart
+ var parts []Part
imageNum := 0
for _, part := range openaiContent {
- if part.Type == openai.ContentTypeText {
- parts = append(parts, GeminiPart{
+ if part.Type == model.ContentTypeText {
+ parts = append(parts, Part{
Text: part.Text,
})
- } else if part.Type == openai.ContentTypeImageURL {
+ } else if part.Type == model.ContentTypeImageURL {
imageNum += 1
- if imageNum > GeminiVisionMaxImageNum {
+ if imageNum > VisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
- parts = append(parts, GeminiPart{
- InlineData: &GeminiInlineData{
+ parts = append(parts, Part{
+ InlineData: &InlineData{
MimeType: mimeType,
Data: data,
},
@@ -106,9 +107,9 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
- geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
+ geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
Role: "model",
- Parts: []GeminiPart{
+ Parts: []Part{
{
Text: "Okay",
},
@@ -121,12 +122,12 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
return &geminiRequest
}
-type GeminiChatResponse struct {
- Candidates []GeminiChatCandidate `json:"candidates"`
- PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
+type ChatResponse struct {
+ Candidates []ChatCandidate `json:"candidates"`
+ PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
}
-func (g *GeminiChatResponse) GetResponseText() string {
+func (g *ChatResponse) GetResponseText() string {
if g == nil {
return ""
}
@@ -136,23 +137,23 @@ func (g *GeminiChatResponse) GetResponseText() string {
return ""
}
-type GeminiChatCandidate struct {
- Content GeminiChatContent `json:"content"`
- FinishReason string `json:"finishReason"`
- Index int64 `json:"index"`
- SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
+type ChatCandidate struct {
+ Content ChatContent `json:"content"`
+ FinishReason string `json:"finishReason"`
+ Index int64 `json:"index"`
+ SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
-type GeminiChatSafetyRating struct {
+type ChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
-type GeminiChatPromptFeedback struct {
- SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
+type ChatPromptFeedback struct {
+ SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
-func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse {
+func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
@@ -162,7 +163,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextRespons
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: "",
},
@@ -176,7 +177,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextRespons
return &fullTextResponse
}
-func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse {
+func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &constant.StopFinishReason
@@ -187,7 +188,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai
return &response
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
@@ -257,7 +258,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, responseText
}
-func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -266,14 +267,14 @@ func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
- var geminiResponse GeminiChatResponse
+ var geminiResponse ChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: "No candidates returned",
Type: "server_error",
Param: "",
@@ -283,9 +284,9 @@ func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
- fullTextResponse.Model = model
- completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model)
- usage := openai.Usage{
+ fullTextResponse.Model = modelName
+ completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName)
+ usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
diff --git a/relay/channel/gemini/model.go b/relay/channel/gemini/model.go
new file mode 100644
index 00000000..d1e3c4fd
--- /dev/null
+++ b/relay/channel/gemini/model.go
@@ -0,0 +1,41 @@
+package gemini
+
+type ChatRequest struct {
+ Contents []ChatContent `json:"contents"`
+ SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"`
+ GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"`
+ Tools []ChatTools `json:"tools,omitempty"`
+}
+
+type InlineData struct {
+ MimeType string `json:"mimeType"`
+ Data string `json:"data"`
+}
+
+type Part struct {
+ Text string `json:"text,omitempty"`
+ InlineData *InlineData `json:"inlineData,omitempty"`
+}
+
+type ChatContent struct {
+ Role string `json:"role,omitempty"`
+ Parts []Part `json:"parts"`
+}
+
+type ChatSafetySettings struct {
+ Category string `json:"category"`
+ Threshold string `json:"threshold"`
+}
+
+type ChatTools struct {
+ FunctionDeclarations any `json:"functionDeclarations,omitempty"`
+}
+
+type ChatGenerationConfig struct {
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+ TopK float64 `json:"topK,omitempty"`
+ MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
+ CandidateCount int `json:"candidateCount,omitempty"`
+ StopSequences []string `json:"stopSequences,omitempty"`
+}
diff --git a/relay/channel/google/adaptor.go b/relay/channel/google/adaptor.go
deleted file mode 100644
index ad45bc48..00000000
--- a/relay/channel/google/adaptor.go
+++ /dev/null
@@ -1,22 +0,0 @@
-package google
-
-import (
- "github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/channel/openai"
- "net/http"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) Auth(c *gin.Context) error {
- return nil
-}
-
-func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
- return nil, nil, nil
-}
diff --git a/relay/channel/google/model.go b/relay/channel/google/model.go
deleted file mode 100644
index e69a9445..00000000
--- a/relay/channel/google/model.go
+++ /dev/null
@@ -1,80 +0,0 @@
-package google
-
-import (
- "github.com/songquanpeng/one-api/relay/channel/openai"
-)
-
-type GeminiChatRequest struct {
- Contents []GeminiChatContent `json:"contents"`
- SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
- GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
- Tools []GeminiChatTools `json:"tools,omitempty"`
-}
-
-type GeminiInlineData struct {
- MimeType string `json:"mimeType"`
- Data string `json:"data"`
-}
-
-type GeminiPart struct {
- Text string `json:"text,omitempty"`
- InlineData *GeminiInlineData `json:"inlineData,omitempty"`
-}
-
-type GeminiChatContent struct {
- Role string `json:"role,omitempty"`
- Parts []GeminiPart `json:"parts"`
-}
-
-type GeminiChatSafetySettings struct {
- Category string `json:"category"`
- Threshold string `json:"threshold"`
-}
-
-type GeminiChatTools struct {
- FunctionDeclarations any `json:"functionDeclarations,omitempty"`
-}
-
-type GeminiChatGenerationConfig struct {
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"topP,omitempty"`
- TopK float64 `json:"topK,omitempty"`
- MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- StopSequences []string `json:"stopSequences,omitempty"`
-}
-
-type PaLMChatMessage struct {
- Author string `json:"author"`
- Content string `json:"content"`
-}
-
-type PaLMFilter struct {
- Reason string `json:"reason"`
- Message string `json:"message"`
-}
-
-type PaLMPrompt struct {
- Messages []PaLMChatMessage `json:"messages"`
-}
-
-type PaLMChatRequest struct {
- Prompt PaLMPrompt `json:"prompt"`
- Temperature float64 `json:"temperature,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- TopP float64 `json:"topP,omitempty"`
- TopK int `json:"topK,omitempty"`
-}
-
-type PaLMError struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Status string `json:"status"`
-}
-
-type PaLMChatResponse struct {
- Candidates []PaLMChatMessage `json:"candidates"`
- Messages []openai.Message `json:"messages"`
- Filters []PaLMFilter `json:"filters"`
- Error PaLMError `json:"error"`
-}
diff --git a/relay/channel/interface.go b/relay/channel/interface.go
index 2a28abb8..e25db83f 100644
--- a/relay/channel/interface.go
+++ b/relay/channel/interface.go
@@ -2,14 +2,19 @@ package channel
import (
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
"net/http"
)
type Adaptor interface {
- GetRequestURL() string
- Auth(c *gin.Context) error
- ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error)
- DoRequest(request *openai.GeneralOpenAIRequest) error
- DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error)
+ Init(meta *util.RelayMeta)
+ GetRequestURL(meta *util.RelayMeta) (string, error)
+ SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error
+ ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
+ DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error)
+ DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
+ GetModelList() []string
+ GetChannelName() string
}
diff --git a/relay/channel/moonshot/constants.go b/relay/channel/moonshot/constants.go
new file mode 100644
index 00000000..1b86f0fa
--- /dev/null
+++ b/relay/channel/moonshot/constants.go
@@ -0,0 +1,7 @@
+package moonshot
+
+var ModelList = []string{
+ "moonshot-v1-8k",
+ "moonshot-v1-32k",
+ "moonshot-v1-128k",
+}
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index cc302611..1313e317 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -1,21 +1,103 @@
package openai
import (
+ "errors"
+ "fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/channel/ai360"
+ "github.com/songquanpeng/one-api/relay/channel/moonshot"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
"net/http"
+ "strings"
)
type Adaptor struct {
+ ChannelType int
}
-func (a *Adaptor) Auth(c *gin.Context) error {
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+ a.ChannelType = meta.ChannelType
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ if meta.ChannelType == common.ChannelTypeAzure {
+ // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
+ requestURL := strings.Split(meta.RequestURLPath, "?")[0]
+ requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
+ task := strings.TrimPrefix(requestURL, "/v1/")
+ model_ := meta.ActualModelName
+ model_ = strings.Replace(model_, ".", "", -1)
+ // https://github.com/songquanpeng/one-api/issues/67
+ model_ = strings.TrimSuffix(model_, "-0301")
+ model_ = strings.TrimSuffix(model_, "-0314")
+ model_ = strings.TrimSuffix(model_, "-0613")
+
+ requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
+ return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
+ }
+ return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ if meta.ChannelType == common.ChannelTypeAzure {
+ req.Header.Set("api-key", meta.APIKey)
+ return nil
+ }
+ req.Header.Set("Authorization", "Bearer "+meta.APIKey)
+ if meta.ChannelType == common.ChannelTypeOpenRouter {
+ req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
+ req.Header.Set("X-Title", "One API")
+ }
return nil
}
-func (a *Adaptor) ConvertRequest(request *GeneralOpenAIRequest) (any, error) {
- return nil, nil
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*ErrorWithStatusCode, *Usage, error) {
- return nil, nil, nil
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ var responseText string
+ err, responseText = StreamHandler(c, resp, meta.Mode)
+ usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ } else {
+ err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ switch a.ChannelType {
+ case common.ChannelType360:
+ return ai360.ModelList
+ case common.ChannelTypeMoonshot:
+ return moonshot.ModelList
+ default:
+ return ModelList
+ }
+}
+
+func (a *Adaptor) GetChannelName() string {
+ switch a.ChannelType {
+ case common.ChannelTypeAzure:
+ return "azure"
+ case common.ChannelType360:
+ return "360"
+ case common.ChannelTypeMoonshot:
+ return "moonshot"
+ default:
+ return "openai"
+ }
}
diff --git a/relay/channel/openai/constants.go b/relay/channel/openai/constants.go
new file mode 100644
index 00000000..ea236ea1
--- /dev/null
+++ b/relay/channel/openai/constants.go
@@ -0,0 +1,19 @@
+package openai
+
+var ModelList = []string{
+ "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
+ "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
+ "gpt-3.5-turbo-instruct",
+ "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
+ "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
+ "gpt-4-turbo-preview",
+ "gpt-4-vision-preview",
+ "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
+ "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
+ "text-moderation-latest", "text-moderation-stable",
+ "text-davinci-edit-001",
+ "davinci-002", "babbage-002",
+ "dall-e-2", "dall-e-3",
+ "whisper-1",
+ "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
+}
diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go
new file mode 100644
index 00000000..9bca8cab
--- /dev/null
+++ b/relay/channel/openai/helper.go
@@ -0,0 +1,11 @@
+package openai
+
+import "github.com/songquanpeng/one-api/relay/model"
+
+func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
+ usage := &model.Usage{}
+ usage.PromptTokens = promptTokens
+ usage.CompletionTokens = CountTokenText(responseText, modeName)
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ return usage
+}
diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go
index f56028a2..fbe55cf9 100644
--- a/relay/channel/openai/main.go
+++ b/relay/channel/openai/main.go
@@ -8,12 +8,13 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
-func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -90,7 +91,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi
return nil, responseText
}
-func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) {
+func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
var textResponse SlimTextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -105,7 +106,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
- return &ErrorWithStatusCode{
+ return &model.ErrorWithStatusCode{
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
@@ -133,9 +134,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
- completionTokens += CountTokenText(choice.Message.StringContent(), model)
+ completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
}
- textResponse.Usage = Usage{
+ textResponse.Usage = model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go
index 937fb424..b24485a8 100644
--- a/relay/channel/openai/model.go
+++ b/relay/channel/openai/model.go
@@ -1,15 +1,6 @@
package openai
-type Message struct {
- Role string `json:"role"`
- Content any `json:"content"`
- Name *string `json:"name,omitempty"`
-}
-
-type ImageURL struct {
- Url string `json:"url,omitempty"`
- Detail string `json:"detail,omitempty"`
-}
+import "github.com/songquanpeng/one-api/relay/model"
type TextContent struct {
Type string `json:"type,omitempty"`
@@ -17,142 +8,21 @@ type TextContent struct {
}
type ImageContent struct {
- Type string `json:"type,omitempty"`
- ImageURL *ImageURL `json:"image_url,omitempty"`
-}
-
-type OpenAIMessageContent struct {
- Type string `json:"type,omitempty"`
- Text string `json:"text"`
- ImageURL *ImageURL `json:"image_url,omitempty"`
-}
-
-func (m Message) IsStringContent() bool {
- _, ok := m.Content.(string)
- return ok
-}
-
-func (m Message) StringContent() string {
- content, ok := m.Content.(string)
- if ok {
- return content
- }
- contentList, ok := m.Content.([]any)
- if ok {
- var contentStr string
- for _, contentItem := range contentList {
- contentMap, ok := contentItem.(map[string]any)
- if !ok {
- continue
- }
- if contentMap["type"] == ContentTypeText {
- if subStr, ok := contentMap["text"].(string); ok {
- contentStr += subStr
- }
- }
- }
- return contentStr
- }
- return ""
-}
-
-func (m Message) ParseContent() []OpenAIMessageContent {
- var contentList []OpenAIMessageContent
- content, ok := m.Content.(string)
- if ok {
- contentList = append(contentList, OpenAIMessageContent{
- Type: ContentTypeText,
- Text: content,
- })
- return contentList
- }
- anyList, ok := m.Content.([]any)
- if ok {
- for _, contentItem := range anyList {
- contentMap, ok := contentItem.(map[string]any)
- if !ok {
- continue
- }
- switch contentMap["type"] {
- case ContentTypeText:
- if subStr, ok := contentMap["text"].(string); ok {
- contentList = append(contentList, OpenAIMessageContent{
- Type: ContentTypeText,
- Text: subStr,
- })
- }
- case ContentTypeImageURL:
- if subObj, ok := contentMap["image_url"].(map[string]any); ok {
- contentList = append(contentList, OpenAIMessageContent{
- Type: ContentTypeImageURL,
- ImageURL: &ImageURL{
- Url: subObj["url"].(string),
- },
- })
- }
- }
- }
- return contentList
- }
- return nil
-}
-
-type ResponseFormat struct {
- Type string `json:"type,omitempty"`
-}
-
-type GeneralOpenAIRequest struct {
- Model string `json:"model,omitempty"`
- Messages []Message `json:"messages,omitempty"`
- Prompt any `json:"prompt,omitempty"`
- Stream bool `json:"stream,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- N int `json:"n,omitempty"`
- Input any `json:"input,omitempty"`
- Instruction string `json:"instruction,omitempty"`
- Size string `json:"size,omitempty"`
- Functions any `json:"functions,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
- ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
- Seed float64 `json:"seed,omitempty"`
- Tools any `json:"tools,omitempty"`
- ToolChoice any `json:"tool_choice,omitempty"`
- User string `json:"user,omitempty"`
-}
-
-func (r GeneralOpenAIRequest) ParseInput() []string {
- if r.Input == nil {
- return nil
- }
- var input []string
- switch r.Input.(type) {
- case string:
- input = []string{r.Input.(string)}
- case []any:
- input = make([]string, 0, len(r.Input.([]any)))
- for _, item := range r.Input.([]any) {
- if str, ok := item.(string); ok {
- input = append(input, str)
- }
- }
- }
- return input
+ Type string `json:"type,omitempty"`
+ ImageURL *model.ImageURL `json:"image_url,omitempty"`
}
type ChatRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- MaxTokens int `json:"max_tokens"`
+ Model string `json:"model"`
+ Messages []model.Message `json:"messages"`
+ MaxTokens int `json:"max_tokens"`
}
type TextRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- Prompt string `json:"prompt"`
- MaxTokens int `json:"max_tokens"`
+ Model string `json:"model"`
+ Messages []model.Message `json:"messages"`
+ Prompt string `json:"prompt"`
+ MaxTokens int `json:"max_tokens"`
//Stream bool `json:"stream"`
}
@@ -201,48 +71,30 @@ type TextToSpeechRequest struct {
ResponseFormat string `json:"response_format"`
}
-type Usage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
-}
-
type UsageOrResponseText struct {
- *Usage
+ *model.Usage
ResponseText string
}
-type Error struct {
- Message string `json:"message"`
- Type string `json:"type"`
- Param string `json:"param"`
- Code any `json:"code"`
-}
-
-type ErrorWithStatusCode struct {
- Error
- StatusCode int `json:"status_code"`
-}
-
type SlimTextResponse struct {
- Choices []TextResponseChoice `json:"choices"`
- Usage `json:"usage"`
- Error Error `json:"error"`
+ Choices []TextResponseChoice `json:"choices"`
+ model.Usage `json:"usage"`
+ Error model.Error `json:"error"`
}
type TextResponseChoice struct {
- Index int `json:"index"`
- Message `json:"message"`
- FinishReason string `json:"finish_reason"`
+ Index int `json:"index"`
+ model.Message `json:"message"`
+ FinishReason string `json:"finish_reason"`
}
type TextResponse struct {
- Id string `json:"id"`
- Model string `json:"model,omitempty"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Choices []TextResponseChoice `json:"choices"`
- Usage `json:"usage"`
+ Id string `json:"id"`
+ Model string `json:"model,omitempty"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Choices []TextResponseChoice `json:"choices"`
+ model.Usage `json:"usage"`
}
type EmbeddingResponseItem struct {
@@ -252,10 +104,10 @@ type EmbeddingResponseItem struct {
}
type EmbeddingResponse struct {
- Object string `json:"object"`
- Data []EmbeddingResponseItem `json:"data"`
- Model string `json:"model"`
- Usage `json:"usage"`
+ Object string `json:"object"`
+ Data []EmbeddingResponseItem `json:"data"`
+ Model string `json:"model"`
+ model.Usage `json:"usage"`
}
type ImageResponse struct {
@@ -266,8 +118,10 @@ type ImageResponse struct {
}
type ChatCompletionsStreamResponseChoice struct {
+ Index int `json:"index"`
Delta struct {
Content string `json:"content"`
+ Role string `json:"role,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
diff --git a/relay/channel/openai/token.go b/relay/channel/openai/token.go
index 686ac39f..0720425f 100644
--- a/relay/channel/openai/token.go
+++ b/relay/channel/openai/token.go
@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/model"
"math"
"strings"
)
@@ -63,7 +64,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil))
}
-func CountTokenMessages(messages []Message, model string) int {
+func CountTokenMessages(messages []model.Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
diff --git a/relay/channel/openai/util.go b/relay/channel/openai/util.go
index 69ece6b3..ba0cab7d 100644
--- a/relay/channel/openai/util.go
+++ b/relay/channel/openai/util.go
@@ -1,12 +1,14 @@
package openai
-func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode {
- Error := Error{
+import "github.com/songquanpeng/one-api/relay/model"
+
+func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
+ Error := model.Error{
Message: err.Error(),
Type: "one_api_error",
Code: code,
}
- return &ErrorWithStatusCode{
+ return &model.ErrorWithStatusCode{
Error: Error,
StatusCode: statusCode,
}
diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go
new file mode 100644
index 00000000..efd0620c
--- /dev/null
+++ b/relay/channel/palm/adaptor.go
@@ -0,0 +1,60 @@
+package palm
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("x-goog-api-key", meta.APIKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return ConvertRequest(*request), nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ var responseText string
+ err, responseText = StreamHandler(c, resp)
+ usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ } else {
+ err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "google palm"
+}
diff --git a/relay/channel/palm/constants.go b/relay/channel/palm/constants.go
new file mode 100644
index 00000000..a8349362
--- /dev/null
+++ b/relay/channel/palm/constants.go
@@ -0,0 +1,5 @@
+package palm
+
+var ModelList = []string{
+ "PaLM-2",
+}
diff --git a/relay/channel/palm/model.go b/relay/channel/palm/model.go
new file mode 100644
index 00000000..f653022c
--- /dev/null
+++ b/relay/channel/palm/model.go
@@ -0,0 +1,40 @@
+package palm
+
+import (
+ "github.com/songquanpeng/one-api/relay/model"
+)
+
+type ChatMessage struct {
+ Author string `json:"author"`
+ Content string `json:"content"`
+}
+
+type Filter struct {
+ Reason string `json:"reason"`
+ Message string `json:"message"`
+}
+
+type Prompt struct {
+ Messages []ChatMessage `json:"messages"`
+}
+
+type ChatRequest struct {
+ Prompt Prompt `json:"prompt"`
+ Temperature float64 `json:"temperature,omitempty"`
+ CandidateCount int `json:"candidateCount,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+ TopK int `json:"topK,omitempty"`
+}
+
+type Error struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Status string `json:"status"`
+}
+
+type ChatResponse struct {
+ Candidates []ChatMessage `json:"candidates"`
+ Messages []model.Message `json:"messages"`
+ Filters []Filter `json:"filters"`
+ Error Error `json:"error"`
+}
diff --git a/relay/channel/google/palm.go b/relay/channel/palm/palm.go
similarity index 84%
rename from relay/channel/google/palm.go
rename to relay/channel/palm/palm.go
index 7b9ee600..56738544 100644
--- a/relay/channel/google/palm.go
+++ b/relay/channel/palm/palm.go
@@ -1,4 +1,4 @@
-package google
+package palm
import (
"encoding/json"
@@ -9,6 +9,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
@@ -16,10 +17,10 @@ import (
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
-func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest {
- palmRequest := PaLMChatRequest{
- Prompt: PaLMPrompt{
- Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
+func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
+ palmRequest := ChatRequest{
+ Prompt: Prompt{
+ Messages: make([]ChatMessage, 0, len(textRequest.Messages)),
},
Temperature: textRequest.Temperature,
CandidateCount: textRequest.N,
@@ -27,7 +28,7 @@ func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatReques
TopK: textRequest.MaxTokens,
}
for _, message := range textRequest.Messages {
- palmMessage := PaLMChatMessage{
+ palmMessage := ChatMessage{
Content: message.StringContent(),
}
if message.Role == "user" {
@@ -40,14 +41,14 @@ func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatReques
return &palmRequest
}
-func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
+func responsePaLM2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: candidate.Content,
},
@@ -58,7 +59,7 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
return &fullTextResponse
}
-func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse {
+func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
@@ -71,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompl
return &response
}
-func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
createdTime := helper.GetTimestamp()
@@ -90,7 +91,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt
stopChan <- true
return
}
- var palmResponse PaLMChatResponse
+ var palmResponse ChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
@@ -130,7 +131,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt
return nil, responseText
}
-func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -139,14 +140,14 @@ func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
- var palmResponse PaLMChatResponse
+ var palmResponse ChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
@@ -156,9 +157,9 @@ func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
- fullTextResponse.Model = model
- completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model)
- usage := openai.Usage{
+ fullTextResponse.Model = modelName
+ completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, modelName)
+ usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go
index c90509ca..f348674e 100644
--- a/relay/channel/tencent/adaptor.go
+++ b/relay/channel/tencent/adaptor.go
@@ -1,22 +1,76 @@
package tencent
import (
+ "errors"
+ "fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
"net/http"
+ "strings"
)
+// https://cloud.tencent.com/document/api/1729/101837
+
type Adaptor struct {
+ Sign string
}
-func (a *Adaptor) Auth(c *gin.Context) error {
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("Authorization", a.Sign)
+ req.Header.Set("X-TC-Action", meta.ActualModelName)
return nil
}
-func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
- return nil, nil
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ apiKey := c.Request.Header.Get("Authorization")
+ apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+ appId, secretId, secretKey, err := ParseConfig(apiKey)
+ if err != nil {
+ return nil, err
+ }
+ tencentRequest := ConvertRequest(*request)
+ tencentRequest.AppId = appId
+ tencentRequest.SecretId = secretId
+ // we have to calculate the sign here
+ a.Sign = GetSign(*tencentRequest, secretKey)
+ return tencentRequest, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
- return nil, nil, nil
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ var responseText string
+ err, responseText = StreamHandler(c, resp)
+ usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ } else {
+ err, usage = Handler(c, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "tencent"
}
diff --git a/relay/channel/tencent/constants.go b/relay/channel/tencent/constants.go
new file mode 100644
index 00000000..fe176c2c
--- /dev/null
+++ b/relay/channel/tencent/constants.go
@@ -0,0 +1,7 @@
+package tencent
+
+var ModelList = []string{
+ "ChatPro",
+ "ChatStd",
+ "hunyuan",
+}
diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go
index 784f86fd..05edac20 100644
--- a/relay/channel/tencent/main.go
+++ b/relay/channel/tencent/main.go
@@ -14,6 +14,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"sort"
@@ -23,7 +24,7 @@ import (
// https://cloud.tencent.com/document/product/1729/97732
-func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
+func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
@@ -67,7 +68,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
if len(response.Choices) > 0 {
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: response.Choices[0].Messages.Content,
},
@@ -95,7 +96,7 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom
return &response
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -159,7 +160,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, responseText
}
-func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var TencentResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -174,8 +175,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
diff --git a/relay/channel/tencent/model.go b/relay/channel/tencent/model.go
index b8aa7698..71286be9 100644
--- a/relay/channel/tencent/model.go
+++ b/relay/channel/tencent/model.go
@@ -1,7 +1,7 @@
package tencent
import (
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
@@ -56,7 +56,7 @@ type ChatResponse struct {
Choices []ResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
- Usage openai.Usage `json:"usage,omitempty"` // token 数量
+ Usage model.Usage `json:"usage,omitempty"` // token 数量
Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go
index d2c80c64..92d9d7d6 100644
--- a/relay/channel/xunfei/adaptor.go
+++ b/relay/channel/xunfei/adaptor.go
@@ -1,22 +1,70 @@
package xunfei
import (
+ "errors"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
"net/http"
+ "strings"
)
type Adaptor struct {
+ request *model.GeneralOpenAIRequest
}
-func (a *Adaptor) Auth(c *gin.Context) error {
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ return "", nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ // check DoResponse for auth part
return nil
}
-func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ a.request = request
return nil, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
- return nil, nil, nil
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ // xunfei's request is not http request, so we don't need to do anything here
+ dummyResp := &http.Response{}
+ dummyResp.StatusCode = http.StatusOK
+ return dummyResp, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ splits := strings.Split(meta.APIKey, "|")
+ if len(splits) != 3 {
+ return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
+ }
+ if a.request == nil {
+ return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
+ }
+ if meta.IsStream {
+ err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2])
+ } else {
+ err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2])
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "xunfei"
}
diff --git a/relay/channel/xunfei/constants.go b/relay/channel/xunfei/constants.go
new file mode 100644
index 00000000..31dcec71
--- /dev/null
+++ b/relay/channel/xunfei/constants.go
@@ -0,0 +1,9 @@
+package xunfei
+
+var ModelList = []string{
+ "SparkDesk",
+ "SparkDesk-v1.1",
+ "SparkDesk-v2.1",
+ "SparkDesk-v3.1",
+ "SparkDesk-v3.5",
+}
diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go
index ff5cdbea..620e808f 100644
--- a/relay/channel/xunfei/main.go
+++ b/relay/channel/xunfei/main.go
@@ -13,6 +13,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"net/url"
@@ -23,7 +24,7 @@ import (
// https://console.xfyun.cn/services/cbm
// https://www.xfyun.cn/doc/spark/Web.html
-func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
+func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
@@ -62,13 +63,14 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
}
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
FinishReason: constant.StopFinishReason,
}
fullTextResponse := openai.TextResponse{
+ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
@@ -91,6 +93,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
choice.FinishReason = &constant.StopFinishReason
}
response := openai.ChatCompletionsStreamResponse{
+ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "SparkDesk",
@@ -125,14 +128,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
return callUrl
}
-func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
- domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
+func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
+ domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
common.SetEventStreamHeaders(c)
- var usage openai.Usage
+ var usage model.Usage
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
@@ -155,13 +158,13 @@ func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appI
return nil, &usage
}
-func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
- domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
+func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
+ domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
- var usage openai.Usage
+ var usage model.Usage
var content string
var xunfeiResponse ChatResponse
stop := false
@@ -197,7 +200,7 @@ func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId stri
return nil, &usage
}
-func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
+func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
@@ -241,20 +244,45 @@ func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl,
return dataChan, stopChan, nil
}
-func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
+func getAPIVersion(c *gin.Context, modelName string) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
- if apiVersion == "" {
- apiVersion = c.GetString("api_version")
+ if apiVersion != "" {
+ return apiVersion
}
- if apiVersion == "" {
- apiVersion = "v1.1"
- logger.SysLog("api_version not found, use default: " + apiVersion)
+ parts := strings.Split(modelName, "-")
+ if len(parts) == 2 {
+ apiVersion = parts[1]
+ return apiVersion
+
}
- domain := "general"
- if apiVersion != "v1.1" {
- domain += strings.Split(apiVersion, ".")[0]
+ apiVersion = c.GetString(common.ConfigKeyAPIVersion)
+ if apiVersion != "" {
+ return apiVersion
}
+ apiVersion = "v1.1"
+ logger.SysLog("api_version not found, using default: " + apiVersion)
+ return apiVersion
+}
+
+// https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
+func apiVersion2domain(apiVersion string) string {
+ switch apiVersion {
+ case "v1.1":
+ return "general"
+ case "v2.1":
+ return "generalv2"
+ case "v3.1":
+ return "generalv3"
+ case "v3.5":
+ return "generalv3.5"
+ }
+ return "general" + apiVersion
+}
+
+func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
+ apiVersion := getAPIVersion(c, modelName)
+ domain := apiVersion2domain(apiVersion)
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl
}
diff --git a/relay/channel/xunfei/model.go b/relay/channel/xunfei/model.go
index e015d164..1266739d 100644
--- a/relay/channel/xunfei/model.go
+++ b/relay/channel/xunfei/model.go
@@ -1,7 +1,7 @@
package xunfei
import (
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
@@ -55,7 +55,7 @@ type ChatResponse struct {
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
- Text openai.Usage `json:"text"`
+ Text model.Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go
index ae0f6faa..7a822853 100644
--- a/relay/channel/zhipu/adaptor.go
+++ b/relay/channel/zhipu/adaptor.go
@@ -1,22 +1,62 @@
package zhipu
import (
+ "errors"
+ "fmt"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
"net/http"
)
type Adaptor struct {
}
-func (a *Adaptor) Auth(c *gin.Context) error {
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ method := "invoke"
+ if meta.IsStream {
+ method = "sse-invoke"
+ }
+ return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ token := GetToken(meta.APIKey)
+ req.Header.Set("Authorization", token)
return nil
}
-func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
- return nil, nil
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return ConvertRequest(*request), nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
- return nil, nil, nil
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ err, usage = StreamHandler(c, resp)
+ } else {
+ err, usage = Handler(c, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "zhipu"
}
diff --git a/relay/channel/zhipu/constants.go b/relay/channel/zhipu/constants.go
new file mode 100644
index 00000000..f0367b82
--- /dev/null
+++ b/relay/channel/zhipu/constants.go
@@ -0,0 +1,5 @@
+package zhipu
+
+var ModelList = []string{
+ "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
+}
diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go
index fe7b7b4b..7c3e83f3 100644
--- a/relay/channel/zhipu/main.go
+++ b/relay/channel/zhipu/main.go
@@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
@@ -72,7 +73,7 @@ func GetToken(apikey string) string {
return tokenString
}
-func ConvertRequest(request openai.GeneralOpenAIRequest) *Request {
+func ConvertRequest(request model.GeneralOpenAIRequest) *Request {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
@@ -110,7 +111,7 @@ func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
for i, choice := range response.Data.Choices {
openaiChoice := openai.TextResponseChoice{
Index: i,
- Message: openai.Message{
+ Message: model.Message{
Role: choice.Role,
Content: strings.Trim(choice.Content, "\""),
},
@@ -136,7 +137,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStr
return &response
}
-func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) {
+func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *model.Usage) {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = ""
choice.FinishReason = &constant.StopFinishReason
@@ -150,8 +151,8 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.
return &response, &zhipuResponse.Usage
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
- var usage *openai.Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var usage *model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -228,7 +229,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, usage
}
-func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var zhipuResponse Response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -243,8 +244,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if !zhipuResponse.Success {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",
diff --git a/relay/channel/zhipu/model.go b/relay/channel/zhipu/model.go
index 67f1caeb..b63e1d6f 100644
--- a/relay/channel/zhipu/model.go
+++ b/relay/channel/zhipu/model.go
@@ -1,7 +1,7 @@
package zhipu
import (
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
"time"
)
@@ -19,11 +19,11 @@ type Request struct {
}
type ResponseData struct {
- TaskId string `json:"task_id"`
- RequestId string `json:"request_id"`
- TaskStatus string `json:"task_status"`
- Choices []Message `json:"choices"`
- openai.Usage `json:"usage"`
+ TaskId string `json:"task_id"`
+ RequestId string `json:"request_id"`
+ TaskStatus string `json:"task_status"`
+ Choices []Message `json:"choices"`
+ model.Usage `json:"usage"`
}
type Response struct {
@@ -34,10 +34,10 @@ type Response struct {
}
type StreamMetaResponse struct {
- RequestId string `json:"request_id"`
- TaskId string `json:"task_id"`
- TaskStatus string `json:"task_status"`
- openai.Usage `json:"usage"`
+ RequestId string `json:"request_id"`
+ TaskId string `json:"task_id"`
+ TaskStatus string `json:"task_status"`
+ model.Usage `json:"usage"`
}
type tokenData struct {
diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go
index e0458279..d2184dac 100644
--- a/relay/constant/api_type.go
+++ b/relay/constant/api_type.go
@@ -6,7 +6,7 @@ import (
const (
APITypeOpenAI = iota
- APITypeClaude
+ APITypeAnthropic
APITypePaLM
APITypeBaidu
APITypeZhipu
@@ -15,13 +15,15 @@ const (
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
+
+ APITypeDummy // this one is only for count, do not add any channel after this
)
func ChannelType2APIType(channelType int) int {
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
- apiType = APITypeClaude
+ apiType = APITypeAnthropic
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
@@ -41,29 +43,3 @@ func ChannelType2APIType(channelType int) int {
}
return apiType
}
-
-//func GetAdaptor(apiType int) channel.Adaptor {
-// switch apiType {
-// case APITypeOpenAI:
-// return &openai.Adaptor{}
-// case APITypeClaude:
-// return &anthropic.Adaptor{}
-// case APITypePaLM:
-// return &google.Adaptor{}
-// case APITypeZhipu:
-// return &baidu.Adaptor{}
-// case APITypeBaidu:
-// return &baidu.Adaptor{}
-// case APITypeAli:
-// return &ali.Adaptor{}
-// case APITypeXunfei:
-// return &xunfei.Adaptor{}
-// case APITypeAIProxyLibrary:
-// return &aiproxy.Adaptor{}
-// case APITypeTencent:
-// return &tencent.Adaptor{}
-// case APITypeGemini:
-// return &google.Adaptor{}
-// }
-// return nil
-//}
diff --git a/relay/controller/audio.go b/relay/controller/audio.go
index cbbd8a04..ee8771c9 100644
--- a/relay/controller/audio.go
+++ b/relay/controller/audio.go
@@ -14,13 +14,14 @@ import (
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
)
-func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
+func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
audioModel := "whisper-1"
tokenId := c.GetInt("token_id")
diff --git a/relay/controller/helper.go b/relay/controller/helper.go
index 6154f291..a06b2768 100644
--- a/relay/controller/helper.go
+++ b/relay/controller/helper.go
@@ -11,14 +11,14 @@ import (
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
- "io"
"math"
"net/http"
)
-func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOpenAIRequest, error) {
- textRequest := &openai.GeneralOpenAIRequest{}
+func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
+ textRequest := &relaymodel.GeneralOpenAIRequest{}
err := common.UnmarshalBodyReusable(c, textRequest)
if err != nil {
return nil, err
@@ -36,7 +36,7 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOp
return textRequest, nil
}
-func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) int {
+func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode {
case constant.RelayModeChatCompletions:
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
@@ -48,7 +48,7 @@ func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) in
return 0
}
-func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
+func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
preConsumedTokens := config.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
@@ -56,7 +56,7 @@ func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens
return int(float64(preConsumedTokens) * ratio)
}
-func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *openai.ErrorWithStatusCode) {
+func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) {
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
userQuota, err := model.CacheGetUserQuota(meta.UserId)
@@ -85,7 +85,7 @@ func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIReque
return preConsumedQuota, nil
}
-func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
+func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
@@ -120,27 +120,3 @@ func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.Relay
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
}
-
-func doRequest(ctx context.Context, c *gin.Context, meta *util.RelayMeta, isStream bool, fullRequestURL string, requestBody io.Reader) (*http.Response, error) {
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
- if err != nil {
- return nil, err
- }
- SetupRequestHeaders(c, req, meta, isStream)
- resp, err := util.HTTPClient.Do(req)
- if err != nil {
- return nil, err
- }
- if resp == nil {
- return nil, errors.New("resp is nil")
- }
- err = req.Body.Close()
- if err != nil {
- logger.Warnf(ctx, "close req.Body failed: %+v", err)
- }
- err = c.Request.Body.Close()
- if err != nil {
- logger.Warnf(ctx, "close c.Request.Body failed: %+v", err)
- }
- return resp, nil
-}
diff --git a/relay/controller/image.go b/relay/controller/image.go
index c64e001b..6ec368f5 100644
--- a/relay/controller/image.go
+++ b/relay/controller/image.go
@@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
@@ -28,7 +29,7 @@ func isWithinRange(element string, value int) bool {
return value >= min && value <= max
}
-func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
+func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
imageModel := "dall-e-2"
imageSize := "1024x1024"
diff --git a/relay/controller/temp.go b/relay/controller/temp.go
deleted file mode 100644
index 6339bdab..00000000
--- a/relay/controller/temp.go
+++ /dev/null
@@ -1,337 +0,0 @@
-package controller
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/common"
- "github.com/songquanpeng/one-api/common/helper"
- "github.com/songquanpeng/one-api/relay/channel/aiproxy"
- "github.com/songquanpeng/one-api/relay/channel/ali"
- "github.com/songquanpeng/one-api/relay/channel/anthropic"
- "github.com/songquanpeng/one-api/relay/channel/baidu"
- "github.com/songquanpeng/one-api/relay/channel/google"
- "github.com/songquanpeng/one-api/relay/channel/openai"
- "github.com/songquanpeng/one-api/relay/channel/tencent"
- "github.com/songquanpeng/one-api/relay/channel/xunfei"
- "github.com/songquanpeng/one-api/relay/channel/zhipu"
- "github.com/songquanpeng/one-api/relay/constant"
- "github.com/songquanpeng/one-api/relay/util"
- "io"
- "net/http"
- "strings"
-)
-
-func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) {
- fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
- switch meta.APIType {
- case constant.APITypeOpenAI:
- if meta.ChannelType == common.ChannelTypeAzure {
- // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
- requestURL := strings.Split(requestURL, "?")[0]
- requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
- task := strings.TrimPrefix(requestURL, "/v1/")
- model_ := textRequest.Model
- model_ = strings.Replace(model_, ".", "", -1)
- // https://github.com/songquanpeng/one-api/issues/67
- model_ = strings.TrimSuffix(model_, "-0301")
- model_ = strings.TrimSuffix(model_, "-0314")
- model_ = strings.TrimSuffix(model_, "-0613")
-
- requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
- fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
- }
- case constant.APITypeClaude:
- fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL)
- case constant.APITypeBaidu:
- switch textRequest.Model {
- case "ERNIE-Bot":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
- case "ERNIE-Bot-turbo":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
- case "ERNIE-Bot-4":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
- case "BLOOMZ-7B":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
- case "Embedding-V1":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
- }
- var accessToken string
- var err error
- if accessToken, err = baidu.GetAccessToken(meta.APIKey); err != nil {
- return "", fmt.Errorf("failed to get baidu access token: %w", err)
- }
- fullRequestURL += "?access_token=" + accessToken
- case constant.APITypePaLM:
- fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL)
- case constant.APITypeGemini:
- version := helper.AssignOrDefault(meta.APIVersion, "v1")
- action := "generateContent"
- if textRequest.Stream {
- action = "streamGenerateContent"
- }
- fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, textRequest.Model, action)
- case constant.APITypeZhipu:
- method := "invoke"
- if textRequest.Stream {
- method = "sse-invoke"
- }
- fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
- case constant.APITypeAli:
- fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
- if meta.Mode == constant.RelayModeEmbeddings {
- fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
- }
- case constant.APITypeTencent:
- fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
- case constant.APITypeAIProxyLibrary:
- fullRequestURL = fmt.Sprintf("%s/api/library/ask", meta.BaseURL)
- }
- return fullRequestURL, nil
-}
-
-func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
- var requestBody io.Reader
- if isModelMapped {
- jsonStr, err := json.Marshal(textRequest)
- if err != nil {
- return nil, err
- }
- requestBody = bytes.NewBuffer(jsonStr)
- } else {
- requestBody = c.Request.Body
- }
- switch apiType {
- case constant.APITypeClaude:
- claudeRequest := anthropic.ConvertRequest(textRequest)
- jsonStr, err := json.Marshal(claudeRequest)
- if err != nil {
- return nil, err
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case constant.APITypeBaidu:
- var jsonData []byte
- var err error
- switch relayMode {
- case constant.RelayModeEmbeddings:
- baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest)
- jsonData, err = json.Marshal(baiduEmbeddingRequest)
- default:
- baiduRequest := baidu.ConvertRequest(textRequest)
- jsonData, err = json.Marshal(baiduRequest)
- }
- if err != nil {
- return nil, err
- }
- requestBody = bytes.NewBuffer(jsonData)
- case constant.APITypePaLM:
- palmRequest := google.ConvertPaLMRequest(textRequest)
- jsonStr, err := json.Marshal(palmRequest)
- if err != nil {
- return nil, err
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case constant.APITypeGemini:
- geminiChatRequest := google.ConvertGeminiRequest(textRequest)
- jsonStr, err := json.Marshal(geminiChatRequest)
- if err != nil {
- return nil, err
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case constant.APITypeZhipu:
- zhipuRequest := zhipu.ConvertRequest(textRequest)
- jsonStr, err := json.Marshal(zhipuRequest)
- if err != nil {
- return nil, err
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case constant.APITypeAli:
- var jsonStr []byte
- var err error
- switch relayMode {
- case constant.RelayModeEmbeddings:
- aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest)
- jsonStr, err = json.Marshal(aliEmbeddingRequest)
- default:
- aliRequest := ali.ConvertRequest(textRequest)
- jsonStr, err = json.Marshal(aliRequest)
- }
- if err != nil {
- return nil, err
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case constant.APITypeTencent:
- apiKey := c.Request.Header.Get("Authorization")
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
- appId, secretId, secretKey, err := tencent.ParseConfig(apiKey)
- if err != nil {
- return nil, err
- }
- tencentRequest := tencent.ConvertRequest(textRequest)
- tencentRequest.AppId = appId
- tencentRequest.SecretId = secretId
- jsonStr, err := json.Marshal(tencentRequest)
- if err != nil {
- return nil, err
- }
- sign := tencent.GetSign(*tencentRequest, secretKey)
- c.Request.Header.Set("Authorization", sign)
- requestBody = bytes.NewBuffer(jsonStr)
- case constant.APITypeAIProxyLibrary:
- aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest)
- aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
- jsonStr, err := json.Marshal(aiProxyLibraryRequest)
- if err != nil {
- return nil, err
- }
- requestBody = bytes.NewBuffer(jsonStr)
- }
- return requestBody, nil
-}
-
-func SetupRequestHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
- SetupAuthHeaders(c, req, meta, isStream)
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- if isStream && c.Request.Header.Get("Accept") == "" {
- req.Header.Set("Accept", "text/event-stream")
- }
-}
-
-func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
- apiKey := meta.APIKey
- switch meta.APIType {
- case constant.APITypeOpenAI:
- if meta.ChannelType == common.ChannelTypeAzure {
- req.Header.Set("api-key", apiKey)
- } else {
- req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
- if meta.ChannelType == common.ChannelTypeOpenRouter {
- req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
- req.Header.Set("X-Title", "One API")
- }
- }
- case constant.APITypeClaude:
- req.Header.Set("x-api-key", apiKey)
- anthropicVersion := c.Request.Header.Get("anthropic-version")
- if anthropicVersion == "" {
- anthropicVersion = "2023-06-01"
- }
- req.Header.Set("anthropic-version", anthropicVersion)
- case constant.APITypeZhipu:
- token := zhipu.GetToken(apiKey)
- req.Header.Set("Authorization", token)
- case constant.APITypeAli:
- req.Header.Set("Authorization", "Bearer "+apiKey)
- if isStream {
- req.Header.Set("X-DashScope-SSE", "enable")
- }
- if c.GetString("plugin") != "" {
- req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
- }
- case constant.APITypeTencent:
- req.Header.Set("Authorization", apiKey)
- case constant.APITypePaLM:
- req.Header.Set("x-goog-api-key", apiKey)
- case constant.APITypeGemini:
- req.Header.Set("x-goog-api-key", apiKey)
- default:
- req.Header.Set("Authorization", "Bearer "+apiKey)
- }
-}
-
-func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) {
- var responseText string
- switch apiType {
- case constant.APITypeOpenAI:
- if isStream {
- err, responseText = openai.StreamHandler(c, resp, relayMode)
- } else {
- err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model)
- }
- case constant.APITypeClaude:
- if isStream {
- err, responseText = anthropic.StreamHandler(c, resp)
- } else {
- err, usage = anthropic.Handler(c, resp, promptTokens, textRequest.Model)
- }
- case constant.APITypeBaidu:
- if isStream {
- err, usage = baidu.StreamHandler(c, resp)
- } else {
- switch relayMode {
- case constant.RelayModeEmbeddings:
- err, usage = baidu.EmbeddingHandler(c, resp)
- default:
- err, usage = baidu.Handler(c, resp)
- }
- }
- case constant.APITypePaLM:
- if isStream { // PaLM2 API does not support stream
- err, responseText = google.PaLMStreamHandler(c, resp)
- } else {
- err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
- }
- case constant.APITypeGemini:
- if isStream {
- err, responseText = google.StreamHandler(c, resp)
- } else {
- err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
- }
- case constant.APITypeZhipu:
- if isStream {
- err, usage = zhipu.StreamHandler(c, resp)
- } else {
- err, usage = zhipu.Handler(c, resp)
- }
- case constant.APITypeAli:
- if isStream {
- err, usage = ali.StreamHandler(c, resp)
- } else {
- switch relayMode {
- case constant.RelayModeEmbeddings:
- err, usage = ali.EmbeddingHandler(c, resp)
- default:
- err, usage = ali.Handler(c, resp)
- }
- }
- case constant.APITypeXunfei:
- auth := c.Request.Header.Get("Authorization")
- auth = strings.TrimPrefix(auth, "Bearer ")
- splits := strings.Split(auth, "|")
- if len(splits) != 3 {
- return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
- }
- if isStream {
- err, usage = xunfei.StreamHandler(c, *textRequest, splits[0], splits[1], splits[2])
- } else {
- err, usage = xunfei.Handler(c, *textRequest, splits[0], splits[1], splits[2])
- }
- case constant.APITypeAIProxyLibrary:
- if isStream {
- err, usage = aiproxy.StreamHandler(c, resp)
- } else {
- err, usage = aiproxy.Handler(c, resp)
- }
- case constant.APITypeTencent:
- if isStream {
- err, responseText = tencent.StreamHandler(c, resp)
- } else {
- err, usage = tencent.Handler(c, resp)
- }
- default:
- return nil, openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
- }
- if err != nil {
- return nil, err
- }
- if usage == nil && responseText != "" {
- usage = &openai.Usage{}
- usage.PromptTokens = promptTokens
- usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- }
- return usage, nil
-}
diff --git a/relay/controller/text.go b/relay/controller/text.go
index 0445aa90..cc460511 100644
--- a/relay/controller/text.go
+++ b/relay/controller/text.go
@@ -1,18 +1,23 @@
package controller
import (
+ "bytes"
+ "encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/helper"
+ "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
+ "io"
"net/http"
"strings"
)
-func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode {
+func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := util.GetRelayMeta(c)
// get & validate textRequest
@@ -21,50 +26,70 @@ func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode {
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
}
+ meta.IsStream = textRequest.Stream
+
// map model name
var isModelMapped bool
+ meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
+ meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := common.GetModelRatio(textRequest.Model)
groupRatio := common.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
// pre-consume quota
promptTokens := getPromptTokens(textRequest, meta.Mode)
+ meta.PromptTokens = promptTokens
preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
if bizErr != nil {
logger.Warnf(ctx, "preConsumeQuota failed: %+v", *bizErr)
return bizErr
}
+ adaptor := helper.GetAdaptor(meta.APIType)
+ if adaptor == nil {
+ return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
+ }
+
// get request body
- requestBody, err := GetRequestBody(c, *textRequest, isModelMapped, meta.APIType, meta.Mode)
- if err != nil {
- return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError)
+ var requestBody io.Reader
+ if meta.APIType == constant.APITypeOpenAI {
+ // no need to convert request for openai
+ if isModelMapped {
+ jsonStr, err := json.Marshal(textRequest)
+ if err != nil {
+ return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ } else {
+ requestBody = c.Request.Body
+ }
+ } else {
+ convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
+ if err != nil {
+ return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
+ }
+ jsonData, err := json.Marshal(convertedRequest)
+ if err != nil {
+ return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
+ }
+ requestBody = bytes.NewBuffer(jsonData)
}
+
// do request
- var resp *http.Response
- isStream := textRequest.Stream
- if meta.APIType != constant.APITypeXunfei { // cause xunfei use websocket
- fullRequestURL, err := GetRequestURL(c.Request.URL.String(), meta, textRequest)
- if err != nil {
- logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
- return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
- }
-
- resp, err = doRequest(ctx, c, meta, isStream, fullRequestURL, requestBody)
- if err != nil {
- logger.Errorf(ctx, "doRequest failed: %s", err.Error())
- return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- }
- isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
-
- if resp.StatusCode != http.StatusOK {
- util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
- return util.RelayErrorHandler(resp)
- }
+ resp, err := adaptor.DoRequest(c, meta, requestBody)
+ if err != nil {
+ logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
+ return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
+ meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
+ if resp.StatusCode != http.StatusOK {
+ util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
+ return util.RelayErrorHandler(resp)
+ }
+
// do response
- usage, respErr := DoResponse(c, textRequest, resp, meta.Mode, meta.APIType, isStream, promptTokens)
+ usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
diff --git a/relay/helper/main.go b/relay/helper/main.go
new file mode 100644
index 00000000..c2b6e6af
--- /dev/null
+++ b/relay/helper/main.go
@@ -0,0 +1,42 @@
+package helper
+
+import (
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/channel/aiproxy"
+ "github.com/songquanpeng/one-api/relay/channel/ali"
+ "github.com/songquanpeng/one-api/relay/channel/anthropic"
+ "github.com/songquanpeng/one-api/relay/channel/baidu"
+ "github.com/songquanpeng/one-api/relay/channel/gemini"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/channel/palm"
+ "github.com/songquanpeng/one-api/relay/channel/tencent"
+ "github.com/songquanpeng/one-api/relay/channel/xunfei"
+ "github.com/songquanpeng/one-api/relay/channel/zhipu"
+ "github.com/songquanpeng/one-api/relay/constant"
+)
+
+func GetAdaptor(apiType int) channel.Adaptor {
+ switch apiType {
+ case constant.APITypeAIProxyLibrary:
+ return &aiproxy.Adaptor{}
+ case constant.APITypeAli:
+ return &ali.Adaptor{}
+ case constant.APITypeAnthropic:
+ return &anthropic.Adaptor{}
+ case constant.APITypeBaidu:
+ return &baidu.Adaptor{}
+ case constant.APITypeGemini:
+ return &gemini.Adaptor{}
+ case constant.APITypeOpenAI:
+ return &openai.Adaptor{}
+ case constant.APITypePaLM:
+ return &palm.Adaptor{}
+ case constant.APITypeTencent:
+ return &tencent.Adaptor{}
+ case constant.APITypeXunfei:
+ return &xunfei.Adaptor{}
+ case constant.APITypeZhipu:
+ return &zhipu.Adaptor{}
+ }
+ return nil
+}
diff --git a/relay/channel/openai/constant.go b/relay/model/constant.go
similarity index 83%
rename from relay/channel/openai/constant.go
rename to relay/model/constant.go
index 000f72ee..f6cf1924 100644
--- a/relay/channel/openai/constant.go
+++ b/relay/model/constant.go
@@ -1,4 +1,4 @@
-package openai
+package model
const (
ContentTypeText = "text"
diff --git a/relay/model/general.go b/relay/model/general.go
new file mode 100644
index 00000000..fbcc04e8
--- /dev/null
+++ b/relay/model/general.go
@@ -0,0 +1,46 @@
+package model
+
+type ResponseFormat struct {
+ Type string `json:"type,omitempty"`
+}
+
+type GeneralOpenAIRequest struct {
+ Model string `json:"model,omitempty"`
+ Messages []Message `json:"messages,omitempty"`
+ Prompt any `json:"prompt,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ N int `json:"n,omitempty"`
+ Input any `json:"input,omitempty"`
+ Instruction string `json:"instruction,omitempty"`
+ Size string `json:"size,omitempty"`
+ Functions any `json:"functions,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
+ Seed float64 `json:"seed,omitempty"`
+ Tools any `json:"tools,omitempty"`
+ ToolChoice any `json:"tool_choice,omitempty"`
+ User string `json:"user,omitempty"`
+}
+
+func (r GeneralOpenAIRequest) ParseInput() []string {
+ if r.Input == nil {
+ return nil
+ }
+ var input []string
+ switch r.Input.(type) {
+ case string:
+ input = []string{r.Input.(string)}
+ case []any:
+ input = make([]string, 0, len(r.Input.([]any)))
+ for _, item := range r.Input.([]any) {
+ if str, ok := item.(string); ok {
+ input = append(input, str)
+ }
+ }
+ }
+ return input
+}
diff --git a/relay/model/message.go b/relay/model/message.go
new file mode 100644
index 00000000..c6c8a271
--- /dev/null
+++ b/relay/model/message.go
@@ -0,0 +1,88 @@
+package model
+
+type Message struct {
+ Role string `json:"role"`
+ Content any `json:"content"`
+ Name *string `json:"name,omitempty"`
+}
+
+func (m Message) IsStringContent() bool {
+ _, ok := m.Content.(string)
+ return ok
+}
+
+func (m Message) StringContent() string {
+ content, ok := m.Content.(string)
+ if ok {
+ return content
+ }
+ contentList, ok := m.Content.([]any)
+ if ok {
+ var contentStr string
+ for _, contentItem := range contentList {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+ return ""
+}
+
+func (m Message) ParseContent() []MessageContent {
+ var contentList []MessageContent
+ content, ok := m.Content.(string)
+ if ok {
+ contentList = append(contentList, MessageContent{
+ Type: ContentTypeText,
+ Text: content,
+ })
+ return contentList
+ }
+ anyList, ok := m.Content.([]any)
+ if ok {
+ for _, contentItem := range anyList {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ switch contentMap["type"] {
+ case ContentTypeText:
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentList = append(contentList, MessageContent{
+ Type: ContentTypeText,
+ Text: subStr,
+ })
+ }
+ case ContentTypeImageURL:
+ if subObj, ok := contentMap["image_url"].(map[string]any); ok {
+ contentList = append(contentList, MessageContent{
+ Type: ContentTypeImageURL,
+ ImageURL: &ImageURL{
+ Url: subObj["url"].(string),
+ },
+ })
+ }
+ }
+ }
+ return contentList
+ }
+ return nil
+}
+
+type ImageURL struct {
+ Url string `json:"url,omitempty"`
+ Detail string `json:"detail,omitempty"`
+}
+
+type MessageContent struct {
+ Type string `json:"type,omitempty"`
+ Text string `json:"text"`
+ ImageURL *ImageURL `json:"image_url,omitempty"`
+}
diff --git a/relay/model/misc.go b/relay/model/misc.go
new file mode 100644
index 00000000..163bc398
--- /dev/null
+++ b/relay/model/misc.go
@@ -0,0 +1,19 @@
+package model
+
+type Usage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+type Error struct {
+ Message string `json:"message"`
+ Type string `json:"type"`
+ Param string `json:"param"`
+ Code any `json:"code"`
+}
+
+type ErrorWithStatusCode struct {
+ Error
+ StatusCode int `json:"status_code"`
+}
diff --git a/relay/util/common.go b/relay/util/common.go
index 3a28b09e..6d993378 100644
--- a/relay/util/common.go
+++ b/relay/util/common.go
@@ -8,7 +8,7 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
@@ -17,7 +17,7 @@ import (
"github.com/gin-gonic/gin"
)
-func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
+func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
if !config.AutomaticDisableChannelEnabled {
return false
}
@@ -33,7 +33,7 @@ func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
return false
}
-func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
+func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool {
if !config.AutomaticEnableChannelEnabled {
return false
}
@@ -47,11 +47,11 @@ func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
}
type GeneralErrorResponse struct {
- Error openai.Error `json:"error"`
- Message string `json:"message"`
- Msg string `json:"msg"`
- Err string `json:"err"`
- ErrorMsg string `json:"error_msg"`
+ Error relaymodel.Error `json:"error"`
+ Message string `json:"message"`
+ Msg string `json:"msg"`
+ Err string `json:"err"`
+ ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
@@ -87,10 +87,10 @@ func (e GeneralErrorResponse) ToMessage() string {
return ""
}
-func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) {
- ErrorWithStatusCode = &openai.ErrorWithStatusCode{
+func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) {
+ ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
- Error: openai.Error{
+ Error: relaymodel.Error{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
@@ -162,7 +162,7 @@ func GetAzureAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
- apiVersion = c.GetString("api_version")
+ apiVersion = c.GetString(common.ConfigKeyAPIVersion)
}
return apiVersion
}
diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go
index 27757dcf..31b9d2b4 100644
--- a/relay/util/relay_meta.go
+++ b/relay/util/relay_meta.go
@@ -8,35 +8,41 @@ import (
)
type RelayMeta struct {
- Mode int
- ChannelType int
- ChannelId int
- TokenId int
- TokenName string
- UserId int
- Group string
- ModelMapping map[string]string
- BaseURL string
- APIVersion string
- APIKey string
- APIType int
- Config map[string]string
+ Mode int
+ ChannelType int
+ ChannelId int
+ TokenId int
+ TokenName string
+ UserId int
+ Group string
+ ModelMapping map[string]string
+ BaseURL string
+ APIVersion string
+ APIKey string
+ APIType int
+ Config map[string]string
+ IsStream bool
+ OriginModelName string
+ ActualModelName string
+ RequestURLPath string
+ PromptTokens int // only for DoResponse
}
func GetRelayMeta(c *gin.Context) *RelayMeta {
meta := RelayMeta{
- Mode: constant.Path2RelayMode(c.Request.URL.Path),
- ChannelType: c.GetInt("channel"),
- ChannelId: c.GetInt("channel_id"),
- TokenId: c.GetInt("token_id"),
- TokenName: c.GetString("token_name"),
- UserId: c.GetInt("id"),
- Group: c.GetString("group"),
- ModelMapping: c.GetStringMapString("model_mapping"),
- BaseURL: c.GetString("base_url"),
- APIVersion: c.GetString("api_version"),
- APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
- Config: nil,
+ Mode: constant.Path2RelayMode(c.Request.URL.Path),
+ ChannelType: c.GetInt("channel"),
+ ChannelId: c.GetInt("channel_id"),
+ TokenId: c.GetInt("token_id"),
+ TokenName: c.GetString("token_name"),
+ UserId: c.GetInt("id"),
+ Group: c.GetString("group"),
+ ModelMapping: c.GetStringMapString("model_mapping"),
+ BaseURL: c.GetString("base_url"),
+ APIVersion: c.GetString(common.ConfigKeyAPIVersion),
+ APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+ Config: nil,
+ RequestURLPath: c.Request.URL.String(),
}
if meta.ChannelType == common.ChannelTypeAzure {
meta.APIVersion = GetAzureAPIVersion(c)
diff --git a/relay/util/validation.go b/relay/util/validation.go
index 8848af8e..ef8d840c 100644
--- a/relay/util/validation.go
+++ b/relay/util/validation.go
@@ -2,12 +2,12 @@ package util
import (
"errors"
- "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"math"
)
-func ValidateTextRequest(textRequest *openai.GeneralOpenAIRequest, relayMode int) error {
+func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error {
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
return errors.New("max_tokens is invalid")
}
diff --git a/web/THEMES b/web/THEMES
index b6597eeb..6b0157cb 100644
--- a/web/THEMES
+++ b/web/THEMES
@@ -1,2 +1,2 @@
default
-berry
\ No newline at end of file
+berry
diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js
index 3ce27838..aeff5190 100644
--- a/web/berry/src/constants/ChannelConstants.js
+++ b/web/berry/src/constants/ChannelConstants.js
@@ -59,6 +59,12 @@ export const CHANNEL_OPTIONS = {
value: 19,
color: 'default'
},
+ 25: {
+ key: 25,
+ text: 'Moonshot AI',
+ value: 19,
+ color: 'default'
+ },
23: {
key: 23,
text: '腾讯混元',
diff --git a/web/berry/src/views/Channel/index.js b/web/berry/src/views/Channel/index.js
index 5b7f1722..39ab5d82 100644
--- a/web/berry/src/views/Channel/index.js
+++ b/web/berry/src/views/Channel/index.js
@@ -202,9 +202,7 @@ export default function ChannelPage() {
- 当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo
- 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。 另外,OpenAI 渠道已经不再支持通过 key
- 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
+ OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
@@ -229,9 +227,9 @@ export default function ChannelPage() {
}>
测试启用渠道
- }>
- 更新启用余额
-
+ {/*}>*/}
+ {/* 更新启用余额*/}
+ {/**/}
}>
删除禁用渠道
diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js
index d270f527..a091c8d6 100644
--- a/web/berry/src/views/Channel/type/Config.js
+++ b/web/berry/src/views/Channel/type/Config.js
@@ -94,7 +94,13 @@ const typeConfig = {
other: "版本号",
},
input: {
- models: ["SparkDesk"],
+ models: [
+ "SparkDesk",
+ 'SparkDesk-v1.1',
+ 'SparkDesk-v2.1',
+ 'SparkDesk-v3.1',
+ 'SparkDesk-v3.5'
+ ],
},
prompt: {
key: "按照如下格式输入:APPID|APISecret|APIKey",
diff --git a/web/build.sh b/web/build.sh
index b3751ff4..b59babe4 100644
--- a/web/build.sh
+++ b/web/build.sh
@@ -1,13 +1,13 @@
#!/bin/sh
version=$(cat VERSION)
-themes=$(cat THEMES)
-IFS=$'\n'
+pwd
-for theme in $themes; do
+while IFS= read -r theme; do
echo "Building theme: $theme"
- cd $theme
+ rm -r build/$theme
+ cd "$theme"
npm install
DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$version npm run build
cd ..
-done
+done < THEMES
diff --git a/web/default/src/components/ChannelsTable.js b/web/default/src/components/ChannelsTable.js
index a2adfd32..7117fe53 100644
--- a/web/default/src/components/ChannelsTable.js
+++ b/web/default/src/components/ChannelsTable.js
@@ -322,10 +322,7 @@ const ChannelsTable = () => {
setShowPrompt(false);
setPromptShown("channel-test");
}}>
- 当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo
- 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。
-
- 另外,OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
+ OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
)
}
@@ -525,8 +522,8 @@ const ChannelsTable = () => {
-
+ {/**/}
diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js
index 264dbefb..16da1b97 100644
--- a/web/default/src/constants/channel.constants.js
+++ b/web/default/src/constants/channel.constants.js
@@ -9,6 +9,7 @@ export const CHANNEL_OPTIONS = [
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
+ { key: 25, text: 'Moonshot AI', value: 25, color: 'black' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js
index 0d4e114d..4f4633ff 100644
--- a/web/default/src/pages/Channel/EditChannel.js
+++ b/web/default/src/pages/Channel/EditChannel.js
@@ -82,7 +82,13 @@ const EditChannel = () => {
localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
- localModels = ['SparkDesk'];
+ localModels = [
+ 'SparkDesk',
+ 'SparkDesk-v1.1',
+ 'SparkDesk-v2.1',
+ 'SparkDesk-v3.1',
+ 'SparkDesk-v3.5'
+ ];
break;
case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
@@ -93,6 +99,9 @@ const EditChannel = () => {
case 24:
localModels = ['gemini-pro', 'gemini-pro-vision'];
break;
+ case 25:
+ localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'];
+ break;
}
setInputs((inputs) => ({ ...inputs, models: localModels }));
}