diff --git a/.gitignore b/.gitignore
index 974fcf63..2a8ae16e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,3 +8,4 @@ build
logs
data
/web/node_modules
+cmd.md
\ No newline at end of file
diff --git a/README.md b/README.md
index 2dcdbd4f..d5c939be 100644
--- a/README.md
+++ b/README.md
@@ -81,11 +81,12 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Groq](https://wow.groq.com/)
+ [x] [Ollama](https://github.com/ollama/ollama)
+ [x] [零一万物](https://platform.lingyiwanwu.com/)
+ + [x] [阶跃星辰](https://platform.stepfun.com/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
5. 支持**多机部署**,[详见此处](#多机部署)。
-6. 支持**令牌管理**,设置令牌的过期时间和额度。
+6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。
7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
8. 支持**渠道管理**,批量创建渠道。
9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
@@ -101,10 +102,11 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称,logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
-20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。
+20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。。
21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式**:
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ + 支持使用飞书进行授权登录。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
diff --git a/common/config/config.go b/common/config/config.go
index 3524183a..9fd7cba0 100644
--- a/common/config/config.go
+++ b/common/config/config.go
@@ -66,6 +66,9 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
+var LarkClientId = ""
+var LarkClientSecret = ""
+
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
diff --git a/common/constants.go b/common/constants.go
index 849bdce7..04a56649 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -71,6 +71,7 @@ const (
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeLingYiWanWu
+ ChannelTypeStepFun
ChannelTypeDummy
)
@@ -108,6 +109,7 @@ var ChannelBaseURLs = []string{
"https://api.groq.com/openai", // 29
"http://localhost:11434", // 30
"https://api.lingyiwanwu.com", // 31
+ "https://api.stepfun.com", // 32
}
const (
diff --git a/common/conv/any.go b/common/conv/any.go
new file mode 100644
index 00000000..467e8bb7
--- /dev/null
+++ b/common/conv/any.go
@@ -0,0 +1,6 @@
+package conv
+
+func AsString(v any) string {
+ str, _ := v.(string)
+ return str
+}
diff --git a/common/model-ratio.go b/common/model-ratio.go
index b96f8d21..c8a2b5b8 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -72,14 +72,22 @@ var ModelRatio = map[string]float64{
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
- "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
- "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
- "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
- "ERNIE-Bot-8k": 0.024 * RMB,
- "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
- "bge-large-zh": 0.002 * RMB,
- "bge-large-en": 0.002 * RMB,
- "bge-large-8k": 0.002 * RMB,
+ "ERNIE-4.0-8K": 0.120 * RMB,
+ "ERNIE-3.5-8K": 0.012 * RMB,
+ "ERNIE-3.5-8K-0205": 0.024 * RMB,
+ "ERNIE-3.5-8K-1222": 0.012 * RMB,
+ "ERNIE-Bot-8K": 0.024 * RMB,
+ "ERNIE-3.5-4K-0205": 0.012 * RMB,
+ "ERNIE-Speed-8K": 0.004 * RMB,
+ "ERNIE-Speed-128K": 0.004 * RMB,
+ "ERNIE-Lite-8K-0922": 0.008 * RMB,
+ "ERNIE-Lite-8K-0308": 0.003 * RMB,
+ "ERNIE-Tiny-8K": 0.001 * RMB,
+ "BLOOMZ-7B": 0.004 * RMB,
+ "Embedding-V1": 0.002 * RMB,
+ "bge-large-zh": 0.002 * RMB,
+ "bge-large-en": 0.002 * RMB,
+ "tao-8k": 0.002 * RMB,
// https://ai.google.dev/pricing
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
@@ -88,13 +96,14 @@ var ModelRatio = map[string]float64{
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
// https://open.bigmodel.cn/pricing
- "glm-4": 0.1 * RMB,
- "glm-4v": 0.1 * RMB,
- "glm-3-turbo": 0.005 * RMB,
- "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
- "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
- "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
- "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
+ "glm-4": 0.1 * RMB,
+ "glm-4v": 0.1 * RMB,
+ "glm-3-turbo": 0.005 * RMB,
+ "embedding-2": 0.0005 * RMB,
+ "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
+ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
+ "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
+ "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
diff --git a/common/network/ip.go b/common/network/ip.go
new file mode 100644
index 00000000..0fbe5e6f
--- /dev/null
+++ b/common/network/ip.go
@@ -0,0 +1,52 @@
+package network
+
+import (
+ "context"
+ "fmt"
+ "github.com/songquanpeng/one-api/common/logger"
+ "net"
+ "strings"
+)
+
+func splitSubnets(subnets string) []string {
+ res := strings.Split(subnets, ",")
+ for i := 0; i < len(res); i++ {
+ res[i] = strings.TrimSpace(res[i])
+ }
+ return res
+}
+
+func isValidSubnet(subnet string) error {
+ _, _, err := net.ParseCIDR(subnet)
+ if err != nil {
+ return fmt.Errorf("failed to parse subnet: %w", err)
+ }
+ return nil
+}
+
+func isIpInSubnet(ctx context.Context, ip string, subnet string) bool {
+ _, ipNet, err := net.ParseCIDR(subnet)
+ if err != nil {
+ logger.Errorf(ctx, "failed to parse subnet: %s", err.Error())
+ return false
+ }
+ return ipNet.Contains(net.ParseIP(ip))
+}
+
+func IsValidSubnets(subnets string) error {
+ for _, subnet := range splitSubnets(subnets) {
+ if err := isValidSubnet(subnet); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool {
+ for _, subnet := range splitSubnets(subnets) {
+ if isIpInSubnet(ctx, ip, subnet) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/common/network/ip_test.go b/common/network/ip_test.go
new file mode 100644
index 00000000..6c593458
--- /dev/null
+++ b/common/network/ip_test.go
@@ -0,0 +1,19 @@
+package network
+
+import (
+ "context"
+ "testing"
+
+ . "github.com/smartystreets/goconvey/convey"
+)
+
+func TestIsIpInSubnet(t *testing.T) {
+ ctx := context.Background()
+ ip1 := "192.168.0.5"
+ ip2 := "125.216.250.89"
+ subnet := "192.168.0.0/24"
+ Convey("TestIsIpInSubnet", t, func() {
+ So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
+ So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
+ })
+}
diff --git a/controller/github.go b/controller/auth/github.go
similarity index 98%
rename from controller/github.go
rename to controller/auth/github.go
index 7d7fa106..cf073133 100644
--- a/controller/github.go
+++ b/controller/auth/github.go
@@ -1,4 +1,4 @@
-package controller
+package auth
import (
"bytes"
@@ -11,6 +11,7 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -159,7 +160,7 @@ func GitHubOAuth(c *gin.Context) {
})
return
}
- setupLogin(&user, c)
+ controller.SetupLogin(&user, c)
}
func GitHubBind(c *gin.Context) {
diff --git a/controller/auth/lark.go b/controller/auth/lark.go
new file mode 100644
index 00000000..21446d46
--- /dev/null
+++ b/controller/auth/lark.go
@@ -0,0 +1,201 @@
+package auth
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/controller"
+ "github.com/songquanpeng/one-api/model"
+ "net/http"
+ "strconv"
+ "time"
+)
+
+type LarkOAuthResponse struct {
+ AccessToken string `json:"access_token"`
+}
+
+type LarkUser struct {
+ Name string `json:"name"`
+ OpenID string `json:"open_id"`
+}
+
+func getLarkUserInfoByCode(code string) (*LarkUser, error) {
+ if code == "" {
+ return nil, errors.New("无效的参数")
+ }
+ values := map[string]string{
+ "client_id": config.LarkClientId,
+ "client_secret": config.LarkClientSecret,
+ "code": code,
+ "grant_type": "authorization_code",
+ "redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress),
+ }
+ jsonData, err := json.Marshal(values)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ client := http.Client{
+ Timeout: 5 * time.Second,
+ }
+ res, err := client.Do(req)
+ if err != nil {
+ logger.SysLog(err.Error())
+ return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
+ }
+ defer res.Body.Close()
+ var oAuthResponse LarkOAuthResponse
+ err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
+ if err != nil {
+ return nil, err
+ }
+ req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
+ res2, err := client.Do(req)
+ if err != nil {
+ logger.SysLog(err.Error())
+ return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
+ }
+ var larkUser LarkUser
+ err = json.NewDecoder(res2.Body).Decode(&larkUser)
+ if err != nil {
+ return nil, err
+ }
+ return &larkUser, nil
+}
+
+func LarkOAuth(c *gin.Context) {
+ session := sessions.Default(c)
+ state := c.Query("state")
+ if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
+ c.JSON(http.StatusForbidden, gin.H{
+ "success": false,
+ "message": "state is empty or not same",
+ })
+ return
+ }
+ username := session.Get("username")
+ if username != nil {
+ LarkBind(c)
+ return
+ }
+ code := c.Query("code")
+ larkUser, err := getLarkUserInfoByCode(code)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ user := model.User{
+ LarkId: larkUser.OpenID,
+ }
+ if model.IsLarkIdAlreadyTaken(user.LarkId) {
+ err := user.FillUserByLarkId()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ if config.RegisterEnabled {
+ user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1)
+ if larkUser.Name != "" {
+ user.DisplayName = larkUser.Name
+ } else {
+ user.DisplayName = "Lark User"
+ }
+ user.Role = common.RoleCommonUser
+ user.Status = common.UserStatusEnabled
+
+ if err := user.Insert(0); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员关闭了新用户注册",
+ })
+ return
+ }
+ }
+
+ if user.Status != common.UserStatusEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "用户已被封禁",
+ "success": false,
+ })
+ return
+ }
+ controller.SetupLogin(&user, c)
+}
+
+func LarkBind(c *gin.Context) {
+ code := c.Query("code")
+ larkUser, err := getLarkUserInfoByCode(code)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ user := model.User{
+ LarkId: larkUser.OpenID,
+ }
+ if model.IsLarkIdAlreadyTaken(user.LarkId) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该飞书账户已被绑定",
+ })
+ return
+ }
+ session := sessions.Default(c)
+ id := session.Get("id")
+ // id := c.GetInt("id") // critical bug!
+ user.Id = id.(int)
+ err = user.FillUserById()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ user.LarkId = larkUser.OpenID
+ err = user.Update(false)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "bind",
+ })
+ return
+}
diff --git a/controller/wechat.go b/controller/auth/wechat.go
similarity index 97%
rename from controller/wechat.go
rename to controller/auth/wechat.go
index 74be5604..80552c9a 100644
--- a/controller/wechat.go
+++ b/controller/auth/wechat.go
@@ -1,4 +1,4 @@
-package controller
+package auth
import (
"encoding/json"
@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -109,7 +110,7 @@ func WeChatAuth(c *gin.Context) {
})
return
}
- setupLogin(&user, c)
+ controller.SetupLogin(&user, c)
}
func WeChatBind(c *gin.Context) {
diff --git a/controller/misc.go b/controller/misc.go
index f27fdb12..2928b8fb 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) {
"email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId,
+ "lark_client_id": config.LarkClientId,
"system_name": config.SystemName,
"logo": config.Logo,
"footer_html": config.Footer,
diff --git a/controller/model.go b/controller/model.go
index 4c5476b4..43e73c6c 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -4,12 +4,14 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http"
+ "strings"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -120,9 +122,41 @@ func DashboardListModels(c *gin.Context) {
}
func ListModels(c *gin.Context) {
+ ctx := c.Request.Context()
+ var availableModels []string
+ if c.GetString("available_models") != "" {
+ availableModels = strings.Split(c.GetString("available_models"), ",")
+ } else {
+ userId := c.GetInt("id")
+ userGroup, _ := model.CacheGetUserGroup(userId)
+ availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
+ }
+ modelSet := make(map[string]bool)
+ for _, availableModel := range availableModels {
+ modelSet[availableModel] = true
+ }
+ availableOpenAIModels := make([]OpenAIModels, 0)
+ for _, model := range openAIModels {
+ if _, ok := modelSet[model.Id]; ok {
+ modelSet[model.Id] = false
+ availableOpenAIModels = append(availableOpenAIModels, model)
+ }
+ }
+ for modelName, ok := range modelSet {
+ if ok {
+ availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: "custom",
+ Root: modelName,
+ Parent: nil,
+ })
+ }
+ }
c.JSON(200, gin.H{
"object": "list",
- "data": openAIModels,
+ "data": availableOpenAIModels,
})
}
@@ -142,3 +176,30 @@ func RetrieveModel(c *gin.Context) {
})
}
}
+
+func GetUserAvailableModels(c *gin.Context) {
+ ctx := c.Request.Context()
+ id := c.GetInt("id")
+ userGroup, err := model.CacheGetUserGroup(id)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ models, err := model.CacheGetGroupModels(ctx, userGroup)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": models,
+ })
+ return
+}
diff --git a/controller/token.go b/controller/token.go
index 949931da..7d20371c 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -1,10 +1,12 @@
package controller
import (
+ "fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -104,6 +106,19 @@ func GetTokenStatus(c *gin.Context) {
})
}
+func validateToken(c *gin.Context, token model.Token) error {
+ if len(token.Name) > 30 {
+ return fmt.Errorf("令牌名称过长")
+ }
+ if token.Subnet != nil && *token.Subnet != "" {
+ err := network.IsValidSubnets(*token.Subnet)
+ if err != nil {
+ return fmt.Errorf("无效的网段:%s", err.Error())
+ }
+ }
+ return nil
+}
+
func AddToken(c *gin.Context) {
token := model.Token{}
err := c.ShouldBindJSON(&token)
@@ -114,13 +129,15 @@ func AddToken(c *gin.Context) {
})
return
}
- if len(token.Name) > 30 {
+ err = validateToken(c, token)
+ if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
- "message": "令牌名称过长",
+ "message": fmt.Sprintf("参数错误:%s", err.Error()),
})
return
}
+
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
@@ -130,6 +147,8 @@ func AddToken(c *gin.Context) {
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
+ Models: token.Models,
+ Subnet: token.Subnet,
}
err = cleanToken.Insert()
if err != nil {
@@ -177,10 +196,11 @@ func UpdateToken(c *gin.Context) {
})
return
}
- if len(token.Name) > 30 {
+ err = validateToken(c, token)
+ if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
- "message": "令牌名称过长",
+ "message": fmt.Sprintf("参数错误:%s", err.Error()),
})
return
}
@@ -216,6 +236,8 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota
+ cleanToken.Models = token.Models
+ cleanToken.Subnet = token.Subnet
}
err = cleanToken.Update()
if err != nil {
diff --git a/controller/user.go b/controller/user.go
index 8b614e5d..e87a03a2 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -58,11 +58,11 @@ func Login(c *gin.Context) {
})
return
}
- setupLogin(&user, c)
+ SetupLogin(&user, c)
}
// setup session & cookies and then return user info
-func setupLogin(user *model.User, c *gin.Context) {
+func SetupLogin(user *model.User, c *gin.Context) {
session := sessions.Default(c)
session.Set("id", user.Id)
session.Set("username", user.Username)
@@ -180,27 +180,27 @@ func Register(c *gin.Context) {
}
func GetAllUsers(c *gin.Context) {
- p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
- }
-
- order := c.DefaultQuery("order", "")
- users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
-
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": users,
- })
+ p, _ := strconv.Atoi(c.Query("p"))
+ if p < 0 {
+ p = 0
+ }
+
+ order := c.DefaultQuery("order", "")
+ users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
+
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": users,
+ })
}
func SearchUsers(c *gin.Context) {
@@ -770,3 +770,38 @@ func TopUp(c *gin.Context) {
})
return
}
+
+type adminTopUpRequest struct {
+ UserId int `json:"user_id"`
+ Quota int `json:"quota"`
+ Remark string `json:"remark"`
+}
+
+func AdminTopUp(c *gin.Context) {
+ req := adminTopUpRequest{}
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ err = model.IncreaseUserQuota(req.UserId, int64(req.Quota))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ if req.Remark == "" {
+ req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
+ }
+ model.RecordTopupLog(req.UserId, req.Remark, req.Quota)
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
diff --git a/docs/API.md b/docs/API.md
new file mode 100644
index 00000000..0b7ddf5a
--- /dev/null
+++ b/docs/API.md
@@ -0,0 +1,53 @@
+# 使用 API 操控 & 扩展 One API
+> 欢迎提交 PR 在此放上你的拓展项目。
+
+例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。
+
+又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。
+
+## 鉴权
+One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取:
+
+
+
+之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API:
+
+
+## 请求格式与响应格式
+One API 使用 JSON 格式进行请求和响应。
+
+对于响应体,一般格式如下:
+```json
+{
+ "message": "请求信息",
+ "success": true,
+ "data": {}
+}
+```
+
+## API 列表
+> 当前 API 列表不全,请自行通过浏览器抓取前端请求
+
+如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。
+
+### 获取当前登录用户信息
+**GET** `/api/user/self`
+
+### 为给定用户充值额度
+**POST** `/api/topup`
+```json
+{
+ "user_id": 1,
+ "quota": 100000,
+ "remark": "充值 100000 额度"
+}
+```
+
+## 其他
+### 充值链接上的附加参数
+One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如:
+`https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837`
+
+你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。
+
+注意,不是所有主题都支持该功能,欢迎 PR 补齐。
\ No newline at end of file
diff --git a/go.mod b/go.mod
index f9ed96d3..6ace51f2 100644
--- a/go.mod
+++ b/go.mod
@@ -15,6 +15,7 @@ require (
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
+ github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.17.0
golang.org/x/image v0.14.0
@@ -37,6 +38,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
+ github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
@@ -47,6 +49,7 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
+ github.com/jtolds/gls v4.20.0+incompatible // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
@@ -55,6 +58,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
+ github.com/smarty/assertions v1.15.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
diff --git a/go.sum b/go.sum
index 9cf056e5..3ead2711 100644
--- a/go.sum
+++ b/go.sum
@@ -56,11 +56,13 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
-github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
+github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
@@ -85,6 +87,8 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
+github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
+github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
@@ -127,6 +131,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
+github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
+github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
+github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
+github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -177,8 +185,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
diff --git a/middleware/auth.go b/middleware/auth.go
index 30997efd..223cef3d 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -1,10 +1,12 @@
package middleware
import (
+ "fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
+ "github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/model"
"net/http"
"strings"
@@ -88,6 +90,7 @@ func RootAuth() func(c *gin.Context) {
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
+ ctx := c.Request.Context()
key := c.Request.Header.Get("Authorization")
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(key, "sk-")
@@ -98,6 +101,12 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
+ if token.Subnet != nil && *token.Subnet != "" {
+ if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
+ abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP()))
+ return
+ }
+ }
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
@@ -107,6 +116,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
+ requestModel, err := getRequestModel(c)
+ if err != nil && !strings.HasPrefix(c.Request.URL.Path, "/v1/models") {
+ abortWithMessage(c, http.StatusBadRequest, err.Error())
+ return
+ }
+ c.Set("request_model", requestModel)
+ if token.Models != nil && *token.Models != "" {
+ c.Set("available_models", *token.Models)
+ if requestModel != "" && !isModelInList(requestModel, *token.Models) {
+ abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
+ return
+ }
+ }
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
diff --git a/middleware/distributor.go b/middleware/distributor.go
index e845c2f8..04489a2b 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -2,14 +2,12 @@ package middleware
import (
"fmt"
+ "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
- "strings"
-
- "github.com/gin-gonic/gin"
)
type ModelRequest struct {
@@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) {
return
}
} else {
- // Select a channel for the user
- var modelRequest ModelRequest
- err := common.UnmarshalBodyReusable(c, &modelRequest)
+ requestModel := c.GetString("request_model")
+ var err error
+ channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
if err != nil {
- abortWithMessage(c, http.StatusBadRequest, "无效的请求")
- return
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
- if modelRequest.Model == "" {
- modelRequest.Model = "text-moderation-stable"
- }
- }
- if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
- if modelRequest.Model == "" {
- modelRequest.Model = c.Param("model")
- }
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
- if modelRequest.Model == "" {
- modelRequest.Model = "dall-e-2"
- }
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
- if modelRequest.Model == "" {
- modelRequest.Model = "whisper-1"
- }
- }
- requestModel = modelRequest.Model
- channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
- if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
if channel != nil {
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
diff --git a/middleware/utils.go b/middleware/utils.go
index bc14c367..b65b018b 100644
--- a/middleware/utils.go
+++ b/middleware/utils.go
@@ -1,9 +1,12 @@
package middleware
import (
+ "fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
+ "strings"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
@@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.Abort()
logger.Error(c.Request.Context(), message)
}
+
+func getRequestModel(c *gin.Context) (string, error) {
+ var modelRequest ModelRequest
+ err := common.UnmarshalBodyReusable(c, &modelRequest)
+ if err != nil {
+ return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = "text-moderation-stable"
+ }
+ }
+ if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = c.Param("model")
+ }
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = "dall-e-2"
+ }
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = "whisper-1"
+ }
+ }
+ return modelRequest.Model, nil
+}
+
+func isModelInList(modelName string, models string) bool {
+ modelList := strings.Split(models, ",")
+ for _, model := range modelList {
+ if modelName == model {
+ return true
+ }
+ }
+ return false
+}
diff --git a/model/ability.go b/model/ability.go
index 48b856a2..4a48bc51 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -1,8 +1,10 @@
package model
import (
+ "context"
"github.com/songquanpeng/one-api/common"
"gorm.io/gorm"
+ "sort"
"strings"
)
@@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error {
func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
+
+func GetGroupModels(ctx context.Context, group string) ([]string, error) {
+ groupCol := "`group`"
+ trueVal := "1"
+ if common.UsingPostgreSQL {
+ groupCol = `"group"`
+ trueVal = "true"
+ }
+ var models []string
+ err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error
+ if err != nil {
+ return nil, err
+ }
+ sort.Strings(models)
+ return models, err
+}
diff --git a/model/cache.go b/model/cache.go
index 244fe6ac..cfc5445a 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -21,6 +21,7 @@ var (
UserId2GroupCacheSeconds = config.SyncFrequency
UserId2QuotaCacheSeconds = config.SyncFrequency
UserId2StatusCacheSeconds = config.SyncFrequency
+ GroupModelsCacheSeconds = config.SyncFrequency
)
func CacheGetTokenByKey(key string) (*Token, error) {
@@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
+func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
+ if !common.RedisEnabled {
+ return GetGroupModels(ctx, group)
+ }
+ modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group))
+ if err == nil {
+ return strings.Split(modelsStr, ","), nil
+ }
+ models, err := GetGroupModels(ctx, group)
+ if err != nil {
+ return nil, err
+ }
+ err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second)
+ if err != nil {
+ logger.SysError("Redis set group models error: " + err.Error())
+ }
+ return models, nil
+}
+
var group2model2channels map[string]map[string][]*Channel
var channelSyncLock sync.RWMutex
diff --git a/model/log.go b/model/log.go
index 4409f73e..6b679c36 100644
--- a/model/log.go
+++ b/model/log.go
@@ -51,6 +51,21 @@ func RecordLog(userId int, logType int, content string) {
}
}
+func RecordTopupLog(userId int, content string, quota int) {
+ log := &Log{
+ UserId: userId,
+ Username: GetUsernameById(userId),
+ CreatedAt: helper.GetTimestamp(),
+ Type: LogTypeTopup,
+ Content: content,
+ Quota: quota,
+ }
+ err := LOG_DB.Create(log).Error
+ if err != nil {
+ logger.SysError("failed to record log: " + err.Error())
+ }
+}
+
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled {
diff --git a/model/option.go b/model/option.go
index 1d1c28b4..cee9bd3b 100644
--- a/model/option.go
+++ b/model/option.go
@@ -172,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) {
config.GitHubClientId = value
case "GitHubClientSecret":
config.GitHubClientSecret = value
+ case "LarkClientId":
+ config.LarkClientId = value
+ case "LarkClientSecret":
+ config.LarkClientSecret = value
case "Footer":
config.Footer = value
case "SystemName":
diff --git a/model/token.go b/model/token.go
index 493e27c9..20228ec5 100644
--- a/model/token.go
+++ b/model/token.go
@@ -12,24 +12,26 @@ import (
)
type Token struct {
- Id int `json:"id"`
- UserId int `json:"user_id"`
- Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
- Status int `json:"status" gorm:"default:1"`
- Name string `json:"name" gorm:"index" `
- CreatedTime int64 `json:"created_time" gorm:"bigint"`
- AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
- ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
- RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
- UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
- UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
+ Id int `json:"id"`
+ UserId int `json:"user_id"`
+ Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
+ Status int `json:"status" gorm:"default:1"`
+ Name string `json:"name" gorm:"index" `
+ CreatedTime int64 `json:"created_time" gorm:"bigint"`
+ AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
+ ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
+ RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
+ UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
+ UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
+ Models *string `json:"models" gorm:"default:''"` // allowed models
+ Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
}
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
var tokens []*Token
var err error
query := DB.Where("user_id = ?", userId)
-
+
switch order {
case "remain_quota":
query = query.Order("unlimited_quota desc, remain_quota desc")
@@ -38,7 +40,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token
default:
query = query.Order("id desc")
}
-
+
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
return tokens, err
}
@@ -61,7 +63,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
return nil, errors.New("令牌验证失败")
}
if token.Status == common.TokenStatusExhausted {
- return nil, errors.New("该令牌额度已用尽")
+ return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id)
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
@@ -121,7 +123,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
- err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
+ err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
return err
}
diff --git a/model/user.go b/model/user.go
index 5e729b5e..42d8f7b1 100644
--- a/model/user.go
+++ b/model/user.go
@@ -24,6 +24,7 @@ type User struct {
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
+ LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int64 `json:"quota" gorm:"bigint;default:0"`
@@ -41,21 +42,21 @@ func GetMaxUserId() int {
}
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
- query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
-
- switch order {
- case "quota":
- query = query.Order("quota desc")
- case "used_quota":
- query = query.Order("used_quota desc")
- case "request_count":
- query = query.Order("request_count desc")
- default:
- query = query.Order("id desc")
- }
-
- err = query.Find(&users).Error
- return users, err
+ query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
+
+ switch order {
+ case "quota":
+ query = query.Order("quota desc")
+ case "used_quota":
+ query = query.Order("used_quota desc")
+ case "request_count":
+ query = query.Order("request_count desc")
+ default:
+ query = query.Order("id desc")
+ }
+
+ err = query.Find(&users).Error
+ return users, err
}
func SearchUsers(keyword string) (users []*User, err error) {
@@ -206,6 +207,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
+func (user *User) FillUserByLarkId() error {
+ if user.LarkId == "" {
+ return errors.New("lark id 为空!")
+ }
+ DB.Where(User{LarkId: user.LarkId}).First(user)
+ return nil
+}
+
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -234,6 +243,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
+func IsLarkIdAlreadyTaken(githubId string) bool {
+ return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
+}
+
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go
index df28f82b..c1832b11 100644
--- a/relay/channel/ali/main.go
+++ b/relay/channel/ali/main.go
@@ -52,6 +52,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
MaxTokens: request.MaxTokens,
Temperature: request.Temperature,
TopP: request.TopP,
+ TopK: request.TopK,
+ ResultFormat: "message",
+ Tools: request.Tools,
},
}
}
@@ -132,19 +135,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
}
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
- choice := openai.TextResponseChoice{
- Index: 0,
- Message: model.Message{
- Role: "assistant",
- Content: response.Output.Text,
- },
- FinishReason: response.Output.FinishReason,
- }
fullTextResponse := openai.TextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: helper.GetTimestamp(),
- Choices: []openai.TextResponseChoice{choice},
+ Choices: response.Output.Choices,
Usage: model.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
@@ -155,10 +150,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
}
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
+ if len(aliResponse.Output.Choices) == 0 {
+ return nil
+ }
+ aliChoice := aliResponse.Output.Choices[0]
var choice openai.ChatCompletionsStreamResponseChoice
- choice.Delta.Content = aliResponse.Output.Text
- if aliResponse.Output.FinishReason != "null" {
- finishReason := aliResponse.Output.FinishReason
+ choice.Delta = aliChoice.Message
+ if aliChoice.FinishReason != "null" {
+ finishReason := aliChoice.FinishReason
choice.FinishReason = &finishReason
}
response := openai.ChatCompletionsStreamResponse{
@@ -219,6 +218,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
+ if response == nil {
+ return true
+ }
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
@@ -241,6 +243,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ ctx := c.Request.Context()
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -250,6 +253,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
+ logger.Debugf(ctx, "response body: %s\n", responseBody)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go
index 34515b16..6a83f0f4 100644
--- a/relay/channel/ali/model.go
+++ b/relay/channel/ali/model.go
@@ -1,5 +1,10 @@
package ali
+import (
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+)
+
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
@@ -11,13 +16,15 @@ type Input struct {
}
type Parameters struct {
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- Seed uint64 `json:"seed,omitempty"`
- EnableSearch bool `json:"enable_search,omitempty"`
- IncrementalOutput bool `json:"incremental_output,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Seed uint64 `json:"seed,omitempty"`
+ EnableSearch bool `json:"enable_search,omitempty"`
+ IncrementalOutput bool `json:"incremental_output,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ ResultFormat string `json:"result_format,omitempty"`
+ Tools []model.Tool `json:"tools,omitempty"`
}
type ChatRequest struct {
@@ -135,8 +142,9 @@ type Usage struct {
}
type Output struct {
- Text string `json:"text"`
- FinishReason string `json:"finish_reason"`
+ //Text string `json:"text"`
+ //FinishReason string `json:"finish_reason"`
+ Choices []openai.TextResponseChoice `json:"choices"`
}
type ChatResponse struct {
diff --git a/relay/channel/anthropic/main.go b/relay/channel/anthropic/main.go
index 3eeb0b2c..04e65d99 100644
--- a/relay/channel/anthropic/main.go
+++ b/relay/channel/anthropic/main.go
@@ -38,6 +38,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
MaxTokens: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
+ TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokens == 0 {
diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go
index 471f2dc5..6096eb31 100644
--- a/relay/channel/baidu/adaptor.go
+++ b/relay/channel/baidu/adaptor.go
@@ -38,16 +38,34 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
- case "ERNIE-3.5-8K":
- suffix += "completions"
- case "ERNIE-Bot-8K":
- suffix += "ernie_bot_8k"
case "ERNIE-Bot":
suffix += "completions"
- case "ERNIE-Speed":
- suffix += "ernie_speed"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
+ case "ERNIE-Speed":
+ suffix += "ernie_speed"
+ case "ERNIE-4.0-8K":
+ suffix += "completions_pro"
+ case "ERNIE-3.5-8K":
+ suffix += "completions"
+ case "ERNIE-3.5-8K-0205":
+ suffix += "ernie-3.5-8k-0205"
+ case "ERNIE-3.5-8K-1222":
+ suffix += "ernie-3.5-8k-1222"
+ case "ERNIE-Bot-8K":
+ suffix += "ernie_bot_8k"
+ case "ERNIE-3.5-4K-0205":
+ suffix += "ernie-3.5-4k-0205"
+ case "ERNIE-Speed-8K":
+ suffix += "ernie_speed"
+ case "ERNIE-Speed-128K":
+ suffix += "ernie-speed-128k"
+ case "ERNIE-Lite-8K-0922":
+ suffix += "eb-instant"
+ case "ERNIE-Lite-8K-0308":
+ suffix += "ernie-lite-8k"
+ case "ERNIE-Tiny-8K":
+ suffix += "ernie-tiny-8k"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
case "Embedding-V1":
@@ -59,7 +77,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
case "tao-8k":
suffix += "tao_8k"
default:
- suffix += meta.ActualModelName
+ suffix += strings.ToLower(meta.ActualModelName)
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
var accessToken string
diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go
index 45a4e901..f952adc6 100644
--- a/relay/channel/baidu/constants.go
+++ b/relay/channel/baidu/constants.go
@@ -1,11 +1,18 @@
package baidu
var ModelList = []string{
- "ERNIE-Bot-4",
+ "ERNIE-4.0-8K",
+ "ERNIE-3.5-8K",
+ "ERNIE-3.5-8K-0205",
+ "ERNIE-3.5-8K-1222",
"ERNIE-Bot-8K",
- "ERNIE-Bot",
- "ERNIE-Speed",
- "ERNIE-Bot-turbo",
+ "ERNIE-3.5-4K-0205",
+ "ERNIE-Speed-8K",
+ "ERNIE-Speed-128K",
+ "ERNIE-Lite-8K-0922",
+ "ERNIE-Lite-8K-0308",
+ "ERNIE-Tiny-8K",
+ "BLOOMZ-7B",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index bcbaf835..91b5fc2c 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -85,8 +85,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
- err, responseText, _ = StreamHandler(c, resp, meta.Mode)
- usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ err, responseText, usage = StreamHandler(c, resp, meta.Mode)
+ if usage == nil {
+ usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ }
} else {
switch meta.Mode {
case constant.RelayModeImagesGenerations:
diff --git a/relay/channel/openai/compatible.go b/relay/channel/openai/compatible.go
index e4951a34..2a1447ab 100644
--- a/relay/channel/openai/compatible.go
+++ b/relay/channel/openai/compatible.go
@@ -9,6 +9,7 @@ import (
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
+ "github.com/songquanpeng/one-api/relay/channel/stepfun"
)
var CompatibleChannels = []int{
@@ -20,6 +21,7 @@ var CompatibleChannels = []int{
common.ChannelTypeMistral,
common.ChannelTypeGroq,
common.ChannelTypeLingYiWanWu,
+ common.ChannelTypeStepFun,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
@@ -40,6 +42,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "groq", groq.ModelList
case common.ChannelTypeLingYiWanWu:
return "lingyiwanwu", lingyiwanwu.ModelList
+ case common.ChannelTypeStepFun:
+ return "stepfun", stepfun.ModelList
default:
return "openai", ModelList
}
diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go
index fa4ee25e..df3f0691 100644
--- a/relay/channel/openai/main.go
+++ b/relay/channel/openai/main.go
@@ -6,6 +6,7 @@ import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
@@ -53,7 +54,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
continue // just ignore the error
}
for _, choice := range streamResponse.Choices {
- responseText += choice.Delta.Content
+ responseText += conv.AsString(choice.Delta.Content)
}
if streamResponse.Usage != nil {
usage = streamResponse.Usage
diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go
index 31d9fe97..6e9c38f1 100644
--- a/relay/channel/openai/model.go
+++ b/relay/channel/openai/model.go
@@ -123,12 +123,9 @@ type ImageResponse struct {
}
type ChatCompletionsStreamResponseChoice struct {
- Index int `json:"index"`
- Delta struct {
- Content string `json:"content"`
- Role string `json:"role,omitempty"`
- } `json:"delta"`
- FinishReason *string `json:"finish_reason,omitempty"`
+ Index int `json:"index"`
+ Delta model.Message `json:"delta"`
+ FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {
diff --git a/relay/channel/stepfun/constants.go b/relay/channel/stepfun/constants.go
new file mode 100644
index 00000000..a82e562b
--- /dev/null
+++ b/relay/channel/stepfun/constants.go
@@ -0,0 +1,7 @@
+package stepfun
+
+var ModelList = []string{
+ "step-1-32k",
+ "step-1v-32k",
+ "step-1-200k",
+}
diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go
index cfdc0bfd..b5a64cde 100644
--- a/relay/channel/tencent/main.go
+++ b/relay/channel/tencent/main.go
@@ -10,6 +10,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
@@ -129,7 +130,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
- responseText += response.Choices[0].Delta.Content
+ responseText += conv.AsString(response.Choices[0].Delta.Content)
}
jsonResponse, err := json.Marshal(response)
if err != nil {
diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go
index 5e7014cb..67784a56 100644
--- a/relay/channel/xunfei/main.go
+++ b/relay/channel/xunfei/main.go
@@ -26,7 +26,11 @@ import (
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
+ var lastToolCalls []model.Tool
for _, message := range request.Messages {
+ if message.ToolCalls != nil {
+ lastToolCalls = message.ToolCalls
+ }
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
@@ -39,9 +43,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
+ if len(lastToolCalls) != 0 {
+ for _, toolCall := range lastToolCalls {
+ xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function)
+ }
+ }
+
return &xunfeiRequest
}
+func getToolCalls(response *ChatResponse) []model.Tool {
+ var toolCalls []model.Tool
+ if len(response.Payload.Choices.Text) == 0 {
+ return toolCalls
+ }
+ item := response.Payload.Choices.Text[0]
+ if item.FunctionCall == nil {
+ return toolCalls
+ }
+ toolCall := model.Tool{
+ Id: fmt.Sprintf("call_%s", helper.GetUUID()),
+ Type: "function",
+ Function: *item.FunctionCall,
+ }
+ toolCalls = append(toolCalls, toolCall)
+ return toolCalls
+}
+
func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []ChatResponseTextItem{
@@ -53,8 +81,9 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
- Role: "assistant",
- Content: response.Payload.Choices.Text[0].Content,
+ Role: "assistant",
+ Content: response.Payload.Choices.Text[0].Content,
+ ToolCalls: getToolCalls(response),
},
FinishReason: constant.StopFinishReason,
}
@@ -78,6 +107,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
+ choice.Delta.ToolCalls = getToolCalls(xunfeiResponse)
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &constant.StopFinishReason
}
diff --git a/relay/channel/xunfei/model.go b/relay/channel/xunfei/model.go
index 1266739d..97a43154 100644
--- a/relay/channel/xunfei/model.go
+++ b/relay/channel/xunfei/model.go
@@ -26,13 +26,18 @@ type ChatRequest struct {
Message struct {
Text []Message `json:"text"`
} `json:"message"`
+ Functions struct {
+ Text []model.Function `json:"text,omitempty"`
+ } `json:"functions,omitempty"`
} `json:"payload"`
}
type ChatResponseTextItem struct {
- Content string `json:"content"`
- Role string `json:"role"`
- Index int `json:"index"`
+ Content string `json:"content"`
+ Role string `json:"role"`
+ Index int `json:"index"`
+ ContentType string `json:"content_type"`
+ FunctionCall *model.Function `json:"function_call"`
}
type ChatResponse struct {
diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go
index 75ff977b..dbcf240d 100644
--- a/relay/channel/zhipu/adaptor.go
+++ b/relay/channel/zhipu/adaptor.go
@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
@@ -35,6 +36,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
if a.APIVersion == "v4" {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
}
+ if meta.Mode == constant.RelayModeEmbeddings {
+ return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
+ }
method := "invoke"
if meta.IsStream {
method = "sse-invoke"
@@ -53,18 +57,24 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
- // TopP (0.0, 1.0)
- request.TopP = math.Min(0.99, request.TopP)
- request.TopP = math.Max(0.01, request.TopP)
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
+ return baiduEmbeddingRequest, nil
+ default:
+ // TopP (0.0, 1.0)
+ request.TopP = math.Min(0.99, request.TopP)
+ request.TopP = math.Max(0.01, request.TopP)
- // Temperature (0.0, 1.0)
- request.Temperature = math.Min(0.99, request.Temperature)
- request.Temperature = math.Max(0.01, request.Temperature)
- a.SetVersionByModeName(request.Model)
- if a.APIVersion == "v4" {
- return request, nil
+ // Temperature (0.0, 1.0)
+ request.Temperature = math.Min(0.99, request.Temperature)
+ request.Temperature = math.Max(0.01, request.Temperature)
+ a.SetVersionByModeName(request.Model)
+ if a.APIVersion == "v4" {
+ return request, nil
+ }
+ return ConvertRequest(*request), nil
}
- return ConvertRequest(*request), nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
@@ -91,14 +101,26 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
if a.APIVersion == "v4" {
return a.DoResponseV4(c, resp, meta)
}
+
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
- err, usage = Handler(c, resp)
+ if meta.Mode == constant.RelayModeEmbeddings {
+ err, usage = EmbeddingsHandler(c, resp)
+ } else {
+ err, usage = Handler(c, resp)
+ }
}
return
}
+func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
+ return &EmbeddingRequest{
+ Model: "embedding-2",
+ Input: request.Input.(string),
+ }
+}
+
func (a *Adaptor) GetModelList() []string {
return ModelList
}
diff --git a/relay/channel/zhipu/constants.go b/relay/channel/zhipu/constants.go
index 1655a59d..2daeb19c 100644
--- a/relay/channel/zhipu/constants.go
+++ b/relay/channel/zhipu/constants.go
@@ -2,5 +2,5 @@ package zhipu
var ModelList = []string{
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
- "glm-4", "glm-4v", "glm-3-turbo",
+ "glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
}
diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go
index a46fd537..f54e0504 100644
--- a/relay/channel/zhipu/main.go
+++ b/relay/channel/zhipu/main.go
@@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
+
+func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var zhipuResponse EmbeddingRespone
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = json.Unmarshal(responseBody, &zhipuResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ return nil, &fullTextResponse.Usage
+}
+
+func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse {
+ openAIEmbeddingResponse := openai.EmbeddingResponse{
+ Object: "list",
+ Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
+ Model: response.Model,
+ Usage: model.Usage{
+ PromptTokens: response.PromptTokens,
+ CompletionTokens: response.CompletionTokens,
+ TotalTokens: response.Usage.TotalTokens,
+ },
+ }
+
+ for _, item := range response.Embeddings {
+ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
+ Object: `embedding`,
+ Index: item.Index,
+ Embedding: item.Embedding,
+ })
+ }
+ return &openAIEmbeddingResponse
+}
diff --git a/relay/channel/zhipu/model.go b/relay/channel/zhipu/model.go
index b63e1d6f..3c3a7443 100644
--- a/relay/channel/zhipu/model.go
+++ b/relay/channel/zhipu/model.go
@@ -44,3 +44,21 @@ type tokenData struct {
Token string
ExpiryTime time.Time
}
+
+type EmbeddingRequest struct {
+ Model string `json:"model"`
+ Input string `json:"input"`
+}
+
+type EmbeddingRespone struct {
+ Model string `json:"model"`
+ Object string `json:"object"`
+ Embeddings []EmbeddingData `json:"data"`
+ model.Usage `json:"usage"`
+}
+
+type EmbeddingData struct {
+ Index int `json:"index"`
+ Object string `json:"object"`
+ Embedding []float64 `json:"embedding"`
+}
diff --git a/relay/model/general.go b/relay/model/general.go
index fbcc04e8..30772894 100644
--- a/relay/model/general.go
+++ b/relay/model/general.go
@@ -5,25 +5,29 @@ type ResponseFormat struct {
}
type GeneralOpenAIRequest struct {
- Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
- Prompt any `json:"prompt,omitempty"`
- Stream bool `json:"stream,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- N int `json:"n,omitempty"`
- Input any `json:"input,omitempty"`
- Instruction string `json:"instruction,omitempty"`
- Size string `json:"size,omitempty"`
- Functions any `json:"functions,omitempty"`
+ Model string `json:"model,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
- Tools any `json:"tools,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
+ FunctionCall any `json:"function_call,omitempty"`
+ Functions any `json:"functions,omitempty"`
User string `json:"user,omitempty"`
+ Prompt any `json:"prompt,omitempty"`
+ Input any `json:"input,omitempty"`
+ EncodingFormat string `json:"encoding_format,omitempty"`
+ Dimensions int `json:"dimensions,omitempty"`
+ Instruction string `json:"instruction,omitempty"`
+ Size string `json:"size,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
diff --git a/relay/model/message.go b/relay/model/message.go
index c6c8a271..32a1055b 100644
--- a/relay/model/message.go
+++ b/relay/model/message.go
@@ -1,9 +1,10 @@
package model
type Message struct {
- Role string `json:"role"`
- Content any `json:"content"`
- Name *string `json:"name,omitempty"`
+ Role string `json:"role,omitempty"`
+ Content any `json:"content,omitempty"`
+ Name *string `json:"name,omitempty"`
+ ToolCalls []Tool `json:"tool_calls,omitempty"`
}
func (m Message) IsStringContent() bool {
diff --git a/relay/model/tool.go b/relay/model/tool.go
new file mode 100644
index 00000000..253dca35
--- /dev/null
+++ b/relay/model/tool.go
@@ -0,0 +1,14 @@
+package model
+
+type Tool struct {
+ Id string `json:"id,omitempty"`
+ Type string `json:"type"`
+ Function Function `json:"function"`
+}
+
+type Function struct {
+ Description string `json:"description,omitempty"`
+ Name string `json:"name"`
+ Parameters any `json:"parameters,omitempty"` // request
+ Arguments any `json:"arguments,omitempty"` // response
+}
diff --git a/relay/util/common.go b/relay/util/common.go
index 535ef680..5d787204 100644
--- a/relay/util/common.go
+++ b/relay/util/common.go
@@ -46,6 +46,15 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
+ //if strings.Contains(err.Message, "quota") {
+ // return true
+ //}
+ if strings.Contains(err.Message, "credit") {
+ return true
+ }
+ if strings.Contains(err.Message, "balance") {
+ return true
+ }
return false
}
diff --git a/router/api-router.go b/router/api-router.go
index 5b755ede..a36232b3 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -2,6 +2,7 @@ package router
import (
"github.com/songquanpeng/one-api/controller"
+ "github.com/songquanpeng/one-api/controller/auth"
"github.com/songquanpeng/one-api/middleware"
"github.com/gin-contrib/gzip"
@@ -21,11 +22,13 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
- apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
- apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
- apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
- apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
+ apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
+ apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
+ apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
+ apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)
+ apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
+ apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp)
userRoute := apiRouter.Group("/user")
{
@@ -43,6 +46,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp)
+ selfRoute.GET("/available_models", controller.GetUserAvailableModels)
}
adminRoute := userRoute.Group("/")
diff --git a/web/README.md b/web/README.md
index 29f4713e..829271e2 100644
--- a/web/README.md
+++ b/web/README.md
@@ -2,6 +2,9 @@
> 每个文件夹代表一个主题,欢迎提交你的主题
+> [!WARNING]
+> 不是每一个主题都及时同步了所有功能,由于精力有限,优先更新默认主题,其他主题欢迎 & 期待 PR
+
## 提交新的主题
> 欢迎在页面底部保留你和 One API 的版权信息以及指向链接
diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js
index ec049f7d..b74c58c7 100644
--- a/web/berry/src/constants/ChannelConstants.js
+++ b/web/berry/src/constants/ChannelConstants.js
@@ -107,6 +107,12 @@ export const CHANNEL_OPTIONS = {
value: 31,
color: 'primary'
},
+ 32: {
+ key: 32,
+ text: '阶跃星辰',
+ value: 32,
+ color: 'primary'
+ },
8: {
key: 8,
text: '自定义渠道',
diff --git a/web/berry/src/constants/SnackbarConstants.js b/web/berry/src/constants/SnackbarConstants.js
index a05c6652..19523da1 100644
--- a/web/berry/src/constants/SnackbarConstants.js
+++ b/web/berry/src/constants/SnackbarConstants.js
@@ -18,7 +18,7 @@ export const snackbarConstants = {
},
NOTICE: {
variant: 'info',
- autoHideDuration: 20000
+ autoHideDuration: 7000
}
},
Mobile: {
diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js
index 25e5c635..d8dabac3 100644
--- a/web/berry/src/utils/common.js
+++ b/web/berry/src/utils/common.js
@@ -51,9 +51,9 @@ export function showError(error) {
export function showNotice(message, isHTML = false) {
if (isHTML) {
- enqueueSnackbar(, getSnackbarOptions('INFO'));
+ enqueueSnackbar(, getSnackbarOptions('NOTICE'));
} else {
- enqueueSnackbar(message, getSnackbarOptions('INFO'));
+ enqueueSnackbar(message, getSnackbarOptions('NOTICE'));
}
}
diff --git a/web/berry/src/views/Channel/component/EditModal.js b/web/berry/src/views/Channel/component/EditModal.js
index 07111c97..cbf411b9 100644
--- a/web/berry/src/views/Channel/component/EditModal.js
+++ b/web/berry/src/views/Channel/component/EditModal.js
@@ -340,7 +340,9 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
},
}}
>
- {Object.values(CHANNEL_OPTIONS).map((option) => {
+ {Object.values(CHANNEL_OPTIONS).sort((a, b) => {
+ return a.text.localeCompare(b.text)
+ }).map((option) => {
return (