diff --git a/Dockerfile b/Dockerfile index 22055553..ffb8c21b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,10 @@ FROM node:16 as builder WORKDIR /build +COPY web/package.json . +RUN npm install COPY ./web . COPY ./VERSION . -RUN npm install RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build FROM golang AS builder2 @@ -13,9 +14,10 @@ ENV GO111MODULE=on \ GOOS=linux WORKDIR /build +ADD go.mod go.sum ./ +RUN go mod download COPY . . COPY --from=builder /build/build ./web/build -RUN go mod download RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api FROM alpine diff --git a/README.md b/README.md index a53c8b9d..e2979961 100644 --- a/README.md +++ b/README.md @@ -279,13 +279,13 @@ graph LR 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 + 例子: + MySQL:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` - + PostgreSQL:`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi` + + PostgreSQL:`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈) + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。 + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。 + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。 + 请根据你的数据库配置修改下列参数(或者保持默认值): - + `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `10`。 - + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `100`。 + + `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。 + + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 diff --git a/controller/relay-ali.go b/controller/relay-ali.go index e94abd6a..014f6b84 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -166,11 +166,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat } stopChan <- true }() - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + setEventStreamHeaders(c) lastResponseText := "" c.Stream(func(w io.Writer) bool { select { diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index 664bbd11..ad20d6d6 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -3,22 +3,22 @@ package controller import ( "bufio" "encoding/json" + "errors" + "fmt" "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "strings" + "sync" + "time" ) // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 type BaiduTokenResponse struct { - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` - SessionKey string `json:"session_key"` - AccessToken string `json:"access_token"` - Scope string `json:"scope"` - SessionSecret string `json:"session_secret"` + ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` } type BaiduMessage struct { @@ -73,6 +73,16 @@ type BaiduEmbeddingResponse struct { BaiduError } +type BaiduAccessToken struct { + AccessToken string `json:"access_token"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"-"` +} + +var baiduTokenStore sync.Map + func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { messages := make([]BaiduMessage, 0, len(request.Messages)) for _, message := range request.Messages { @@ -140,8 +150,12 @@ func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingR switch request.Input.(type) { case string: baiduEmbeddingRequest.Input = []string{request.Input.(string)} - case []string: - baiduEmbeddingRequest.Input = request.Input.([]string) + case []any: + for _, item := range request.Input.([]any) { + if str, ok := item.(string); ok { + baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str) + } + } } return &baiduEmbeddingRequest } @@ -191,11 +205,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt } stopChan <- true }() - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + setEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -299,3 +309,60 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func getBaiduAccessToken(apiKey string) (string, error) { + if val, ok := baiduTokenStore.Load(apiKey); ok { + var accessToken BaiduAccessToken + if accessToken, ok = val.(BaiduAccessToken); ok { + // soon this will expire + if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { + go func() { + _, _ = getBaiduAccessTokenHelper(apiKey) + }() + } + return accessToken.AccessToken, nil + } + } + accessToken, err := getBaiduAccessTokenHelper(apiKey) + if err != nil { + return "", err + } + if accessToken == nil { + return "", errors.New("getBaiduAccessToken return a nil token") + } + return (*accessToken).AccessToken, nil +} + +func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { + parts := strings.Split(apiKey, "|") + if len(parts) != 2 { + return nil, errors.New("invalid baidu apikey") + } + req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", + parts[0], parts[1]), nil) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + res, err := impatientHTTPClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var accessToken BaiduAccessToken + err = json.NewDecoder(res.Body).Decode(&accessToken) + if err != nil { + return nil, err + } + if accessToken.Error != "" { + return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) + } + if accessToken.AccessToken == "" { + return nil, errors.New("getBaiduAccessTokenHelper get empty access token") + } + accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) + baiduTokenStore.Store(apiKey, accessToken) + return &accessToken, nil +} diff --git a/controller/relay-claude.go b/controller/relay-claude.go index 052e5605..1f4a3e7b 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -141,11 +141,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS } stopChan <- true }() - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + setEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/controller/relay-openai.go b/controller/relay-openai.go index 298dbe95..6bdfbc08 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -66,11 +66,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O } stopChan <- true }() - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + setEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/controller/relay-palm.go b/controller/relay-palm.go index 0053c9b8..a705b318 100644 --- a/controller/relay-palm.go +++ b/controller/relay-palm.go @@ -143,11 +143,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta dataChan <- string(jsonResponse) stopChan <- true }() - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + setEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/controller/relay-text.go b/controller/relay-text.go index a3e28ff7..cb63822c 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/model" "strings" + "time" ) const ( @@ -25,9 +26,13 @@ const ( ) var httpClient *http.Client +var impatientHTTPClient *http.Client func init() { httpClient = &http.Client{} + impatientHTTPClient = &http.Client{ + Timeout: 5 * time.Second, + } } func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -154,7 +159,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") - fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days + var err error + if apiKey, err = getBaiduAccessToken(apiKey); err != nil { + return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) + } + fullRequestURL += "?access_token=" + apiKey case APITypePaLM: fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" if baseURL != "" { diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 5622d3e0..2df09eff 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -2,6 +2,7 @@ package controller import ( "fmt" + "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" "one-api/common" "reflect" @@ -123,3 +124,11 @@ func shouldDisableChannel(err *OpenAIError) bool { } return false } + +func setEventStreamHeaders(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") +} diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 48472456..87037e34 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -217,11 +217,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId } stopChan <- true }() - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + setEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case xunfeiResponse := <-dataChan: diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go index b125f1e7..7a4a582d 100644 --- a/controller/relay-zhipu.go +++ b/controller/relay-zhipu.go @@ -224,11 +224,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt } stopChan <- true }() - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + setEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/i18n/en.json b/i18n/en.json index 67ce8a56..a9402419 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -519,5 +519,6 @@ "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!", "代理": "Proxy", "此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com", - "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?" + "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?", + "按照如下格式输入:": "Enter in the following format:" } diff --git a/model/channel.go b/model/channel.go index b0d6e644..7cc9fa9b 100644 --- a/model/channel.go +++ b/model/channel.go @@ -141,7 +141,7 @@ func UpdateChannelStatusById(id int, status int) { } func UpdateChannelUsedQuota(id int, quota int) { - err := DB.Set("gorm:query_option", "FOR UPDATE").Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error + err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { common.SysError("failed to update channel used quota: " + err.Error()) } diff --git a/model/main.go b/model/main.go index 213db58c..d422c4e0 100644 --- a/model/main.go +++ b/model/main.go @@ -74,8 +74,8 @@ func InitDB() (err error) { if err != nil { return err } - sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 10)) - sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 100)) + sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) if !common.IsMasterNode { diff --git a/model/token.go b/model/token.go index 0e2395ad..7cd226c6 100644 --- a/model/token.go +++ b/model/token.go @@ -131,7 +131,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - err = DB.Set("gorm:query_option", "FOR UPDATE").Model(&Token{}).Where("id = ?", id).Updates( + err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), "used_quota": gorm.Expr("used_quota - ?", quota), @@ -144,7 +144,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - err = DB.Set("gorm:query_option", "FOR UPDATE").Model(&Token{}).Where("id = ?", id).Updates( + err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota), diff --git a/model/user.go b/model/user.go index c7080450..7c771840 100644 --- a/model/user.go +++ b/model/user.go @@ -275,7 +275,7 @@ func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - err = DB.Set("gorm:query_option", "FOR UPDATE").Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error + err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error return err } @@ -283,7 +283,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - err = DB.Set("gorm:query_option", "FOR UPDATE").Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error + err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } @@ -293,7 +293,7 @@ func GetRootUserEmail() (email string) { } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { - err := DB.Set("gorm:query_option", "FOR UPDATE").Model(&User{}).Where("id = ?", id).Updates( + err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), "request_count": gorm.Expr("request_count + ?", 1), diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 072f5b90..5eb39783 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -368,7 +368,7 @@ const ChannelsTable = () => { }} style={{ cursor: 'pointer' }}> {renderBalance(channel.type, channel.balance)} } - content="点击更新" + content='点击更新' basic /> @@ -447,8 +447,8 @@ const ChannelsTable = () => { - {/* */} + { label='密钥' name='key' required - placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} + placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} onChange={handleInputChange} value={inputs.key} autoComplete='new-password'