diff --git a/.gitignore b/.gitignore index 974fcf63..2a8ae16e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ build logs data /web/node_modules +cmd.md \ No newline at end of file diff --git a/README.md b/README.md index 2dcdbd4f..d5c939be 100644 --- a/README.md +++ b/README.md @@ -81,11 +81,12 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [Groq](https://wow.groq.com/) + [x] [Ollama](https://github.com/ollama/ollama) + [x] [零一万物](https://platform.lingyiwanwu.com/) + + [x] [阶跃星辰](https://platform.stepfun.com/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 5. 支持**多机部署**,[详见此处](#多机部署)。 -6. 支持**令牌管理**,设置令牌的过期时间和额度。 +6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 8. 支持**渠道管理**,批量创建渠道。 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 @@ -101,10 +102,11 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 19. 支持丰富的**自定义**设置, 1. 支持自定义系统名称,logo 以及页脚。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 -20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 +20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。。 21. 支持 Cloudflare Turnstile 用户校验。 22. 支持用户管理,支持**多种用户登录注册方式**: + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + + 支持使用飞书进行授权登录。 + [GitHub 开放授权](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 diff --git a/common/config/config.go b/common/config/config.go index 3524183a..9fd7cba0 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -66,6 +66,9 @@ var SMTPToken = "" var GitHubClientId = "" var GitHubClientSecret = "" +var LarkClientId = "" +var LarkClientSecret = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" diff --git a/common/constants.go b/common/constants.go index 849bdce7..04a56649 100644 --- a/common/constants.go +++ b/common/constants.go @@ -71,6 +71,7 @@ const ( ChannelTypeGroq ChannelTypeOllama ChannelTypeLingYiWanWu + ChannelTypeStepFun ChannelTypeDummy ) @@ -108,6 +109,7 @@ var ChannelBaseURLs = []string{ "https://api.groq.com/openai", // 29 "http://localhost:11434", // 30 "https://api.lingyiwanwu.com", // 31 + "https://api.stepfun.com", // 32 } const ( diff --git a/common/conv/any.go b/common/conv/any.go new file mode 100644 index 00000000..467e8bb7 --- /dev/null +++ b/common/conv/any.go @@ -0,0 +1,6 @@ +package conv + +func AsString(v any) string { + str, _ := v.(string) + return str +} diff --git a/common/model-ratio.go b/common/model-ratio.go index b96f8d21..c8a2b5b8 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -72,14 +72,22 @@ var ModelRatio = map[string]float64{ "claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 - "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens - "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens - "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens - "ERNIE-Bot-8k": 0.024 * RMB, - "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens - "bge-large-zh": 0.002 * RMB, - "bge-large-en": 0.002 * RMB, - "bge-large-8k": 0.002 * RMB, + "ERNIE-4.0-8K": 0.120 * RMB, + "ERNIE-3.5-8K": 0.012 * RMB, + "ERNIE-3.5-8K-0205": 0.024 * RMB, + "ERNIE-3.5-8K-1222": 0.012 * RMB, + "ERNIE-Bot-8K": 0.024 * RMB, + "ERNIE-3.5-4K-0205": 0.012 * RMB, + "ERNIE-Speed-8K": 0.004 * RMB, + "ERNIE-Speed-128K": 0.004 * RMB, + "ERNIE-Lite-8K-0922": 0.008 * RMB, + "ERNIE-Lite-8K-0308": 0.003 * RMB, + "ERNIE-Tiny-8K": 0.001 * RMB, + "BLOOMZ-7B": 0.004 * RMB, + "Embedding-V1": 0.002 * RMB, + "bge-large-zh": 0.002 * RMB, + "bge-large-en": 0.002 * RMB, + "tao-8k": 0.002 * RMB, // https://ai.google.dev/pricing "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens @@ -88,13 +96,14 @@ var ModelRatio = map[string]float64{ "gemini-1.0-pro-001": 1, "gemini-1.5-pro": 1, // https://open.bigmodel.cn/pricing - "glm-4": 0.1 * RMB, - "glm-4v": 0.1 * RMB, - "glm-3-turbo": 0.005 * RMB, - "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens - "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens - "chatglm_std": 0.3572, // ¥0.005 / 1k tokens - "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "glm-4": 0.1 * RMB, + "glm-4v": 0.1 * RMB, + "glm-3-turbo": 0.005 * RMB, + "embedding-2": 0.0005 * RMB, + "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens + "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens + "chatglm_std": 0.3572, // ¥0.005 / 1k tokens + "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens "qwen-plus": 1.4286, // ¥0.02 / 1k tokens diff --git a/common/network/ip.go b/common/network/ip.go new file mode 100644 index 00000000..0fbe5e6f --- /dev/null +++ b/common/network/ip.go @@ -0,0 +1,52 @@ +package network + +import ( + "context" + "fmt" + "github.com/songquanpeng/one-api/common/logger" + "net" + "strings" +) + +func splitSubnets(subnets string) []string { + res := strings.Split(subnets, ",") + for i := 0; i < len(res); i++ { + res[i] = strings.TrimSpace(res[i]) + } + return res +} + +func isValidSubnet(subnet string) error { + _, _, err := net.ParseCIDR(subnet) + if err != nil { + return fmt.Errorf("failed to parse subnet: %w", err) + } + return nil +} + +func isIpInSubnet(ctx context.Context, ip string, subnet string) bool { + _, ipNet, err := net.ParseCIDR(subnet) + if err != nil { + logger.Errorf(ctx, "failed to parse subnet: %s", err.Error()) + return false + } + return ipNet.Contains(net.ParseIP(ip)) +} + +func IsValidSubnets(subnets string) error { + for _, subnet := range splitSubnets(subnets) { + if err := isValidSubnet(subnet); err != nil { + return err + } + } + return nil +} + +func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool { + for _, subnet := range splitSubnets(subnets) { + if isIpInSubnet(ctx, ip, subnet) { + return true + } + } + return false +} diff --git a/common/network/ip_test.go b/common/network/ip_test.go new file mode 100644 index 00000000..6c593458 --- /dev/null +++ b/common/network/ip_test.go @@ -0,0 +1,19 @@ +package network + +import ( + "context" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestIsIpInSubnet(t *testing.T) { + ctx := context.Background() + ip1 := "192.168.0.5" + ip2 := "125.216.250.89" + subnet := "192.168.0.0/24" + Convey("TestIsIpInSubnet", t, func() { + So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue) + So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse) + }) +} diff --git a/controller/github.go b/controller/auth/github.go similarity index 98% rename from controller/github.go rename to controller/auth/github.go index 7d7fa106..cf073133 100644 --- a/controller/github.go +++ b/controller/auth/github.go @@ -1,4 +1,4 @@ -package controller +package auth import ( "bytes" @@ -11,6 +11,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -159,7 +160,7 @@ func GitHubOAuth(c *gin.Context) { }) return } - setupLogin(&user, c) + controller.SetupLogin(&user, c) } func GitHubBind(c *gin.Context) { diff --git a/controller/auth/lark.go b/controller/auth/lark.go new file mode 100644 index 00000000..21446d46 --- /dev/null +++ b/controller/auth/lark.go @@ -0,0 +1,201 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" + "time" +) + +type LarkOAuthResponse struct { + AccessToken string `json:"access_token"` +} + +type LarkUser struct { + Name string `json:"name"` + OpenID string `json:"open_id"` +} + +func getLarkUserInfoByCode(code string) (*LarkUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.LarkClientId, + "client_secret": config.LarkClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + defer res.Body.Close() + var oAuthResponse LarkOAuthResponse + err = json.NewDecoder(res.Body).Decode(&oAuthResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + var larkUser LarkUser + err = json.NewDecoder(res2.Body).Decode(&larkUser) + if err != nil { + return nil, err + } + return &larkUser, nil +} + +func LarkOAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + LarkBind(c) + return + } + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + err := user.FillUserByLarkId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1) + if larkUser.Name != "" { + user.DisplayName = larkUser.Name + } else { + user.DisplayName = "Lark User" + } + user.Role = common.RoleCommonUser + user.Status = common.UserStatusEnabled + + if err := user.Insert(0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != common.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func LarkBind(c *gin.Context) { + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该飞书账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.LarkId = larkUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/wechat.go b/controller/auth/wechat.go similarity index 97% rename from controller/wechat.go rename to controller/auth/wechat.go index 74be5604..80552c9a 100644 --- a/controller/wechat.go +++ b/controller/auth/wechat.go @@ -1,4 +1,4 @@ -package controller +package auth import ( "encoding/json" @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -109,7 +110,7 @@ func WeChatAuth(c *gin.Context) { }) return } - setupLogin(&user, c) + controller.SetupLogin(&user, c) } func WeChatBind(c *gin.Context) { diff --git a/controller/misc.go b/controller/misc.go index f27fdb12..2928b8fb 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) { "email_verification": config.EmailVerificationEnabled, "github_oauth": config.GitHubOAuthEnabled, "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, "system_name": config.SystemName, "logo": config.Logo, "footer_html": config.Footer, diff --git a/controller/model.go b/controller/model.go index 4c5476b4..43e73c6c 100644 --- a/controller/model.go +++ b/controller/model.go @@ -4,12 +4,14 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "net/http" + "strings" ) // https://platform.openai.com/docs/api-reference/models/list @@ -120,9 +122,41 @@ func DashboardListModels(c *gin.Context) { } func ListModels(c *gin.Context) { + ctx := c.Request.Context() + var availableModels []string + if c.GetString("available_models") != "" { + availableModels = strings.Split(c.GetString("available_models"), ",") + } else { + userId := c.GetInt("id") + userGroup, _ := model.CacheGetUserGroup(userId) + availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) + } + modelSet := make(map[string]bool) + for _, availableModel := range availableModels { + modelSet[availableModel] = true + } + availableOpenAIModels := make([]OpenAIModels, 0) + for _, model := range openAIModels { + if _, ok := modelSet[model.Id]; ok { + modelSet[model.Id] = false + availableOpenAIModels = append(availableOpenAIModels, model) + } + } + for modelName, ok := range modelSet { + if ok { + availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + Root: modelName, + Parent: nil, + }) + } + } c.JSON(200, gin.H{ "object": "list", - "data": openAIModels, + "data": availableOpenAIModels, }) } @@ -142,3 +176,30 @@ func RetrieveModel(c *gin.Context) { }) } } + +func GetUserAvailableModels(c *gin.Context) { + ctx := c.Request.Context() + id := c.GetInt("id") + userGroup, err := model.CacheGetUserGroup(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + models, err := model.CacheGetGroupModels(ctx, userGroup) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": models, + }) + return +} diff --git a/controller/token.go b/controller/token.go index 949931da..7d20371c 100644 --- a/controller/token.go +++ b/controller/token.go @@ -1,10 +1,12 @@ package controller import ( + "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -104,6 +106,19 @@ func GetTokenStatus(c *gin.Context) { }) } +func validateToken(c *gin.Context, token model.Token) error { + if len(token.Name) > 30 { + return fmt.Errorf("令牌名称过长") + } + if token.Subnet != nil && *token.Subnet != "" { + err := network.IsValidSubnets(*token.Subnet) + if err != nil { + return fmt.Errorf("无效的网段:%s", err.Error()) + } + } + return nil +} + func AddToken(c *gin.Context) { token := model.Token{} err := c.ShouldBindJSON(&token) @@ -114,13 +129,15 @@ func AddToken(c *gin.Context) { }) return } - if len(token.Name) > 30 { + err = validateToken(c, token) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "令牌名称过长", + "message": fmt.Sprintf("参数错误:%s", err.Error()), }) return } + cleanToken := model.Token{ UserId: c.GetInt("id"), Name: token.Name, @@ -130,6 +147,8 @@ func AddToken(c *gin.Context) { ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, + Models: token.Models, + Subnet: token.Subnet, } err = cleanToken.Insert() if err != nil { @@ -177,10 +196,11 @@ func UpdateToken(c *gin.Context) { }) return } - if len(token.Name) > 30 { + err = validateToken(c, token) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "令牌名称过长", + "message": fmt.Sprintf("参数错误:%s", err.Error()), }) return } @@ -216,6 +236,8 @@ func UpdateToken(c *gin.Context) { cleanToken.ExpiredTime = token.ExpiredTime cleanToken.RemainQuota = token.RemainQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota + cleanToken.Models = token.Models + cleanToken.Subnet = token.Subnet } err = cleanToken.Update() if err != nil { diff --git a/controller/user.go b/controller/user.go index 8b614e5d..e87a03a2 100644 --- a/controller/user.go +++ b/controller/user.go @@ -58,11 +58,11 @@ func Login(c *gin.Context) { }) return } - setupLogin(&user, c) + SetupLogin(&user, c) } // setup session & cookies and then return user info -func setupLogin(user *model.User, c *gin.Context) { +func SetupLogin(user *model.User, c *gin.Context) { session := sessions.Default(c) session.Set("id", user.Id) session.Set("username", user.Username) @@ -180,27 +180,27 @@ func Register(c *gin.Context) { } func GetAllUsers(c *gin.Context) { - p, _ := strconv.Atoi(c.Query("p")) - if p < 0 { - p = 0 - } - - order := c.DefaultQuery("order", "") - users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) - - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": users, - }) + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + + order := c.DefaultQuery("order", "") + users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": users, + }) } func SearchUsers(c *gin.Context) { @@ -770,3 +770,38 @@ func TopUp(c *gin.Context) { }) return } + +type adminTopUpRequest struct { + UserId int `json:"user_id"` + Quota int `json:"quota"` + Remark string `json:"remark"` +} + +func AdminTopUp(c *gin.Context) { + req := adminTopUpRequest{} + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + err = model.IncreaseUserQuota(req.UserId, int64(req.Quota)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if req.Remark == "" { + req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) + } + model.RecordTopupLog(req.UserId, req.Remark, req.Quota) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 00000000..0b7ddf5a --- /dev/null +++ b/docs/API.md @@ -0,0 +1,53 @@ +# 使用 API 操控 & 扩展 One API +> 欢迎提交 PR 在此放上你的拓展项目。 + +例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。 + +又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。 + +## 鉴权 +One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取: + +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/c15281a7-83ed-47cb-a1f6-913cb6bf4a7c) + +之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API: +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1273b7ae-cb60-4c0d-93a6-b1cbc039c4f8) + +## 请求格式与响应格式 +One API 使用 JSON 格式进行请求和响应。 + +对于响应体,一般格式如下: +```json +{ + "message": "请求信息", + "success": true, + "data": {} +} +``` + +## API 列表 +> 当前 API 列表不全,请自行通过浏览器抓取前端请求 + +如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 + +### 获取当前登录用户信息 +**GET** `/api/user/self` + +### 为给定用户充值额度 +**POST** `/api/topup` +```json +{ + "user_id": 1, + "quota": 100000, + "remark": "充值 100000 额度" +} +``` + +## 其他 +### 充值链接上的附加参数 +One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如: +`https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837` + +你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。 + +注意,不是所有主题都支持该功能,欢迎 PR 补齐。 \ No newline at end of file diff --git a/go.mod b/go.mod index f9ed96d3..6ace51f2 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.5 + github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.8.3 golang.org/x/crypto v0.17.0 golang.org/x/image v0.14.0 @@ -37,6 +38,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect @@ -47,6 +49,7 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/mattn/go-isatty v0.0.19 // indirect @@ -55,6 +58,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/smarty/assertions v1.15.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect diff --git a/go.sum b/go.sum index 9cf056e5..3ead2711 100644 --- a/go.sum +++ b/go.sum @@ -56,11 +56,13 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= @@ -85,6 +87,8 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/ github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= @@ -127,6 +131,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -177,8 +185,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= diff --git a/middleware/auth.go b/middleware/auth.go index 30997efd..223cef3d 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,10 +1,12 @@ package middleware import ( + "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/blacklist" + "github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/model" "net/http" "strings" @@ -88,6 +90,7 @@ func RootAuth() func(c *gin.Context) { func TokenAuth() func(c *gin.Context) { return func(c *gin.Context) { + ctx := c.Request.Context() key := c.Request.Header.Get("Authorization") key = strings.TrimPrefix(key, "Bearer ") key = strings.TrimPrefix(key, "sk-") @@ -98,6 +101,12 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusUnauthorized, err.Error()) return } + if token.Subnet != nil && *token.Subnet != "" { + if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP())) + return + } + } userEnabled, err := model.CacheIsUserEnabled(token.UserId) if err != nil { abortWithMessage(c, http.StatusInternalServerError, err.Error()) @@ -107,6 +116,19 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } + requestModel, err := getRequestModel(c) + if err != nil && !strings.HasPrefix(c.Request.URL.Path, "/v1/models") { + abortWithMessage(c, http.StatusBadRequest, err.Error()) + return + } + c.Set("request_model", requestModel) + if token.Models != nil && *token.Models != "" { + c.Set("available_models", *token.Models) + if requestModel != "" && !isModelInList(requestModel, *token.Models) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) + return + } + } c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) diff --git a/middleware/distributor.go b/middleware/distributor.go index e845c2f8..04489a2b 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,14 +2,12 @@ package middleware import ( "fmt" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "net/http" "strconv" - "strings" - - "github.com/gin-gonic/gin" ) type ModelRequest struct { @@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) { return } } else { - // Select a channel for the user - var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c, &modelRequest) + requestModel := c.GetString("request_model") + var err error + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的请求") - return - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - if modelRequest.Model == "" { - modelRequest.Model = "text-moderation-stable" - } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - if modelRequest.Model == "" { - modelRequest.Model = c.Param("model") - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e-2" - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - if modelRequest.Model == "" { - modelRequest.Model = "whisper-1" - } - } - requestModel = modelRequest.Model - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) - if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) if channel != nil { logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" diff --git a/middleware/utils.go b/middleware/utils.go index bc14c367..b65b018b 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,9 +1,12 @@ package middleware import ( + "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { @@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { c.Abort() logger.Error(c.Request.Context(), message) } + +func getRequestModel(c *gin.Context) (string, error) { + var modelRequest ModelRequest + err := common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + if modelRequest.Model == "" { + modelRequest.Model = "dall-e-2" + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + if modelRequest.Model == "" { + modelRequest.Model = "whisper-1" + } + } + return modelRequest.Model, nil +} + +func isModelInList(modelName string, models string) bool { + modelList := strings.Split(models, ",") + for _, model := range modelList { + if modelName == model { + return true + } + } + return false +} diff --git a/model/ability.go b/model/ability.go index 48b856a2..4a48bc51 100644 --- a/model/ability.go +++ b/model/ability.go @@ -1,8 +1,10 @@ package model import ( + "context" "github.com/songquanpeng/one-api/common" "gorm.io/gorm" + "sort" "strings" ) @@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { func UpdateAbilityStatus(channelId int, status bool) error { return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error } + +func GetGroupModels(ctx context.Context, group string) ([]string, error) { + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + var models []string + err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error + if err != nil { + return nil, err + } + sort.Strings(models) + return models, err +} diff --git a/model/cache.go b/model/cache.go index 244fe6ac..cfc5445a 100644 --- a/model/cache.go +++ b/model/cache.go @@ -21,6 +21,7 @@ var ( UserId2GroupCacheSeconds = config.SyncFrequency UserId2QuotaCacheSeconds = config.SyncFrequency UserId2StatusCacheSeconds = config.SyncFrequency + GroupModelsCacheSeconds = config.SyncFrequency ) func CacheGetTokenByKey(key string) (*Token, error) { @@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) { return userEnabled, err } +func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { + if !common.RedisEnabled { + return GetGroupModels(ctx, group) + } + modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) + if err == nil { + return strings.Split(modelsStr, ","), nil + } + models, err := GetGroupModels(ctx, group) + if err != nil { + return nil, err + } + err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set group models error: " + err.Error()) + } + return models, nil +} + var group2model2channels map[string]map[string][]*Channel var channelSyncLock sync.RWMutex diff --git a/model/log.go b/model/log.go index 4409f73e..6b679c36 100644 --- a/model/log.go +++ b/model/log.go @@ -51,6 +51,21 @@ func RecordLog(userId int, logType int, content string) { } } +func RecordTopupLog(userId int, content string, quota int) { + log := &Log{ + UserId: userId, + Username: GetUsernameById(userId), + CreatedAt: helper.GetTimestamp(), + Type: LogTypeTopup, + Content: content, + Quota: quota, + } + err := LOG_DB.Create(log).Error + if err != nil { + logger.SysError("failed to record log: " + err.Error()) + } +} + func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !config.LogConsumeEnabled { diff --git a/model/option.go b/model/option.go index 1d1c28b4..cee9bd3b 100644 --- a/model/option.go +++ b/model/option.go @@ -172,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) { config.GitHubClientId = value case "GitHubClientSecret": config.GitHubClientSecret = value + case "LarkClientId": + config.LarkClientId = value + case "LarkClientSecret": + config.LarkClientSecret = value case "Footer": config.Footer = value case "SystemName": diff --git a/model/token.go b/model/token.go index 493e27c9..20228ec5 100644 --- a/model/token.go +++ b/model/token.go @@ -12,24 +12,26 @@ import ( ) type Token struct { - Id int `json:"id"` - UserId int `json:"user_id"` - Key string `json:"key" gorm:"type:char(48);uniqueIndex"` - Status int `json:"status" gorm:"default:1"` - Name string `json:"name" gorm:"index" ` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - AccessedTime int64 `json:"accessed_time" gorm:"bigint"` - ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired - RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` - UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` - UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota + Id int `json:"id"` + UserId int `json:"user_id"` + Key string `json:"key" gorm:"type:char(48);uniqueIndex"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index" ` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + AccessedTime int64 `json:"accessed_time" gorm:"bigint"` + ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired + RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` + UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota + Models *string `json:"models" gorm:"default:''"` // allowed models + Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet } func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { var tokens []*Token var err error query := DB.Where("user_id = ?", userId) - + switch order { case "remain_quota": query = query.Order("unlimited_quota desc, remain_quota desc") @@ -38,7 +40,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token default: query = query.Order("id desc") } - + err = query.Limit(num).Offset(startIdx).Find(&tokens).Error return tokens, err } @@ -61,7 +63,7 @@ func ValidateUserToken(key string) (token *Token, err error) { return nil, errors.New("令牌验证失败") } if token.Status == common.TokenStatusExhausted { - return nil, errors.New("该令牌额度已用尽") + return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id) } else if token.Status == common.TokenStatusExpired { return nil, errors.New("该令牌已过期") } @@ -121,7 +123,7 @@ func (token *Token) Insert() error { // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error return err } diff --git a/model/user.go b/model/user.go index 5e729b5e..42d8f7b1 100644 --- a/model/user.go +++ b/model/user.go @@ -24,6 +24,7 @@ type User struct { Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` + LarkId string `json:"lark_id" gorm:"column:lark_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int64 `json:"quota" gorm:"bigint;default:0"` @@ -41,21 +42,21 @@ func GetMaxUserId() int { } func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { - query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted) - - switch order { - case "quota": - query = query.Order("quota desc") - case "used_quota": - query = query.Order("used_quota desc") - case "request_count": - query = query.Order("request_count desc") - default: - query = query.Order("id desc") - } - - err = query.Find(&users).Error - return users, err + query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted) + + switch order { + case "quota": + query = query.Order("quota desc") + case "used_quota": + query = query.Order("used_quota desc") + case "request_count": + query = query.Order("request_count desc") + default: + query = query.Order("id desc") + } + + err = query.Find(&users).Error + return users, err } func SearchUsers(keyword string) (users []*User, err error) { @@ -206,6 +207,14 @@ func (user *User) FillUserByGitHubId() error { return nil } +func (user *User) FillUserByLarkId() error { + if user.LarkId == "" { + return errors.New("lark id 为空!") + } + DB.Where(User{LarkId: user.LarkId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -234,6 +243,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsLarkIdAlreadyTaken(githubId string) bool { + return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 +} + func IsUsernameAlreadyTaken(username string) bool { return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 } diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index df28f82b..c1832b11 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -52,6 +52,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { MaxTokens: request.MaxTokens, Temperature: request.Temperature, TopP: request.TopP, + TopK: request.TopK, + ResultFormat: "message", + Tools: request.Tools, }, } } @@ -132,19 +135,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR } func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { - choice := openai.TextResponseChoice{ - Index: 0, - Message: model.Message{ - Role: "assistant", - Content: response.Output.Text, - }, - FinishReason: response.Output.FinishReason, - } fullTextResponse := openai.TextResponse{ Id: response.RequestId, Object: "chat.completion", Created: helper.GetTimestamp(), - Choices: []openai.TextResponseChoice{choice}, + Choices: response.Output.Choices, Usage: model.Usage{ PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, @@ -155,10 +150,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { } func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + if len(aliResponse.Output.Choices) == 0 { + return nil + } + aliChoice := aliResponse.Output.Choices[0] var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = aliResponse.Output.Text - if aliResponse.Output.FinishReason != "null" { - finishReason := aliResponse.Output.FinishReason + choice.Delta = aliChoice.Message + if aliChoice.FinishReason != "null" { + finishReason := aliChoice.FinishReason choice.FinishReason = &finishReason } response := openai.ChatCompletionsStreamResponse{ @@ -219,6 +218,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens } response := streamResponseAli2OpenAI(&aliResponse) + if response == nil { + return true + } //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) //lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) @@ -241,6 +243,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := c.Request.Context() var aliResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -250,6 +253,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + logger.Debugf(ctx, "response body: %s\n", responseBody) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go index 34515b16..6a83f0f4 100644 --- a/relay/channel/ali/model.go +++ b/relay/channel/ali/model.go @@ -1,5 +1,10 @@ package ali +import ( + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" +) + type Message struct { Content string `json:"content"` Role string `json:"role"` @@ -11,13 +16,15 @@ type Input struct { } type Parameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` - IncrementalOutput bool `json:"incremental_output,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + ResultFormat string `json:"result_format,omitempty"` + Tools []model.Tool `json:"tools,omitempty"` } type ChatRequest struct { @@ -135,8 +142,9 @@ type Usage struct { } type Output struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` + //Text string `json:"text"` + //FinishReason string `json:"finish_reason"` + Choices []openai.TextResponseChoice `json:"choices"` } type ChatResponse struct { diff --git a/relay/channel/anthropic/main.go b/relay/channel/anthropic/main.go index 3eeb0b2c..04e65d99 100644 --- a/relay/channel/anthropic/main.go +++ b/relay/channel/anthropic/main.go @@ -38,6 +38,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { MaxTokens: textRequest.MaxTokens, Temperature: textRequest.Temperature, TopP: textRequest.TopP, + TopK: textRequest.TopK, Stream: textRequest.Stream, } if claudeRequest.MaxTokens == 0 { diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 471f2dc5..6096eb31 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -38,16 +38,34 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { suffix += "completions_pro" case "ERNIE-Bot-4": suffix += "completions_pro" - case "ERNIE-3.5-8K": - suffix += "completions" - case "ERNIE-Bot-8K": - suffix += "ernie_bot_8k" case "ERNIE-Bot": suffix += "completions" - case "ERNIE-Speed": - suffix += "ernie_speed" case "ERNIE-Bot-turbo": suffix += "eb-instant" + case "ERNIE-Speed": + suffix += "ernie_speed" + case "ERNIE-4.0-8K": + suffix += "completions_pro" + case "ERNIE-3.5-8K": + suffix += "completions" + case "ERNIE-3.5-8K-0205": + suffix += "ernie-3.5-8k-0205" + case "ERNIE-3.5-8K-1222": + suffix += "ernie-3.5-8k-1222" + case "ERNIE-Bot-8K": + suffix += "ernie_bot_8k" + case "ERNIE-3.5-4K-0205": + suffix += "ernie-3.5-4k-0205" + case "ERNIE-Speed-8K": + suffix += "ernie_speed" + case "ERNIE-Speed-128K": + suffix += "ernie-speed-128k" + case "ERNIE-Lite-8K-0922": + suffix += "eb-instant" + case "ERNIE-Lite-8K-0308": + suffix += "ernie-lite-8k" + case "ERNIE-Tiny-8K": + suffix += "ernie-tiny-8k" case "BLOOMZ-7B": suffix += "bloomz_7b1" case "Embedding-V1": @@ -59,7 +77,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { case "tao-8k": suffix += "tao_8k" default: - suffix += meta.ActualModelName + suffix += strings.ToLower(meta.ActualModelName) } fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) var accessToken string diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go index 45a4e901..f952adc6 100644 --- a/relay/channel/baidu/constants.go +++ b/relay/channel/baidu/constants.go @@ -1,11 +1,18 @@ package baidu var ModelList = []string{ - "ERNIE-Bot-4", + "ERNIE-4.0-8K", + "ERNIE-3.5-8K", + "ERNIE-3.5-8K-0205", + "ERNIE-3.5-8K-1222", "ERNIE-Bot-8K", - "ERNIE-Bot", - "ERNIE-Speed", - "ERNIE-Bot-turbo", + "ERNIE-3.5-4K-0205", + "ERNIE-Speed-8K", + "ERNIE-Speed-128K", + "ERNIE-Lite-8K-0922", + "ERNIE-Lite-8K-0308", + "ERNIE-Tiny-8K", + "BLOOMZ-7B", "Embedding-V1", "bge-large-zh", "bge-large-en", diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index bcbaf835..91b5fc2c 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -85,8 +85,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string - err, responseText, _ = StreamHandler(c, resp, meta.Mode) - usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + err, responseText, usage = StreamHandler(c, resp, meta.Mode) + if usage == nil { + usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) + } } else { switch meta.Mode { case constant.RelayModeImagesGenerations: diff --git a/relay/channel/openai/compatible.go b/relay/channel/openai/compatible.go index e4951a34..2a1447ab 100644 --- a/relay/channel/openai/compatible.go +++ b/relay/channel/openai/compatible.go @@ -9,6 +9,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channel/mistral" "github.com/songquanpeng/one-api/relay/channel/moonshot" + "github.com/songquanpeng/one-api/relay/channel/stepfun" ) var CompatibleChannels = []int{ @@ -20,6 +21,7 @@ var CompatibleChannels = []int{ common.ChannelTypeMistral, common.ChannelTypeGroq, common.ChannelTypeLingYiWanWu, + common.ChannelTypeStepFun, } func GetCompatibleChannelMeta(channelType int) (string, []string) { @@ -40,6 +42,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { return "groq", groq.ModelList case common.ChannelTypeLingYiWanWu: return "lingyiwanwu", lingyiwanwu.ModelList + case common.ChannelTypeStepFun: + return "stepfun", stepfun.ModelList default: return "openai", ModelList } diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go index fa4ee25e..df3f0691 100644 --- a/relay/channel/openai/main.go +++ b/relay/channel/openai/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" @@ -53,7 +54,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E continue // just ignore the error } for _, choice := range streamResponse.Choices { - responseText += choice.Delta.Content + responseText += conv.AsString(choice.Delta.Content) } if streamResponse.Usage != nil { usage = streamResponse.Usage diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go index 31d9fe97..6e9c38f1 100644 --- a/relay/channel/openai/model.go +++ b/relay/channel/openai/model.go @@ -123,12 +123,9 @@ type ImageResponse struct { } type ChatCompletionsStreamResponseChoice struct { - Index int `json:"index"` - Delta struct { - Content string `json:"content"` - Role string `json:"role,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` + Index int `json:"index"` + Delta model.Message `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` } type ChatCompletionsStreamResponse struct { diff --git a/relay/channel/stepfun/constants.go b/relay/channel/stepfun/constants.go new file mode 100644 index 00000000..a82e562b --- /dev/null +++ b/relay/channel/stepfun/constants.go @@ -0,0 +1,7 @@ +package stepfun + +var ModelList = []string{ + "step-1-32k", + "step-1v-32k", + "step-1-200k", +} diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go index cfdc0bfd..b5a64cde 100644 --- a/relay/channel/tencent/main.go +++ b/relay/channel/tencent/main.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" @@ -129,7 +130,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } response := streamResponseTencent2OpenAI(&TencentResponse) if len(response.Choices) != 0 { - responseText += response.Choices[0].Delta.Content + responseText += conv.AsString(response.Choices[0].Delta.Content) } jsonResponse, err := json.Marshal(response) if err != nil { diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go index 5e7014cb..67784a56 100644 --- a/relay/channel/xunfei/main.go +++ b/relay/channel/xunfei/main.go @@ -26,7 +26,11 @@ import ( func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) + var lastToolCalls []model.Tool for _, message := range request.Messages { + if message.ToolCalls != nil { + lastToolCalls = message.ToolCalls + } messages = append(messages, Message{ Role: message.Role, Content: message.StringContent(), @@ -39,9 +43,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string xunfeiRequest.Parameter.Chat.TopK = request.N xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Payload.Message.Text = messages + if len(lastToolCalls) != 0 { + for _, toolCall := range lastToolCalls { + xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function) + } + } + return &xunfeiRequest } +func getToolCalls(response *ChatResponse) []model.Tool { + var toolCalls []model.Tool + if len(response.Payload.Choices.Text) == 0 { + return toolCalls + } + item := response.Payload.Choices.Text[0] + if item.FunctionCall == nil { + return toolCalls + } + toolCall := model.Tool{ + Id: fmt.Sprintf("call_%s", helper.GetUUID()), + Type: "function", + Function: *item.FunctionCall, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { if len(response.Payload.Choices.Text) == 0 { response.Payload.Choices.Text = []ChatResponseTextItem{ @@ -53,8 +81,9 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: 0, Message: model.Message{ - Role: "assistant", - Content: response.Payload.Choices.Text[0].Content, + Role: "assistant", + Content: response.Payload.Choices.Text[0].Content, + ToolCalls: getToolCalls(response), }, FinishReason: constant.StopFinishReason, } @@ -78,6 +107,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl } var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content + choice.Delta.ToolCalls = getToolCalls(xunfeiResponse) if xunfeiResponse.Payload.Choices.Status == 2 { choice.FinishReason = &constant.StopFinishReason } diff --git a/relay/channel/xunfei/model.go b/relay/channel/xunfei/model.go index 1266739d..97a43154 100644 --- a/relay/channel/xunfei/model.go +++ b/relay/channel/xunfei/model.go @@ -26,13 +26,18 @@ type ChatRequest struct { Message struct { Text []Message `json:"text"` } `json:"message"` + Functions struct { + Text []model.Function `json:"text,omitempty"` + } `json:"functions,omitempty"` } `json:"payload"` } type ChatResponseTextItem struct { - Content string `json:"content"` - Role string `json:"role"` - Index int `json:"index"` + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` + ContentType string `json:"content_type"` + FunctionCall *model.Function `json:"function_call"` } type ChatResponse struct { diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 75ff977b..dbcf240d 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -35,6 +36,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { if a.APIVersion == "v4" { return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil } + if meta.Mode == constant.RelayModeEmbeddings { + return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil + } method := "invoke" if meta.IsStream { method = "sse-invoke" @@ -53,18 +57,24 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - // TopP (0.0, 1.0) - request.TopP = math.Min(0.99, request.TopP) - request.TopP = math.Max(0.01, request.TopP) + switch relayMode { + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) + return baiduEmbeddingRequest, nil + default: + // TopP (0.0, 1.0) + request.TopP = math.Min(0.99, request.TopP) + request.TopP = math.Max(0.01, request.TopP) - // Temperature (0.0, 1.0) - request.Temperature = math.Min(0.99, request.Temperature) - request.Temperature = math.Max(0.01, request.Temperature) - a.SetVersionByModeName(request.Model) - if a.APIVersion == "v4" { - return request, nil + // Temperature (0.0, 1.0) + request.Temperature = math.Min(0.99, request.Temperature) + request.Temperature = math.Max(0.01, request.Temperature) + a.SetVersionByModeName(request.Model) + if a.APIVersion == "v4" { + return request, nil + } + return ConvertRequest(*request), nil } - return ConvertRequest(*request), nil } func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { @@ -91,14 +101,26 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel if a.APIVersion == "v4" { return a.DoResponseV4(c, resp, meta) } + if meta.IsStream { err, usage = StreamHandler(c, resp) } else { - err, usage = Handler(c, resp) + if meta.Mode == constant.RelayModeEmbeddings { + err, usage = EmbeddingsHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } } return } +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + Model: "embedding-2", + Input: request.Input.(string), + } +} + func (a *Adaptor) GetModelList() []string { return ModelList } diff --git a/relay/channel/zhipu/constants.go b/relay/channel/zhipu/constants.go index 1655a59d..2daeb19c 100644 --- a/relay/channel/zhipu/constants.go +++ b/relay/channel/zhipu/constants.go @@ -2,5 +2,5 @@ package zhipu var ModelList = []string{ "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", - "glm-4", "glm-4v", "glm-3-turbo", + "glm-4", "glm-4v", "glm-3-turbo", "embedding-2", } diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go index a46fd537..f54e0504 100644 --- a/relay/channel/zhipu/main.go +++ b/relay/channel/zhipu/main.go @@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var zhipuResponse EmbeddingRespone + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &zhipuResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), + Model: response.Model, + Usage: model.Usage{ + PromptTokens: response.PromptTokens, + CompletionTokens: response.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + }, + } + + for _, item := range response.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} diff --git a/relay/channel/zhipu/model.go b/relay/channel/zhipu/model.go index b63e1d6f..3c3a7443 100644 --- a/relay/channel/zhipu/model.go +++ b/relay/channel/zhipu/model.go @@ -44,3 +44,21 @@ type tokenData struct { Token string ExpiryTime time.Time } + +type EmbeddingRequest struct { + Model string `json:"model"` + Input string `json:"input"` +} + +type EmbeddingRespone struct { + Model string `json:"model"` + Object string `json:"object"` + Embeddings []EmbeddingData `json:"data"` + model.Usage `json:"usage"` +} + +type EmbeddingData struct { + Index int `json:"index"` + Object string `json:"object"` + Embedding []float64 `json:"embedding"` +} diff --git a/relay/model/general.go b/relay/model/general.go index fbcc04e8..30772894 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -5,25 +5,29 @@ type ResponseFormat struct { } type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` + Model string `json:"model,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"` Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + Functions any `json:"functions,omitempty"` User string `json:"user,omitempty"` + Prompt any `json:"prompt,omitempty"` + Input any `json:"input,omitempty"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/relay/model/message.go b/relay/model/message.go index c6c8a271..32a1055b 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,9 +1,10 @@ package model type Message struct { - Role string `json:"role"` - Content any `json:"content"` - Name *string `json:"name,omitempty"` + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` } func (m Message) IsStringContent() bool { diff --git a/relay/model/tool.go b/relay/model/tool.go new file mode 100644 index 00000000..253dca35 --- /dev/null +++ b/relay/model/tool.go @@ -0,0 +1,14 @@ +package model + +type Tool struct { + Id string `json:"id,omitempty"` + Type string `json:"type"` + Function Function `json:"function"` +} + +type Function struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Parameters any `json:"parameters,omitempty"` // request + Arguments any `json:"arguments,omitempty"` // response +} diff --git a/relay/util/common.go b/relay/util/common.go index 535ef680..5d787204 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -46,6 +46,15 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { return true } + //if strings.Contains(err.Message, "quota") { + // return true + //} + if strings.Contains(err.Message, "credit") { + return true + } + if strings.Contains(err.Message, "balance") { + return true + } return false } diff --git a/router/api-router.go b/router/api-router.go index 5b755ede..a36232b3 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -2,6 +2,7 @@ package router import ( "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/controller/auth" "github.com/songquanpeng/one-api/middleware" "github.com/gin-contrib/gzip" @@ -21,11 +22,13 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) - apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) - apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) - apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) - apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) + apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) + apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) + apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) + apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) + apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind) apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) + apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp) userRoute := apiRouter.Group("/user") { @@ -43,6 +46,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/aff", controller.GetAffCode) selfRoute.POST("/topup", controller.TopUp) + selfRoute.GET("/available_models", controller.GetUserAvailableModels) } adminRoute := userRoute.Group("/") diff --git a/web/README.md b/web/README.md index 29f4713e..829271e2 100644 --- a/web/README.md +++ b/web/README.md @@ -2,6 +2,9 @@ > 每个文件夹代表一个主题,欢迎提交你的主题 +> [!WARNING] +> 不是每一个主题都及时同步了所有功能,由于精力有限,优先更新默认主题,其他主题欢迎 & 期待 PR + ## 提交新的主题 > 欢迎在页面底部保留你和 One API 的版权信息以及指向链接 diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index ec049f7d..b74c58c7 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -107,6 +107,12 @@ export const CHANNEL_OPTIONS = { value: 31, color: 'primary' }, + 32: { + key: 32, + text: '阶跃星辰', + value: 32, + color: 'primary' + }, 8: { key: 8, text: '自定义渠道', diff --git a/web/berry/src/constants/SnackbarConstants.js b/web/berry/src/constants/SnackbarConstants.js index a05c6652..19523da1 100644 --- a/web/berry/src/constants/SnackbarConstants.js +++ b/web/berry/src/constants/SnackbarConstants.js @@ -18,7 +18,7 @@ export const snackbarConstants = { }, NOTICE: { variant: 'info', - autoHideDuration: 20000 + autoHideDuration: 7000 } }, Mobile: { diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js index 25e5c635..d8dabac3 100644 --- a/web/berry/src/utils/common.js +++ b/web/berry/src/utils/common.js @@ -51,9 +51,9 @@ export function showError(error) { export function showNotice(message, isHTML = false) { if (isHTML) { - enqueueSnackbar(, getSnackbarOptions('INFO')); + enqueueSnackbar(, getSnackbarOptions('NOTICE')); } else { - enqueueSnackbar(message, getSnackbarOptions('INFO')); + enqueueSnackbar(message, getSnackbarOptions('NOTICE')); } } diff --git a/web/berry/src/views/Channel/component/EditModal.js b/web/berry/src/views/Channel/component/EditModal.js index 07111c97..cbf411b9 100644 --- a/web/berry/src/views/Channel/component/EditModal.js +++ b/web/berry/src/views/Channel/component/EditModal.js @@ -340,7 +340,9 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { }, }} > - {Object.values(CHANNEL_OPTIONS).map((option) => { + {Object.values(CHANNEL_OPTIONS).sort((a, b) => { + return a.text.localeCompare(b.text) + }).map((option) => { return ( {option.text} diff --git a/web/berry/src/views/Token/component/EditModal.js b/web/berry/src/views/Token/component/EditModal.js index cf65a96a..a5e08be1 100644 --- a/web/berry/src/views/Token/component/EditModal.js +++ b/web/berry/src/views/Token/component/EditModal.js @@ -103,7 +103,7 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => { fontSize: "1.125rem", }} > - {tokenId ? "编辑Token" : "新建Token"} + {tokenId ? "编辑令牌" : "新建令牌"} diff --git a/web/default/src/App.js b/web/default/src/App.js index 13c884dc..4ece4eeb 100644 --- a/web/default/src/App.js +++ b/web/default/src/App.js @@ -24,6 +24,7 @@ import EditRedemption from './pages/Redemption/EditRedemption'; import TopUp from './pages/TopUp'; import Log from './pages/Log'; import Chat from './pages/Chat'; +import LarkOAuth from './components/LarkOAuth'; const Home = lazy(() => import('./pages/Home')); const About = lazy(() => import('./pages/About')); @@ -239,6 +240,14 @@ function App() { } /> + }> + + + } + /> { + const [searchParams, setSearchParams] = useSearchParams(); + + const [userState, userDispatch] = useContext(UserContext); + const [prompt, setPrompt] = useState('处理中...'); + const [processing, setProcessing] = useState(true); + + let navigate = useNavigate(); + + const sendCode = async (code, state, count) => { + const res = await API.get(`/api/oauth/lark?code=${code}&state=${state}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/setting'); + } else { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } + } else { + showError(message); + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + navigate('/setting'); // in case this is failed to bind lark + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, count * 2000)); + await sendCode(code, state, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + let state = searchParams.get('state'); + sendCode(code, state, 0).then(); + }, []); + + return ( + + + {prompt} + + + ); +}; + +export default LarkOAuth; diff --git a/web/default/src/components/LoginForm.js b/web/default/src/components/LoginForm.js index b48f64c4..01408f56 100644 --- a/web/default/src/components/LoginForm.js +++ b/web/default/src/components/LoginForm.js @@ -3,7 +3,8 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f import { Link, useNavigate, useSearchParams } from 'react-router-dom'; import { UserContext } from '../context/User'; import { API, getLogo, showError, showSuccess, showWarning } from '../helpers'; -import { onGitHubOAuthClicked } from './utils'; +import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils'; +import larkIcon from '../images/lark.svg'; const LoginForm = () => { const [inputs, setInputs] = useState({ @@ -124,7 +125,7 @@ const LoginForm = () => { 点击注册 - {status.github_oauth || status.wechat_login ? ( + {status.github_oauth || status.wechat_login || status.lark_client_id ? ( <> Or {status.github_oauth ? ( @@ -137,6 +138,18 @@ const LoginForm = () => { ) : ( <> )} + {status.lark_client_id ? ( + + ) : ( + <> + )} {status.wechat_login ? ( ) } + { + status.lark_client_id && ( + + ) + }