From 9026ec7510f6dd7500e55567ce9de5f28f71b259 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 26 Apr 2024 23:05:48 +0800 Subject: [PATCH] feat: support cloudflare now --- README.md | 1 + common/ctxkey/config.go | 13 -- common/ctxkey/key.go | 1 + common/helper/helper.go | 6 + common/helper/key.go | 5 + common/logger/constants.go | 4 - common/logger/logger.go | 2 +- controller/channel-test.go | 2 + controller/relay.go | 2 +- middleware/distributor.go | 26 +-- middleware/logger.go | 4 +- middleware/request-id.go | 7 +- middleware/utils.go | 2 +- model/channel.go | 18 ++- relay/adaptor.go | 3 + relay/adaptor/aiproxy/adaptor.go | 6 +- relay/adaptor/ali/adaptor.go | 8 +- relay/adaptor/aws/adapter.go | 16 +- relay/adaptor/aws/main.go | 27 +--- relay/adaptor/azure/helper.go | 15 -- relay/adaptor/cloudflare/adaptor.go | 66 ++++++++ relay/adaptor/cloudflare/constant.go | 36 +++++ relay/adaptor/cloudflare/main.go | 152 ++++++++++++++++++ relay/adaptor/cloudflare/model.go | 25 +++ relay/adaptor/coze/adaptor.go | 6 +- relay/adaptor/gemini/adaptor.go | 2 +- relay/adaptor/openai/adaptor.go | 4 +- relay/adaptor/openai/model.go | 2 +- relay/adaptor/xunfei/adaptor.go | 15 +- relay/adaptor/xunfei/main.go | 32 ++-- relay/apitype/define.go | 1 + relay/channeltype/define.go | 1 + relay/channeltype/helper.go | 2 + relay/channeltype/url.go | 1 + relay/controller/audio.go | 5 +- relay/controller/image.go | 1 + relay/controller/text.go | 1 + relay/meta/relay_meta.go | 35 ++-- .../src/constants/channel.constants.js | 73 ++++----- web/default/src/pages/Channel/EditChannel.js | 15 ++ 40 files changed, 464 insertions(+), 179 deletions(-) delete mode 100644 common/ctxkey/config.go create mode 100644 common/helper/key.go delete mode 100644 relay/adaptor/azure/helper.go create mode 100644 relay/adaptor/cloudflare/adaptor.go create mode 100644 relay/adaptor/cloudflare/constant.go create mode 100644 relay/adaptor/cloudflare/main.go create mode 100644 relay/adaptor/cloudflare/model.go diff --git a/README.md b/README.md index 01236a43..62834fb8 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [Coze](https://www.coze.com/) + [x] [Cohere](https://cohere.com/) + [x] [DeepSeek](https://www.deepseek.com/) + + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 diff --git a/common/ctxkey/config.go b/common/ctxkey/config.go deleted file mode 100644 index 69e8a27a..00000000 --- a/common/ctxkey/config.go +++ /dev/null @@ -1,13 +0,0 @@ -package ctxkey - -const ( - ConfigPrefix = "cfg_" - - ConfigAPIVersion = ConfigPrefix + "api_version" - ConfigLibraryID = ConfigPrefix + "library_id" - ConfigPlugin = ConfigPrefix + "plugin" - ConfigSK = ConfigPrefix + "sk" - ConfigAK = ConfigPrefix + "ak" - ConfigRegion = ConfigPrefix + "region" - ConfigUserID = ConfigPrefix + "user_id" -) diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 568cb095..6c640870 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -1,6 +1,7 @@ package ctxkey const ( + Config = "config" Id = "id" Username = "username" Role = "role" diff --git a/common/helper/helper.go b/common/helper/helper.go index cf2e1635..e06dfb6e 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -2,6 +2,7 @@ package helper import ( "fmt" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/random" "html/template" "log" @@ -105,6 +106,11 @@ func GenRequestID() string { return GetTimeString() + random.GetRandomNumberString(8) } +func GetResponseID(c *gin.Context) string { + logID := c.GetString(RequestIdKey) + return fmt.Sprintf("chatcmpl-%s", logID) +} + func Max(a int, b int) int { if a >= b { return a diff --git a/common/helper/key.go b/common/helper/key.go new file mode 100644 index 00000000..17aee2e0 --- /dev/null +++ b/common/helper/key.go @@ -0,0 +1,5 @@ +package helper + +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) diff --git a/common/logger/constants.go b/common/logger/constants.go index 78d32062..49df31ec 100644 --- a/common/logger/constants.go +++ b/common/logger/constants.go @@ -1,7 +1,3 @@ package logger -const ( - RequestIdKey = "X-Oneapi-Request-Id" -) - var LogDir string diff --git a/common/logger/logger.go b/common/logger/logger.go index 858e33e2..c3dcd89d 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -87,7 +87,7 @@ func logHelper(ctx context.Context, level string, msg string) { if level == loggerINFO { writer = gin.DefaultWriter } - id := ctx.Value(RequestIdKey) + id := ctx.Value(helper.RequestIdKey) if id == nil { id = helper.GenRequestID() } diff --git a/controller/channel-test.go b/controller/channel-test.go index a84dc797..a9f03c45 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -57,6 +57,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error c.Request.Header.Set("Content-Type", "application/json") c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.BaseURL, channel.GetBaseURL()) + cfg, _ := channel.LoadConfig() + c.Set(ctxkey.Config, cfg) middleware.SetupContextForSelectedChannel(c, channel, "") meta := meta.GetByContext(c) apiType := channeltype.ToAPIType(channel.Type) diff --git a/controller/relay.go b/controller/relay.go index 5fd22f85..aba4cd94 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -57,7 +57,7 @@ func Relay(c *gin.Context) { group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) go processChannelRelayError(ctx, channelId, channelName, bizErr) - requestId := c.GetString(logger.RequestIdKey) + requestId := c.GetString(helper.RequestIdKey) retryTimes := config.RetryTimes if !shouldRetry(c, bizErr.StatusCode) { logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) diff --git a/middleware/distributor.go b/middleware/distributor.go index a4c34085..d0fd7ba5 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -65,21 +65,29 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set(ctxkey.OriginalModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set(ctxkey.BaseURL, channel.GetBaseURL()) + cfg, _ := channel.LoadConfig() // this is for backward compatibility switch channel.Type { case channeltype.Azure: - c.Set(ctxkey.ConfigAPIVersion, channel.Other) + if cfg.APIVersion == "" { + cfg.APIVersion = channel.Other + } case channeltype.Xunfei: - c.Set(ctxkey.ConfigAPIVersion, channel.Other) + if cfg.APIVersion == "" { + cfg.APIVersion = channel.Other + } case channeltype.Gemini: - c.Set(ctxkey.ConfigAPIVersion, channel.Other) + if cfg.APIVersion == "" { + cfg.APIVersion = channel.Other + } case channeltype.AIProxyLibrary: - c.Set(ctxkey.ConfigLibraryID, channel.Other) + if cfg.LibraryID == "" { + cfg.LibraryID = channel.Other + } case channeltype.Ali: - c.Set(ctxkey.ConfigPlugin, channel.Other) - } - cfg, _ := channel.LoadConfig() - for k, v := range cfg { - c.Set(ctxkey.ConfigPrefix+k, v) + if cfg.Plugin == "" { + cfg.Plugin = channel.Other + } } + c.Set(ctxkey.Config, cfg) } diff --git a/middleware/logger.go b/middleware/logger.go index 6aae4f23..191364f8 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -3,14 +3,14 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/helper" ) func SetUpLogger(server *gin.Engine) { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { var requestID string if param.Keys != nil { - requestID = param.Keys[logger.RequestIdKey].(string) + requestID = param.Keys[helper.RequestIdKey].(string) } return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", param.TimeStamp.Format("2006/01/02 - 15:04:05"), diff --git a/middleware/request-id.go b/middleware/request-id.go index a4c49ddb..bef09e32 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -4,16 +4,15 @@ import ( "context" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" - "github.com/songquanpeng/one-api/common/logger" ) func RequestId() func(c *gin.Context) { return func(c *gin.Context) { id := helper.GenRequestID() - c.Set(logger.RequestIdKey, id) - ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) + c.Set(helper.RequestIdKey, id) + ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id) c.Request = c.Request.WithContext(ctx) - c.Header(logger.RequestIdKey, id) + c.Header(helper.RequestIdKey, id) c.Next() } } diff --git a/middleware/utils.go b/middleware/utils.go index b65b018b..4d2f8092 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -12,7 +12,7 @@ import ( func abortWithMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ - "message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), + "message": helper.MessageWithRequestId(message, c.GetString(helper.RequestIdKey)), "type": "one_api_error", }, }) diff --git a/model/channel.go b/model/channel.go index e667f7e7..ec52683e 100644 --- a/model/channel.go +++ b/model/channel.go @@ -38,6 +38,16 @@ type Channel struct { Config string `json:"config"` } +type ChannelConfig struct { + Region string `json:"region,omitempty"` + SK string `json:"sk,omitempty"` + AK string `json:"ak,omitempty"` + UserID string `json:"user_id,omitempty"` + APIVersion string `json:"api_version,omitempty"` + LibraryID string `json:"library_id,omitempty"` + Plugin string `json:"plugin,omitempty"` +} + func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { var channels []*Channel var err error @@ -161,14 +171,14 @@ func (channel *Channel) Delete() error { return err } -func (channel *Channel) LoadConfig() (map[string]string, error) { +func (channel *Channel) LoadConfig() (ChannelConfig, error) { + var cfg ChannelConfig if channel.Config == "" { - return nil, nil + return cfg, nil } - cfg := make(map[string]string) err := json.Unmarshal([]byte(channel.Config), &cfg) if err != nil { - return nil, err + return cfg, err } return cfg, nil } diff --git a/relay/adaptor.go b/relay/adaptor.go index 293b6d79..87021a04 100644 --- a/relay/adaptor.go +++ b/relay/adaptor.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/anthropic" "github.com/songquanpeng/one-api/relay/adaptor/aws" "github.com/songquanpeng/one-api/relay/adaptor/baidu" + "github.com/songquanpeng/one-api/relay/adaptor/cloudflare" "github.com/songquanpeng/one-api/relay/adaptor/cohere" "github.com/songquanpeng/one-api/relay/adaptor/coze" "github.com/songquanpeng/one-api/relay/adaptor/gemini" @@ -49,6 +50,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor { return &coze.Adaptor{} case apitype.Cohere: return &cohere.Adaptor{} + case apitype.Cloudflare: + return &cloudflare.Adaptor{} } return nil } diff --git a/relay/adaptor/aiproxy/adaptor.go b/relay/adaptor/aiproxy/adaptor.go index a446f026..42d49c0a 100644 --- a/relay/adaptor/aiproxy/adaptor.go +++ b/relay/adaptor/aiproxy/adaptor.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" @@ -13,10 +12,11 @@ import ( ) type Adaptor struct { + meta *meta.Meta } func (a *Adaptor) Init(meta *meta.Meta) { - + a.meta = meta } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { @@ -34,7 +34,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } aiProxyLibraryRequest := ConvertRequest(*request) - aiProxyLibraryRequest.LibraryId = c.GetString(ctxkey.ConfigLibraryID) + aiProxyLibraryRequest.LibraryId = a.meta.Config.LibraryID return aiProxyLibraryRequest, nil } diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go index 8e7220ff..4aa8a11a 100644 --- a/relay/adaptor/ali/adaptor.go +++ b/relay/adaptor/ali/adaptor.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" @@ -16,10 +15,11 @@ import ( // https://help.aliyun.com/zh/dashscope/developer-reference/api-details type Adaptor struct { + meta *meta.Meta } func (a *Adaptor) Init(meta *meta.Meta) { - + a.meta = meta } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { @@ -47,8 +47,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me if meta.Mode == relaymode.ImagesGenerations { req.Header.Set("X-DashScope-Async", "enable") } - if c.GetString(ctxkey.ConfigPlugin) != "" { - req.Header.Set("X-DashScope-Plugin", c.GetString(ctxkey.ConfigPlugin)) + if a.meta.Config.Plugin != "" { + req.Header.Set("X-DashScope-Plugin", a.meta.Config.Plugin) } return nil } diff --git a/relay/adaptor/aws/adapter.go b/relay/adaptor/aws/adapter.go index 7f064efe..7245d3d9 100644 --- a/relay/adaptor/aws/adapter.go +++ b/relay/adaptor/aws/adapter.go @@ -1,6 +1,9 @@ package aws import ( + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/songquanpeng/one-api/common/ctxkey" "io" "net/http" @@ -16,10 +19,16 @@ import ( var _ adaptor.Adaptor = new(Adaptor) type Adaptor struct { + meta *meta.Meta + awsClient *bedrockruntime.Client } func (a *Adaptor) Init(meta *meta.Meta) { - + a.meta = meta + a.awsClient = bedrockruntime.New(bedrockruntime.Options{ + Region: meta.Config.Region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), + }) } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { @@ -54,9 +63,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { - err, usage = StreamHandler(c, resp) + err, usage = StreamHandler(c, a.awsClient) } else { - err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + err, usage = Handler(c, a.awsClient, meta.ActualModelName) } return } @@ -65,7 +74,6 @@ func (a *Adaptor) GetModelList() (models []string) { for n := range awsModelIDMap { models = append(models, n) } - return } diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go index 3db38d22..0776f985 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/main.go @@ -10,7 +10,6 @@ import ( "net/http" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" "github.com/gin-gonic/gin" @@ -23,18 +22,6 @@ import ( relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func newAwsClient(c *gin.Context) (*bedrockruntime.Client, error) { - ak := c.GetString(ctxkey.ConfigAK) - sk := c.GetString(ctxkey.ConfigSK) - region := c.GetString(ctxkey.ConfigRegion) - client := bedrockruntime.New(bedrockruntime.Options{ - Region: region, - Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")), - }) - - return client, nil -} - func wrapErr(err error) *relaymodel.ErrorWithStatusCode { return &relaymodel.ErrorWithStatusCode{ StatusCode: http.StatusInternalServerError, @@ -62,12 +49,7 @@ func awsModelID(requestModel string) (string, error) { return "", errors.Errorf("model %s not found", requestModel) } -func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { - awsCli, err := newAwsClient(c) - if err != nil { - return wrapErr(errors.Wrap(err, "newAwsClient")), nil - } - +func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return wrapErr(errors.Wrap(err, "awsModelID")), nil @@ -120,13 +102,8 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st return nil, &usage } -func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { +func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { createdTime := helper.GetTimestamp() - awsCli, err := newAwsClient(c) - if err != nil { - return wrapErr(errors.Wrap(err, "newAwsClient")), nil - } - awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return wrapErr(errors.Wrap(err, "awsModelID")), nil diff --git a/relay/adaptor/azure/helper.go b/relay/adaptor/azure/helper.go deleted file mode 100644 index 26443bc4..00000000 --- a/relay/adaptor/azure/helper.go +++ /dev/null @@ -1,15 +0,0 @@ -package azure - -import ( - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/ctxkey" -) - -func GetAPIVersion(c *gin.Context) string { - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString(ctxkey.ConfigAPIVersion) - } - return apiVersion -} diff --git a/relay/adaptor/cloudflare/adaptor.go b/relay/adaptor/cloudflare/adaptor.go new file mode 100644 index 00000000..6ff6b0d3 --- /dev/null +++ b/relay/adaptor/cloudflare/adaptor.go @@ -0,0 +1,66 @@ +package cloudflare + +import ( + "errors" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +type Adaptor struct { + meta *meta.Meta +} + +// ConvertImageRequest implements adaptor.Adaptor. +func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertImageRequest implements adaptor.Adaptor. + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return ConvertRequest(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp, meta.PromptTokens, meta.ActualModelName) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "cloudflare" +} diff --git a/relay/adaptor/cloudflare/constant.go b/relay/adaptor/cloudflare/constant.go new file mode 100644 index 00000000..dee79a76 --- /dev/null +++ b/relay/adaptor/cloudflare/constant.go @@ -0,0 +1,36 @@ +package cloudflare + +var ModelList = []string{ + "@cf/meta/llama-2-7b-chat-fp16", + "@cf/meta/llama-2-7b-chat-int8", + "@cf/mistral/mistral-7b-instruct-v0.1", + "@hf/thebloke/deepseek-coder-6.7b-base-awq", + "@hf/thebloke/deepseek-coder-6.7b-instruct-awq", + "@cf/deepseek-ai/deepseek-math-7b-base", + "@cf/deepseek-ai/deepseek-math-7b-instruct", + "@cf/thebloke/discolm-german-7b-v1-awq", + "@cf/tiiuae/falcon-7b-instruct", + "@cf/google/gemma-2b-it-lora", + "@hf/google/gemma-7b-it", + "@cf/google/gemma-7b-it-lora", + "@hf/nousresearch/hermes-2-pro-mistral-7b", + "@hf/thebloke/llama-2-13b-chat-awq", + "@cf/meta-llama/llama-2-7b-chat-hf-lora", + "@cf/meta/llama-3-8b-instruct", + "@hf/thebloke/llamaguard-7b-awq", + "@hf/thebloke/mistral-7b-instruct-v0.1-awq", + "@hf/mistralai/mistral-7b-instruct-v0.2", + "@cf/mistral/mistral-7b-instruct-v0.2-lora", + "@hf/thebloke/neural-chat-7b-v3-1-awq", + "@cf/openchat/openchat-3.5-0106", + "@hf/thebloke/openhermes-2.5-mistral-7b-awq", + "@cf/microsoft/phi-2", + "@cf/qwen/qwen1.5-0.5b-chat", + "@cf/qwen/qwen1.5-1.8b-chat", + "@cf/qwen/qwen1.5-14b-chat-awq", + "@cf/qwen/qwen1.5-7b-chat-awq", + "@cf/defog/sqlcoder-7b-2", + "@hf/nexusflow/starling-lm-7b-beta", + "@cf/tinyllama/tinyllama-1.1b-chat-v1.0", + "@hf/thebloke/zephyr-7b-beta-awq", +} diff --git a/relay/adaptor/cloudflare/main.go b/relay/adaptor/cloudflare/main.go new file mode 100644 index 00000000..e85bbc25 --- /dev/null +++ b/relay/adaptor/cloudflare/main.go @@ -0,0 +1,152 @@ +package cloudflare + +import ( + "bufio" + "bytes" + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + lastMessage := textRequest.Messages[len(textRequest.Messages)-1] + return &Request{ + MaxTokens: textRequest.MaxTokens, + Prompt: lastMessage.StringContent(), + Stream: textRequest.Stream, + Temperature: textRequest.Temperature, + } +} + +func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: "assistant", + Content: cloudflareResponse.Result.Response, + }, + FinishReason: "stop", + } + fullTextResponse := openai.TextResponse{ + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = cloudflareResponse.Response + choice.Delta.Role = "assistant" + openaiResponse := openai.ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + Created: helper.GetTimestamp(), + } + return &openaiResponse +} + +func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, '\n'); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + dataChan <- data + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + id := helper.GetResponseID(c) + responseModel := c.GetString("original_model") + var responseText string + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var cloudflareResponse StreamResponse + err := json.Unmarshal([]byte(data), &cloudflareResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) + if response == nil { + return true + } + responseText += cloudflareResponse.Response + response.Id = id + response.Model = responseModel + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + _ = resp.Body.Close() + usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens) + return nil, usage +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var cloudflareResponse Response + err = json.Unmarshal(responseBody, &cloudflareResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse) + fullTextResponse.Model = modelName + usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens) + fullTextResponse.Usage = *usage + fullTextResponse.Id = helper.GetResponseID(c) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, usage +} diff --git a/relay/adaptor/cloudflare/model.go b/relay/adaptor/cloudflare/model.go new file mode 100644 index 00000000..0664ecd1 --- /dev/null +++ b/relay/adaptor/cloudflare/model.go @@ -0,0 +1,25 @@ +package cloudflare + +type Request struct { + Lora string `json:"lora,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Prompt string `json:"prompt,omitempty"` + Raw bool `json:"raw,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +type Result struct { + Response string `json:"response"` +} + +type Response struct { + Result Result `json:"result"` + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` +} + +type StreamResponse struct { + Response string `json:"response"` +} diff --git a/relay/adaptor/coze/adaptor.go b/relay/adaptor/coze/adaptor.go index 49979ef6..44f560e8 100644 --- a/relay/adaptor/coze/adaptor.go +++ b/relay/adaptor/coze/adaptor.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" @@ -14,10 +13,11 @@ import ( ) type Adaptor struct { + meta *meta.Meta } func (a *Adaptor) Init(meta *meta.Meta) { - + a.meta = meta } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { @@ -34,7 +34,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - request.User = c.GetString(ctxkey.ConfigUserID) + request.User = a.meta.Config.UserID return ConvertRequest(*request), nil } diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go index 6a2867e4..839e45d6 100644 --- a/relay/adaptor/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -22,7 +22,7 @@ func (a *Adaptor) Init(meta *meta.Meta) { } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { - version := helper.AssignOrDefault(meta.APIVersion, config.GeminiVersion) + version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) action := "generateContent" if meta.IsStream { action = "streamGenerateContent" diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 4bb2384e..57940558 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -29,13 +29,13 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { if meta.Mode == relaymode.ImagesGenerations { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview - fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion) + fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion) return fullRequestURL, nil } // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api requestURL := strings.Split(meta.RequestURLPath, "?")[0] - requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) + requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.Config.APIVersion) task := strings.TrimPrefix(requestURL, "/v1/") model_ := meta.ActualModelName model_ = strings.Replace(model_, ".", "", -1) diff --git a/relay/adaptor/openai/model.go b/relay/adaptor/openai/model.go index ce252ff6..4c974de4 100644 --- a/relay/adaptor/openai/model.go +++ b/relay/adaptor/openai/model.go @@ -134,7 +134,7 @@ type ChatCompletionsStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"` - Usage *model.Usage `json:"usage"` + Usage *model.Usage `json:"usage,omitempty"` } type CompletionsStreamResponse struct { diff --git a/relay/adaptor/xunfei/adaptor.go b/relay/adaptor/xunfei/adaptor.go index edcd719f..3af97831 100644 --- a/relay/adaptor/xunfei/adaptor.go +++ b/relay/adaptor/xunfei/adaptor.go @@ -14,10 +14,11 @@ import ( type Adaptor struct { request *model.GeneralOpenAIRequest + meta *meta.Meta } func (a *Adaptor) Init(meta *meta.Meta) { - + a.meta = meta } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { @@ -26,6 +27,14 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { adaptor.SetupCommonRequestHeader(c, req, meta) + version := parseAPIVersionByModelName(meta.ActualModelName) + if version == "" { + version = a.meta.Config.APIVersion + } + if version == "" { + version = "v1.1" + } + a.meta.Config.APIVersion = version // check DoResponse for auth part return nil } @@ -61,9 +70,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) } if meta.IsStream { - err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2]) + err, usage = StreamHandler(c, meta, *a.request, splits[0], splits[1], splits[2]) } else { - err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2]) + err, usage = Handler(c, meta, *a.request, splits[0], splits[1], splits[2]) } return } diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index 70a926fd..c3e768b7 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -9,12 +9,12 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "io" "net/http" @@ -149,8 +149,8 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { return callUrl } -func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { - domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) +func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { + domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil @@ -179,8 +179,8 @@ func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId return nil, &usage } -func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { - domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) +func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { + domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil @@ -268,25 +268,12 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, return dataChan, stopChan, nil } -func getAPIVersion(c *gin.Context, modelName string) string { - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion != "" { - return apiVersion - } +func parseAPIVersionByModelName(modelName string) string { parts := strings.Split(modelName, "-") if len(parts) == 2 { - apiVersion = parts[1] - return apiVersion - + return parts[1] } - apiVersion = c.GetString(ctxkey.ConfigAPIVersion) - if apiVersion != "" { - return apiVersion - } - apiVersion = "v1.1" - logger.SysLog("api_version not found, using default: " + apiVersion) - return apiVersion + return "" } // https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E @@ -304,8 +291,7 @@ func apiVersion2domain(apiVersion string) string { return "general" + apiVersion } -func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) { - apiVersion := getAPIVersion(c, modelName) +func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { domain := apiVersion2domain(apiVersion) authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) return domain, authUrl diff --git a/relay/apitype/define.go b/relay/apitype/define.go index a1c8e6e1..e38eff7e 100644 --- a/relay/apitype/define.go +++ b/relay/apitype/define.go @@ -15,6 +15,7 @@ const ( AwsClaude Coze Cohere + Cloudflare Dummy // this one is only for count, do not add any channel after this ) diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index 7f29afb3..3aa585a9 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -38,6 +38,7 @@ const ( Coze Cohere DeepSeek + Cloudflare Dummy ) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go index 42b77891..a608c80e 100644 --- a/relay/channeltype/helper.go +++ b/relay/channeltype/helper.go @@ -31,6 +31,8 @@ func ToAPIType(channelType int) int { apiType = apitype.Coze case Cohere: apiType = apitype.Cohere + case Cloudflare: + apiType = apitype.Cloudflare } return apiType diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index ea4dfb95..657b677e 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -38,6 +38,7 @@ var ChannelBaseURLs = []string{ "https://api.coze.com", // 34 "https://api.cohere.ai", // 35 "https://api.deepseek.com", // 36 + "https://api.cloudflare.com", // 37 } func init() { diff --git a/relay/controller/audio.go b/relay/controller/audio.go index db543318..15e74290 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -13,12 +13,12 @@ import ( "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/adaptor/azure" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/client" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" "io" @@ -28,6 +28,7 @@ import ( func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() + meta := meta.GetByContext(c) audioModel := "whisper-1" tokenId := c.GetInt(ctxkey.TokenId) @@ -128,7 +129,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType) if channelType == channeltype.Azure { - apiVersion := azure.GetAPIVersion(c) + apiVersion := meta.Config.APIVersion if relayMode == relaymode.AudioTranscription { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) diff --git a/relay/controller/image.go b/relay/controller/image.go index 216e4700..6620bef5 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -70,6 +70,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if adaptor == nil { return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) } + adaptor.Init(meta) switch meta.ChannelType { case channeltype.Ali: diff --git a/relay/controller/text.go b/relay/controller/text.go index 23e94234..9bfd3e76 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -53,6 +53,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { if adaptor == nil { return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) } + adaptor.Init(meta) // get request body var requestBody io.Reader diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index 0e8f72fe..9714ebb5 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -3,7 +3,7 @@ package meta import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/ctxkey" - "github.com/songquanpeng/one-api/relay/adaptor/azure" + "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/relaymode" "strings" @@ -19,10 +19,9 @@ type Meta struct { Group string ModelMapping map[string]string BaseURL string - APIVersion string APIKey string APIType int - Config map[string]string + Config model.ChannelConfig IsStream bool OriginModelName string ActualModelName string @@ -32,22 +31,22 @@ type Meta struct { func GetByContext(c *gin.Context) *Meta { meta := Meta{ - Mode: relaymode.GetByPath(c.Request.URL.Path), - ChannelType: c.GetInt(ctxkey.Channel), - ChannelId: c.GetInt(ctxkey.ChannelId), - TokenId: c.GetInt(ctxkey.TokenId), - TokenName: c.GetString(ctxkey.TokenName), - UserId: c.GetInt(ctxkey.Id), - Group: c.GetString(ctxkey.Group), - ModelMapping: c.GetStringMapString(ctxkey.ModelMapping), - BaseURL: c.GetString(ctxkey.BaseURL), - APIVersion: c.GetString(ctxkey.ConfigAPIVersion), - APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - Config: nil, - RequestURLPath: c.Request.URL.String(), + Mode: relaymode.GetByPath(c.Request.URL.Path), + ChannelType: c.GetInt(ctxkey.Channel), + ChannelId: c.GetInt(ctxkey.ChannelId), + TokenId: c.GetInt(ctxkey.TokenId), + TokenName: c.GetString(ctxkey.TokenName), + UserId: c.GetInt(ctxkey.Id), + Group: c.GetString(ctxkey.Group), + ModelMapping: c.GetStringMapString(ctxkey.ModelMapping), + OriginModelName: c.GetString(ctxkey.RequestModel), + BaseURL: c.GetString(ctxkey.BaseURL), + APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + RequestURLPath: c.Request.URL.String(), } - if meta.ChannelType == channeltype.Azure { - meta.APIVersion = azure.GetAPIVersion(c) + cfg, ok := c.Get(ctxkey.Config) + if ok { + meta.Config = cfg.(model.ChannelConfig) } if meta.BaseURL == "" { meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index ff124501..a689ef27 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -1,38 +1,39 @@ export const CHANNEL_OPTIONS = [ - { key: 1, text: 'OpenAI', value: 1, color: 'green' }, - { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, - { key: 33, text: 'AWS Claude', value: 33, color: 'black' }, - { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, - { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, - { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, - { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, - { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, - { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, - { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, - { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, - { key: 19, text: '360 智脑', value: 19, color: 'blue' }, - { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, - { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, - { key: 26, text: '百川大模型', value: 26, color: 'orange' }, - { key: 27, text: 'MiniMax', value: 27, color: 'red' }, - { key: 29, text: 'Groq', value: 29, color: 'orange' }, - { key: 30, text: 'Ollama', value: 30, color: 'black' }, - { key: 31, text: '零一万物', value: 31, color: 'green' }, - { key: 32, text: '阶跃星辰', value: 32, color: 'blue' }, - { key: 34, text: 'Coze', value: 34, color: 'blue' }, - { key: 35, text: 'Cohere', value: 35, color: 'blue' }, - { key: 36, text: 'DeepSeek', value: 36, color: 'black' }, - { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, - { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, - { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, - { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, - { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, - { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, - { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, - { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' }, - { key: 4, text: '代理:CloseAI', value: 4, color: 'teal' }, - { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, - { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, - { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, - { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } + {key: 1, text: 'OpenAI', value: 1, color: 'green'}, + {key: 14, text: 'Anthropic Claude', value: 14, color: 'black'}, + {key: 33, text: 'AWS Claude', value: 33, color: 'black'}, + {key: 3, text: 'Azure OpenAI', value: 3, color: 'olive'}, + {key: 11, text: 'Google PaLM2', value: 11, color: 'orange'}, + {key: 24, text: 'Google Gemini', value: 24, color: 'orange'}, + {key: 28, text: 'Mistral AI', value: 28, color: 'orange'}, + {key: 15, text: '百度文心千帆', value: 15, color: 'blue'}, + {key: 17, text: '阿里通义千问', value: 17, color: 'orange'}, + {key: 18, text: '讯飞星火认知', value: 18, color: 'blue'}, + {key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet'}, + {key: 19, text: '360 智脑', value: 19, color: 'blue'}, + {key: 25, text: 'Moonshot AI', value: 25, color: 'black'}, + {key: 23, text: '腾讯混元', value: 23, color: 'teal'}, + {key: 26, text: '百川大模型', value: 26, color: 'orange'}, + {key: 27, text: 'MiniMax', value: 27, color: 'red'}, + {key: 29, text: 'Groq', value: 29, color: 'orange'}, + {key: 30, text: 'Ollama', value: 30, color: 'black'}, + {key: 31, text: '零一万物', value: 31, color: 'green'}, + {key: 32, text: '阶跃星辰', value: 32, color: 'blue'}, + {key: 34, text: 'Coze', value: 34, color: 'blue'}, + {key: 35, text: 'Cohere', value: 35, color: 'blue'}, + {key: 36, text: 'DeepSeek', value: 36, color: 'black'}, + {key: 37, text: 'Cloudflare', value: 37, color: 'orange'}, + {key: 8, text: '自定义渠道', value: 8, color: 'pink'}, + {key: 22, text: '知识库:FastGPT', value: 22, color: 'blue'}, + {key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple'}, + {key: 20, text: '代理:OpenRouter', value: 20, color: 'black'}, + {key: 2, text: '代理:API2D', value: 2, color: 'blue'}, + {key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown'}, + {key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple'}, + {key: 10, text: '代理:AI Proxy', value: 10, color: 'purple'}, + {key: 4, text: '代理:CloseAI', value: 4, color: 'teal'}, + {key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet'}, + {key: 9, text: '代理:AI.LS', value: 9, color: 'yellow'}, + {key: 12, text: '代理:API2GPT', value: 12, color: 'blue'}, + {key: 13, text: '代理:AIGC2D', value: 13, color: 'purple'} ]; diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index ffc9fc5f..5c7f13ff 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -488,6 +488,21 @@ const EditChannel = () => { /> ) } + { + inputs.type === 37 && ( + + + + ) + } { inputs.type !== 33 && !isEdit && (