diff --git a/README.md b/README.md index 40f6e4e0..167fe5ba 100644 --- a/README.md +++ b/README.md @@ -384,14 +384,17 @@ graph LR + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 -18. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 -19. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 -20. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 -21. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 -22. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 -23. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 -24. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 -25. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 +18. `RELAY_PROXY`:设置后使用该代理来请求 API。 +19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 +20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 +21. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 +22. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 +23. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 +24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 +26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 +27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 +28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/client/init.go b/common/client/init.go new file mode 100644 index 00000000..f803cbf8 --- /dev/null +++ b/common/client/init.go @@ -0,0 +1,60 @@ +package client + +import ( + "fmt" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "net/http" + "net/url" + "time" +) + +var HTTPClient *http.Client +var ImpatientHTTPClient *http.Client +var UserContentRequestHTTPClient *http.Client + +func Init() { + if config.UserContentRequestProxy != "" { + logger.SysLog(fmt.Sprintf("using %s as proxy to fetch user content", config.UserContentRequestProxy)) + proxyURL, err := url.Parse(config.UserContentRequestProxy) + if err != nil { + logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy)) + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + UserContentRequestHTTPClient = &http.Client{ + Transport: transport, + Timeout: time.Second * time.Duration(config.UserContentRequestTimeout), + } + } else { + UserContentRequestHTTPClient = &http.Client{} + } + var transport http.RoundTripper + if config.RelayProxy != "" { + logger.SysLog(fmt.Sprintf("using %s as api relay proxy", config.RelayProxy)) + proxyURL, err := url.Parse(config.RelayProxy) + if err != nil { + logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy)) + } + transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + } + + if config.RelayTimeout == 0 { + HTTPClient = &http.Client{ + Transport: transport, + } + } else { + HTTPClient = &http.Client{ + Timeout: time.Duration(config.RelayTimeout) * time.Second, + Transport: transport, + } + } + + ImpatientHTTPClient = &http.Client{ + Timeout: 5 * time.Second, + Transport: transport, + } +} diff --git a/common/config/config.go b/common/config/config.go index 0864d844..539eeef4 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -144,3 +144,7 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") var GeminiVersion = env.String("GEMINI_VERSION", "v1") + +var RelayProxy = env.String("RELAY_PROXY", "") +var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") +var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) diff --git a/common/image/image.go b/common/image/image.go index 12f0adff..beebd0c6 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -3,6 +3,7 @@ package image import ( "bytes" "encoding/base64" + "github.com/songquanpeng/one-api/common/client" "image" _ "image/gif" _ "image/jpeg" @@ -19,7 +20,7 @@ import ( var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) func IsImageUrl(url string) (bool, error) { - resp, err := http.Head(url) + resp, err := client.UserContentRequestHTTPClient.Head(url) if err != nil { return false, err } @@ -34,7 +35,7 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) { if !isImage { return } - resp, err := http.Get(url) + resp, err := client.UserContentRequestHTTPClient.Get(url) if err != nil { return } diff --git a/controller/channel-billing.go b/controller/channel-billing.go index b7ac61fd..53592744 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -4,12 +4,12 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/channeltype" - "github.com/songquanpeng/one-api/relay/client" "io" "net/http" "strconv" diff --git a/controller/relay.go b/controller/relay.go index aba4cd94..5d8ac690 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -4,6 +4,9 @@ import ( "bytes" "context" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" @@ -16,8 +19,6 @@ import ( "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "net/http" ) // https://platform.openai.com/docs/api-reference/chat @@ -47,6 +48,7 @@ func Relay(c *gin.Context) { logger.Debugf(ctx, "request body: %s", string(requestBody)) } channelId := c.GetInt(ctxkey.ChannelId) + userId := c.GetInt("id") bizErr := relayHelper(c, relayMode) if bizErr == nil { monitor.Emit(channelId, true) @@ -56,7 +58,7 @@ func Relay(c *gin.Context) { channelName := c.GetString(ctxkey.ChannelName) group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) - go processChannelRelayError(ctx, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) requestId := c.GetString(helper.RequestIdKey) retryTimes := config.RetryTimes if !shouldRetry(c, bizErr.StatusCode) { @@ -83,7 +85,7 @@ func Relay(c *gin.Context) { channelId := c.GetInt(ctxkey.ChannelId) lastFailedChannelId = channelId channelName := c.GetString(ctxkey.ChannelName) - go processChannelRelayError(ctx, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) } if bizErr != nil { if bizErr.StatusCode == http.StatusTooManyRequests { @@ -115,8 +117,8 @@ func shouldRetry(c *gin.Context, statusCode int) bool { return true } -func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { - logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) +func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { + logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { monitor.DisableChannel(channelId, channelName, err.Message) diff --git a/controller/user.go b/controller/user.go index af90acf6..9ab37b5a 100644 --- a/controller/user.go +++ b/controller/user.go @@ -6,6 +6,8 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" "net/http" @@ -109,6 +111,7 @@ func Logout(c *gin.Context) { } func Register(c *gin.Context) { + ctx := c.Request.Context() if !config.RegisterEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了新用户注册", @@ -173,6 +176,28 @@ func Register(c *gin.Context) { }) return } + go func() { + err := user.ValidateAndFill() + if err != nil { + logger.Errorf(ctx, "user.ValidateAndFill failed: %w", err) + return + } + cleanToken := model.Token{ + UserId: user.Id, + Name: "default", + Key: random.GenerateKey(), + CreatedTime: helper.GetTimestamp(), + AccessedTime: helper.GetTimestamp(), + ExpiredTime: -1, + RemainQuota: -1, + UnlimitedQuota: true, + } + err = cleanToken.Insert() + if err != nil { + logger.Errorf(ctx, "cleanToken.Insert failed: %w", err) + return + } + }() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/main.go b/main.go index bdcdcd61..eb6f368c 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/controller" @@ -94,6 +95,7 @@ func main() { logger.SysLog("metric enabled, will disable channel if too much request failed") } openai.InitTokenEncoders() + client.Init() // Initialize HTTP server server := gin.New() diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go index 6df5ce84..b816e0f4 100644 --- a/relay/adaptor/baidu/main.go +++ b/relay/adaptor/baidu/main.go @@ -7,9 +7,9 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor/openai" - "github.com/songquanpeng/one-api/relay/client" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" diff --git a/relay/adaptor/common.go b/relay/adaptor/common.go index 82a5160e..8953d7a3 100644 --- a/relay/adaptor/common.go +++ b/relay/adaptor/common.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/client" + "github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/relay/meta" "io" "net/http" diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go index a4dcae93..12f48c71 100644 --- a/relay/adaptor/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -13,6 +13,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" ) type Adaptor struct { @@ -24,7 +25,14 @@ func (a *Adaptor) Init(meta *meta.Meta) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) - action := "generateContent" + action := "" + switch meta.Mode { + case relaymode.Embeddings: + action = "batchEmbedContents" + default: + action = "generateContent" + } + if meta.IsStream { action = "streamGenerateContent?alt=sse" } @@ -41,7 +49,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - return ConvertRequest(*request), nil + switch relayMode { + case relaymode.Embeddings: + geminiEmbeddingRequest := ConvertEmbeddingRequest(*request) + return geminiEmbeddingRequest, nil + default: + geminiRequest := ConvertRequest(*request) + return geminiRequest, nil + } } func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { @@ -61,7 +76,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met err, responseText = StreamHandler(c, resp) usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) } else { - err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } } return } diff --git a/relay/adaptor/gemini/constants.go b/relay/adaptor/gemini/constants.go index 32e7c240..f65e6bfc 100644 --- a/relay/adaptor/gemini/constants.go +++ b/relay/adaptor/gemini/constants.go @@ -4,5 +4,5 @@ package gemini var ModelList = []string{ "gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro", - "gemini-pro-vision", "gemini-1.0-pro-vision-001", + "gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004", } diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index faccc4cb..534b2708 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -134,6 +134,29 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { return &geminiRequest } +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest { + inputs := request.ParseInput() + requests := make([]EmbeddingRequest, len(inputs)) + model := fmt.Sprintf("models/%s", request.Model) + + for i, input := range inputs { + requests[i] = EmbeddingRequest{ + Model: model, + Content: ChatContent{ + Parts: []Part{ + { + Text: input, + }, + }, + }, + } + } + + return &BatchEmbeddingRequest{ + Requests: requests, + } +} + type ChatResponse struct { Candidates []ChatCandidate `json:"candidates"` PromptFeedback ChatPromptFeedback `json:"promptFeedback"` @@ -230,6 +253,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC return &response } +func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), + Model: "gemini-embedding", + Usage: model.Usage{TotalTokens: 0}, + } + for _, item := range response.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: 0, + Embedding: item.Values, + }) + } + return &openAIEmbeddingResponse +} + func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" scanner := bufio.NewScanner(resp.Body) @@ -337,3 +377,39 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st _, err = c.Writer.Write(jsonResponse) return nil, &usage } + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var geminiEmbeddingResponse EmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &geminiEmbeddingResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if geminiEmbeddingResponse.Error != nil { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: geminiEmbeddingResponse.Error.Message, + Type: "gemini_error", + Param: "", + Code: geminiEmbeddingResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := embeddingResponseGemini2OpenAI(&geminiEmbeddingResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go index 47b74fbc..f7179ea4 100644 --- a/relay/adaptor/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -7,6 +7,33 @@ type ChatRequest struct { Tools []ChatTools `json:"tools,omitempty"` } +type EmbeddingRequest struct { + Model string `json:"model"` + Content ChatContent `json:"content"` + TaskType string `json:"taskType,omitempty"` + Title string `json:"title,omitempty"` + OutputDimensionality int `json:"outputDimensionality,omitempty"` +} + +type BatchEmbeddingRequest struct { + Requests []EmbeddingRequest `json:"requests"` +} + +type EmbeddingData struct { + Values []float64 `json:"values"` +} + +type EmbeddingResponse struct { + Embeddings []EmbeddingData `json:"embeddings"` + Error *Error `json:"error,omitempty"` +} + +type Error struct { + Code int `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Status string `json:"status,omitempty"` +} + type InlineData struct { MimeType string `json:"mimeType"` Data string `json:"data"` diff --git a/relay/adaptor/openai/token.go b/relay/adaptor/openai/token.go index bb9c38a9..ddbfad86 100644 --- a/relay/adaptor/openai/token.go +++ b/relay/adaptor/openai/token.go @@ -24,6 +24,10 @@ func InitTokenEncoders() { logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) } defaultTokenEncoder = gpt35TokenEncoder + gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o") + if err != nil { + logger.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) + } gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") if err != nil { logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) @@ -31,6 +35,8 @@ func InitTokenEncoders() { for model := range billingratio.ModelRatio { if strings.HasPrefix(model, "gpt-3.5") { tokenEncoderMap[model] = gpt35TokenEncoder + } else if strings.HasPrefix(model, "gpt-4o") { + tokenEncoderMap[model] = gpt4oTokenEncoder } else if strings.HasPrefix(model, "gpt-4") { tokenEncoderMap[model] = gpt4TokenEncoder } else { diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go index 5a29cddc..ced0c667 100644 --- a/relay/billing/ratio/image.go +++ b/relay/billing/ratio/image.go @@ -49,3 +49,8 @@ var ImagePromptLengthLimitations = map[string]int{ "wanx-v1": 4000, "cogview-3": 833, } + +var ImageOriginModelName = map[string]string{ + "ali-stable-diffusion-xl": "stable-diffusion-xl", + "ali-stable-diffusion-v1.5": "stable-diffusion-v1.5", +} diff --git a/relay/client/init.go b/relay/client/init.go deleted file mode 100644 index 4b59cba7..00000000 --- a/relay/client/init.go +++ /dev/null @@ -1,24 +0,0 @@ -package client - -import ( - "github.com/songquanpeng/one-api/common/config" - "net/http" - "time" -) - -var HTTPClient *http.Client -var ImpatientHTTPClient *http.Client - -func init() { - if config.RelayTimeout == 0 { - HTTPClient = &http.Client{} - } else { - HTTPClient = &http.Client{ - Timeout: time.Duration(config.RelayTimeout) * time.Second, - } - } - - ImpatientHTTPClient = &http.Client{ - Timeout: 5 * time.Second, - } -} diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 15e74290..8f9708d0 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -9,6 +9,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" @@ -17,7 +18,6 @@ import ( "github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" - "github.com/songquanpeng/one-api/relay/client" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" diff --git a/relay/controller/image.go b/relay/controller/image.go index 6620bef5..691c7c0e 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -55,6 +55,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) } + imageModel := imageRequest.Model + // Convert the original image model + imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName) + c.Set("response_format", imageRequest.ResponseFormat) + var requestBody io.Reader if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) @@ -89,7 +94,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus requestBody = bytes.NewBuffer(jsonStr) } - modelRatio := billingratio.GetModelRatio(imageRequest.Model) + modelRatio := billingratio.GetModelRatio(imageModel) groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) diff --git a/web/berry/package.json b/web/berry/package.json index 2edb2355..f8265ef7 100644 --- a/web/berry/package.json +++ b/web/berry/package.json @@ -26,7 +26,7 @@ "notistack": "^3.0.1", "prop-types": "^15.8.1", "react": "^18.2.0", - "react-apexcharts": "^1.4.0", + "react-apexcharts": "1.4.0", "react-device-detect": "^2.2.2", "react-dom": "^18.2.0", "react-perfect-scrollbar": "^1.5.8", diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index e6b0aed5..589ef1fb 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -11,12 +11,18 @@ export const CHANNEL_OPTIONS = { value: 14, color: 'primary' }, - // 33: { - // key: 33, - // text: 'AWS Claude', - // value: 33, - // color: 'primary' - // }, + 33: { + key: 33, + text: 'AWS Claude', + value: 33, + color: 'primary' + }, + 37: { + key: 37, + text: 'Cloudflare', + value: 37, + color: 'success' + }, 3: { key: 3, text: 'Azure OpenAI', @@ -119,12 +125,12 @@ export const CHANNEL_OPTIONS = { value: 32, color: 'primary' }, - // 34: { - // key: 34, - // text: 'Coze', - // value: 34, - // color: 'primary' - // }, + 34: { + key: 34, + text: 'Coze', + value: 34, + color: 'primary' + }, 35: { key: 35, text: 'Cohere', diff --git a/web/berry/src/constants/SnackbarConstants.js b/web/berry/src/constants/SnackbarConstants.js index 19523da1..05f79231 100644 --- a/web/berry/src/constants/SnackbarConstants.js +++ b/web/berry/src/constants/SnackbarConstants.js @@ -1,24 +1,56 @@ +import { closeSnackbar } from 'notistack'; +import { IconX } from '@tabler/icons-react'; +import { IconButton } from '@mui/material'; +const action = (snackbarId) => ( + <> + { + closeSnackbar(snackbarId); + }} + > + + + +); + export const snackbarConstants = { Common: { ERROR: { variant: 'error', - autoHideDuration: 5000 + autoHideDuration: 5000, + preventDuplicate: true, + action }, WARNING: { variant: 'warning', - autoHideDuration: 10000 + autoHideDuration: 10000, + preventDuplicate: true, + action }, SUCCESS: { variant: 'success', - autoHideDuration: 1500 + autoHideDuration: 1500, + preventDuplicate: true, + action }, INFO: { variant: 'info', - autoHideDuration: 3000 + autoHideDuration: 3000, + preventDuplicate: true, + action }, NOTICE: { variant: 'info', - autoHideDuration: 7000 + autoHideDuration: 20000, + preventDuplicate: true, + action + }, + COPY: { + variant: 'copy', + persist: true, + preventDuplicate: true, + allowDownload: true, + action } }, Mobile: { diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js index 947df3bf..d74d032e 100644 --- a/web/berry/src/utils/common.js +++ b/web/berry/src/utils/common.js @@ -193,3 +193,40 @@ export function removeTrailingSlash(url) { return url; } } + +let channelModels = undefined; +export async function loadChannelModels() { + const res = await API.get('/api/models'); + const { success, data } = res.data; + if (!success) { + return; + } + channelModels = data; + localStorage.setItem('channel_models', JSON.stringify(data)); +} + +export function getChannelModels(type) { + if (channelModels !== undefined && type in channelModels) { + return channelModels[type]; + } + let models = localStorage.getItem('channel_models'); + if (!models) { + return []; + } + channelModels = JSON.parse(models); + if (type in channelModels) { + return channelModels[type]; + } + return []; +} + +export function copy(text, name = '') { + try { + navigator.clipboard.writeText(text); + } catch (error) { + text = `复制${name}失败,请手动复制:

${text}`; + enqueueSnackbar(, getSnackbarOptions('COPY')); + return; + } + showSuccess(`复制${name}成功!`); +} diff --git a/web/berry/src/views/Authentication/AuthForms/ResetPasswordForm.js b/web/berry/src/views/Authentication/AuthForms/ResetPasswordForm.js index eaa8dc95..a9f0f9e3 100644 --- a/web/berry/src/views/Authentication/AuthForms/ResetPasswordForm.js +++ b/web/berry/src/views/Authentication/AuthForms/ResetPasswordForm.js @@ -1,22 +1,22 @@ -import { useState, useEffect } from "react"; -import { useSearchParams } from "react-router-dom"; +import { useState, useEffect } from 'react'; +import { useSearchParams } from 'react-router-dom'; // material-ui -import { Button, Stack, Typography, Alert } from "@mui/material"; +import { Button, Stack, Typography, Alert } from '@mui/material'; // assets -import { showError, showInfo } from "utils/common"; -import { API } from "utils/api"; +import { showError, copy } from 'utils/common'; +import { API } from 'utils/api'; // ===========================|| FIREBASE - REGISTER ||=========================== // const ResetPasswordForm = () => { const [searchParams] = useSearchParams(); const [inputs, setInputs] = useState({ - email: "", - token: "", + email: '', + token: '' }); - const [newPassword, setNewPassword] = useState(""); + const [newPassword, setNewPassword] = useState(''); const submit = async () => { const res = await API.post(`/api/user/reset`, inputs); @@ -24,31 +24,25 @@ const ResetPasswordForm = () => { if (success) { let password = res.data.data; setNewPassword(password); - navigator.clipboard.writeText(password); - showInfo(`新密码已复制到剪贴板:${password}`); + copy(password, '新密码'); } else { showError(message); } }; useEffect(() => { - let email = searchParams.get("email"); - let token = searchParams.get("token"); + let email = searchParams.get('email'); + let token = searchParams.get('token'); setInputs({ token, - email, + email }); }, []); return ( - + {!inputs.email || !inputs.token ? ( - + 无效的链接 ) : newPassword ? ( @@ -57,14 +51,7 @@ const ResetPasswordForm = () => { 请登录后及时修改密码 ) : ( - )} diff --git a/web/berry/src/views/Channel/component/EditModal.js b/web/berry/src/views/Channel/component/EditModal.js index 03b4df57..4f7f216d 100644 --- a/web/berry/src/views/Channel/component/EditModal.js +++ b/web/berry/src/views/Channel/component/EditModal.js @@ -1,9 +1,9 @@ -import PropTypes from "prop-types"; -import { useState, useEffect } from "react"; -import { CHANNEL_OPTIONS } from "constants/ChannelConstants"; -import { useTheme } from "@mui/material/styles"; -import { API } from "utils/api"; -import { showError, showSuccess } from "utils/common"; +import PropTypes from 'prop-types'; +import { useState, useEffect } from 'react'; +import { CHANNEL_OPTIONS } from 'constants/ChannelConstants'; +import { useTheme } from '@mui/material/styles'; +import { API } from 'utils/api'; +import { showError, showSuccess, getChannelModels } from 'utils/common'; import { Dialog, DialogTitle, @@ -22,15 +22,15 @@ import { Autocomplete, FormHelperText, Switch, - Checkbox, -} from "@mui/material"; + Checkbox +} from '@mui/material'; -import { Formik } from "formik"; -import * as Yup from "yup"; -import { defaultConfig, typeConfig } from "../type/Config"; //typeConfig -import { createFilterOptions } from "@mui/material/Autocomplete"; -import CheckBoxOutlineBlankIcon from "@mui/icons-material/CheckBoxOutlineBlank"; -import CheckBoxIcon from "@mui/icons-material/CheckBox"; +import { Formik } from 'formik'; +import * as Yup from 'yup'; +import { defaultConfig, typeConfig } from '../type/Config'; //typeConfig +import { createFilterOptions } from '@mui/material/Autocomplete'; +import CheckBoxOutlineBlankIcon from '@mui/icons-material/CheckBoxOutlineBlank'; +import CheckBoxIcon from '@mui/icons-material/CheckBox'; const icon = ; const checkedIcon = ; @@ -38,38 +38,34 @@ const checkedIcon = ; const filter = createFilterOptions(); const validationSchema = Yup.object().shape({ is_edit: Yup.boolean(), - name: Yup.string().required("名称 不能为空"), - type: Yup.number().required("渠道 不能为空"), - key: Yup.string().when("is_edit", { - is: false, - then: Yup.string().required("密钥 不能为空"), + name: Yup.string().required('名称 不能为空'), + type: Yup.number().required('渠道 不能为空'), + key: Yup.string().when(['is_edit', 'type'], { + is: (is_edit, type) => !is_edit && type !== 33, + then: Yup.string().required('密钥 不能为空') }), other: Yup.string(), - models: Yup.array().min(1, "模型 不能为空"), - groups: Yup.array().min(1, "用户组 不能为空"), - base_url: Yup.string().when("type", { + models: Yup.array().min(1, '模型 不能为空'), + groups: Yup.array().min(1, '用户组 不能为空'), + base_url: Yup.string().when('type', { is: (value) => [3, 8].includes(value), - then: Yup.string().required("渠道API地址 不能为空"), // base_url 是必需的 - otherwise: Yup.string(), // 在其他情况下,base_url 可以是任意字符串 + then: Yup.string().required('渠道API地址 不能为空'), // base_url 是必需的 + otherwise: Yup.string() // 在其他情况下,base_url 可以是任意字符串 }), - model_mapping: Yup.string().test( - "is-json", - "必须是有效的JSON字符串", - function (value) { - try { - if (value === "" || value === null || value === undefined) { - return true; - } - const parsedValue = JSON.parse(value); - if (typeof parsedValue === "object") { - return true; - } - } catch (e) { - return false; + model_mapping: Yup.string().test('is-json', '必须是有效的JSON字符串', function (value) { + try { + if (value === '' || value === null || value === undefined) { + return true; } + const parsedValue = JSON.parse(value); + if (typeof parsedValue === 'object') { + return true; + } + } catch (e) { return false; } - ), + return false; + }) }); const EditModal = ({ open, channelId, onCancel, onOk }) => { @@ -81,12 +77,13 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { const [groupOptions, setGroupOptions] = useState([]); const [modelOptions, setModelOptions] = useState([]); const [batchAdd, setBatchAdd] = useState(false); + const [basicModels, setBasicModels] = useState([]); const initChannel = (typeValue) => { if (typeConfig[typeValue]?.inputLabel) { setInputLabel({ ...defaultConfig.inputLabel, - ...typeConfig[typeValue].inputLabel, + ...typeConfig[typeValue].inputLabel }); } else { setInputLabel(defaultConfig.inputLabel); @@ -95,7 +92,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { if (typeConfig[typeValue]?.prompt) { setInputPrompt({ ...defaultConfig.prompt, - ...typeConfig[typeValue].prompt, + ...typeConfig[typeValue].prompt }); } else { setInputPrompt(defaultConfig.prompt); @@ -104,40 +101,14 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { return typeConfig[typeValue]?.input; }; const handleTypeChange = (setFieldValue, typeValue, values) => { - const newInput = initChannel(typeValue); - - if (newInput) { - Object.keys(newInput).forEach((key) => { - if ( - (!Array.isArray(values[key]) && - values[key] !== null && - values[key] !== undefined && - values[key] !== "") || - (Array.isArray(values[key]) && values[key].length > 0) - ) { - return; - } - - if (key === "models") { - setFieldValue(key, initialModel(newInput[key])); - return; - } - setFieldValue(key, newInput[key]); - }); + initChannel(typeValue); + let localModels = getChannelModels(typeValue); + setBasicModels(localModels); + if (localModels.length > 0 && Array.isArray(values['models']) && values['models'].length == 0) { + setFieldValue('models', initialModel(localModels)); } - }; - const basicModels = (channelType) => { - let modelGroup = - typeConfig[channelType]?.modelGroup || defaultConfig.modelGroup; - // 循环 modelOptions,找到 modelGroup 对应的模型 - let modelList = []; - modelOptions.forEach((model) => { - if (model.group === modelGroup) { - modelList.push(model); - } - }); - return modelList; + setFieldValue('config', {}); }; const fetchGroups = async () => { @@ -155,7 +126,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { const { data } = res.data; data.forEach((item) => { if (!item.owned_by) { - item.owned_by = "未知"; + item.owned_by = '未知'; } }); // 先对data排序 @@ -171,7 +142,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { data.map((model) => { return { id: model.id, - group: model.owned_by, + group: model.owned_by }; }) ); @@ -182,33 +153,41 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { const submit = async (values, { setErrors, setStatus, setSubmitting }) => { setSubmitting(true); - if (values.base_url && values.base_url.endsWith("/")) { + if (values.base_url && values.base_url.endsWith('/')) { values.base_url = values.base_url.slice(0, values.base_url.length - 1); } - if (values.type === 3 && values.other === "") { - values.other = "2023-09-01-preview"; + if (values.type === 3 && values.other === '') { + values.other = '2023-09-01-preview'; } - if (values.type === 18 && values.other === "") { - values.other = "v2.1"; + if (values.type === 18 && values.other === '') { + values.other = 'v2.1'; } + if (values.key === '') { + if (values.config.ak !== '' && values.config.sk !== '' && values.config.region !== '') { + values.key = `${values.config.ak}|${values.config.sk}|${values.config.region}`; + } + } + let res; - const modelsStr = values.models.map((model) => model.id).join(","); - values.group = values.groups.join(","); + const modelsStr = values.models.map((model) => model.id).join(','); + const configStr = JSON.stringify(values.config); + values.group = values.groups.join(','); if (channelId) { res = await API.put(`/api/channel/`, { ...values, id: parseInt(channelId), models: modelsStr, + config: configStr }); } else { - res = await API.post(`/api/channel/`, { ...values, models: modelsStr }); + res = await API.post(`/api/channel/`, { ...values, models: modelsStr, config: configStr }); } const { success, message } = res.data; if (success) { if (channelId) { - showSuccess("渠道更新成功!"); + showSuccess('渠道更新成功!'); } else { - showSuccess("渠道创建成功!"); + showSuccess('渠道创建成功!'); } setSubmitting(false); setStatus({ success: true }); @@ -226,15 +205,15 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { } // 如果 channelModel 是一个字符串 - if (typeof channelModel === "string") { - channelModel = channelModel.split(","); + if (typeof channelModel === 'string') { + channelModel = channelModel.split(','); } let modelList = channelModel.map((model) => { const modelOption = modelOptions.find((option) => option.id === model); if (modelOption) { return modelOption; } - return { id: model, group: "自定义:点击或回车输入" }; + return { id: model, group: '自定义:点击或回车输入' }; }); return modelList; } @@ -243,24 +222,24 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { let res = await API.get(`/api/channel/${channelId}`); const { success, message, data } = res.data; if (success) { - if (data.models === "") { + if (data.models === '') { data.models = []; } else { data.models = initialModel(data.models); } - if (data.group === "") { + if (data.group === '') { data.groups = []; } else { - data.groups = data.group.split(","); + data.groups = data.group.split(','); } - if (data.model_mapping !== "") { - data.model_mapping = JSON.stringify( - JSON.parse(data.model_mapping), - null, - 2 - ); + if (data.model_mapping !== '') { + data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2); } - data.base_url = data.base_url ?? ""; + if (data.config !== '') { + data.config = JSON.parse(data.config); + } + + data.base_url = data.base_url ?? ''; data.is_edit = true; initChannel(data.type); setInitialInput(data); @@ -286,45 +265,25 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { }, [channelId]); return ( - + - {channelId ? "编辑渠道" : "新建渠道"} + {channelId ? '编辑渠道' : '新建渠道'} - - {({ - errors, - handleBlur, - handleChange, - handleSubmit, - isSubmitting, - touched, - values, - setFieldValue, - }) => ( + + {({ errors, handleBlur, handleChange, handleSubmit, isSubmitting, touched, values, setFieldValue }) => (
- - - {inputLabel.type} - + + {inputLabel.type}