Merge branch 'main' into feat/add_ali_embedding

This commit is contained in:
JustSong 2023-09-03 22:11:32 +08:00 committed by GitHub
commit 24fc299a43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 592 additions and 108 deletions

View File

@ -109,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
@ -275,8 +277,9 @@ graph LR
不加的话将会使用负载均衡的方式使用多个渠道。 不加的话将会使用负载均衡的方式使用多个渠道。
### 环境变量 ### 环境变量
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
+ 如果数据库访问延迟很低,没有必要启用 Redis启用后反而会出现数据滞后的问题。
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
+ 例子:`SESSION_SECRET=random_string` + 例子:`SESSION_SECRET=random_string`
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 或 PostgreSQL。 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 或 PostgreSQL。
@ -303,6 +306,14 @@ graph LR
+ 例子:`CHANNEL_TEST_FREQUENCY=1440` + 例子:`CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5` + 例子:`POLLING_INTERVAL=5`
10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`
+ 例子:`BATCH_UPDATE_INTERVAL=5`
12. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000` 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`
@ -339,6 +350,7 @@ https://openai.justsong.cn
5. ChatGPT Next Web 报错:`Failed to fetch` 5. ChatGPT Next Web 报错:`Failed to fetch`
+ 部署的时候不要设置 `BASE_URL` + 部署的时候不要设置 `BASE_URL`
+ 检查你的接口地址和 API Key 有没有填对。 + 检查你的接口地址和 API Key 有没有填对。
+ 检查是否启用了 HTTPS浏览器会拦截 HTTPS 域名下的 HTTP 请求。
6. 报错:`当前分组负载已饱和,请稍后再试` 6. 报错:`当前分组负载已饱和,请稍后再试`
+ 上游通道 429 了。 + 上游通道 429 了。

View File

@ -94,6 +94,9 @@ var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
const ( const (
RoleGuestUser = 0 RoleGuestUser = 0
RoleCommonUser = 1 RoleCommonUser = 1
@ -111,10 +114,10 @@ var (
// All duration's unit is seconds // All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration // Shouldn't larger then RateLimitKeyExpirationDuration
var ( var (
GlobalApiRateLimitNum = 180 GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60 GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = 60 GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60 GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10 UploadRateLimitNum = 10
@ -154,49 +157,53 @@ const (
) )
const ( const (
ChannelTypeUnknown = 0 ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1 ChannelTypeOpenAI = 1
ChannelTypeAPI2D = 2 ChannelTypeAPI2D = 2
ChannelTypeAzure = 3 ChannelTypeAzure = 3
ChannelTypeCloseAI = 4 ChannelTypeCloseAI = 4
ChannelTypeOpenAISB = 5 ChannelTypeOpenAISB = 5
ChannelTypeOpenAIMax = 6 ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7 ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8 ChannelTypeCustom = 8
ChannelTypeAILS = 9 ChannelTypeAILS = 9
ChannelTypeAIProxy = 10 ChannelTypeAIProxy = 10
ChannelTypePaLM = 11 ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12 ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13 ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14 ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15 ChannelTypeBaidu = 15
ChannelTypeZhipu = 16 ChannelTypeZhipu = 16
ChannelTypeAli = 17 ChannelTypeAli = 17
ChannelTypeXunfei = 18 ChannelTypeXunfei = 18
ChannelType360 = 19 ChannelType360 = 19
ChannelTypeOpenRouter = 20 ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{
"", // 0 "", // 0
"https://api.openai.com", // 1 "https://api.openai.com", // 1
"https://oa.api2d.net", // 2 "https://oa.api2d.net", // 2
"", // 3 "", // 3
"https://api.closeai-proxy.xyz", // 4 "https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5 "https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6 "https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7 "https://api.ohmygpt.com", // 7
"", // 8 "", // 8
"https://api.caipacity.com", // 9 "https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10 "https://api.aiproxy.io", // 10
"", // 11 "", // 11
"https://api.api2gpt.com", // 12 "https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13 "https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14 "https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15 "https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16 "https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17 "https://dashscope.aliyuncs.com", // 17
"", // 18 "", // 18
"https://ai.360.cn", // 19 "https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20 "https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
} }

View File

@ -14,7 +14,7 @@ import (
"time" "time"
) )
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
switch channel.Type { switch channel.Type {
case common.ChannelTypePaLM: case common.ChannelTypePaLM:
fallthrough fallthrough
@ -32,6 +32,11 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
request.Model = "gpt-35-turbo" request.Model = "gpt-35-turbo"
defer func() {
if err != nil {
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
}
}()
default: default:
request.Model = "gpt-3.5-turbo" request.Model = "gpt-3.5-turbo"
} }

View File

@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) {
} }
channel.CreatedTime = common.GetTimestamp() channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n") keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0) channels := make([]model.Channel, 0, len(keys))
for _, key := range keys { for _, key := range keys {
if key == "" { if key == "" {
continue continue

220
controller/relay-aiproxy.go Normal file
View File

@ -0,0 +1,220 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strconv"
"strings"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
type AIProxyLibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type AIProxyLibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type AIProxyLibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type AIProxyLibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []AIProxyLibraryDocument `json:"documents"`
AIProxyLibraryError
}
type AIProxyLibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []AIProxyLibraryDocument `json:"documents"`
}
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].Content
}
return &AIProxyLibraryRequest{
Model: request.Model,
Stream: request.Stream,
Query: query,
}
}
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
if len(documents) == 0 {
return ""
}
content := "\n\n参考文档\n"
for i, document := range documents {
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
}
return content
}
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := OpenAITextResponse{
Id: common.GetUUID(),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
}
return &fullTextResponse
}
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &stopFinishReason
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
}
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: response.Model,
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
}
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
var documents []AIProxyLibraryDocument
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var AIProxyLibraryResponse AIProxyLibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@ -22,6 +22,7 @@ const (
APITypeZhipu APITypeZhipu
APITypeAli APITypeAli
APITypeXunfei APITypeXunfei
APITypeAIProxyLibrary
) )
var httpClient *http.Client var httpClient *http.Client
@ -104,6 +105,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeAli apiType = APITypeAli
case common.ChannelTypeXunfei: case common.ChannelTypeXunfei:
apiType = APITypeXunfei apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
} }
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
@ -174,6 +177,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if relayMode == RelayModeEmbeddings { if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
} }
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
} }
var promptTokens int var promptTokens int
var completionTokens int var completionTokens int
@ -274,6 +279,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
} }
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
case APITypeAIProxyLibrary:
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} }
var req *http.Request var req *http.Request
@ -313,6 +326,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if textRequest.Stream { if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable") req.Header.Set("X-DashScope-SSE", "enable")
} }
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
} }
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
@ -373,7 +388,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
} }
} }
@ -534,6 +548,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} else { } else {
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
} }
case APITypeAIProxyLibrary:
if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := aiProxyLibraryHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default: default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
} }

View File

@ -523,5 +523,6 @@
"按照如下格式输入:": "Enter in the following format:", "按照如下格式输入:": "Enter in the following format:",
"模型版本": "Model version", "模型版本": "Model version",
"请输入星火大模型版本注意是接口地址中的版本号例如v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", "请输入星火大模型版本注意是接口地址中的版本号例如v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
"点击查看": "click to view" "点击查看": "click to view",
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
} }

View File

@ -77,6 +77,11 @@ func main() {
} }
go controller.AutomaticallyTestChannels(frequency) go controller.AutomaticallyTestChannels(frequency)
} }
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
model.InitBatchUpdater()
}
controller.InitTokenEncoders() controller.InitTokenEncoders()
// Initialize HTTP server // Initialize HTTP server

View File

@ -100,7 +100,18 @@ func TokenAuth() func(c *gin.Context) {
c.Abort() c.Abort()
return return
} }
if !model.CacheIsUserEnabled(token.UserId) { userEnabled, err := model.IsUserEnabled(token.UserId)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
return
}
if !userEnabled {
c.JSON(http.StatusForbidden, gin.H{ c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{ "error": gin.H{
"message": "用户已被封禁", "message": "用户已被封禁",

View File

@ -115,8 +115,13 @@ func Distribute() func(c *gin.Context) {
c.Set("model_mapping", channel.ModelMapping) c.Set("model_mapping", channel.ModelMapping)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.BaseURL) c.Set("base_url", channel.BaseURL)
if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei { switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
} }
c.Next() c.Next()
} }

View File

@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error {
return err return err
} }
func CacheIsUserEnabled(userId int) bool { func CacheIsUserEnabled(userId int) (bool, error) {
if !common.RedisEnabled { if !common.RedisEnabled {
return IsUserEnabled(userId) return IsUserEnabled(userId)
} }
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
if err != nil { if err == nil {
status := common.UserStatusDisabled return enabled == "1", nil
if IsUserEnabled(userId) {
status = common.UserStatusEnabled
}
enabled = fmt.Sprintf("%d", status)
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user enabled error: " + err.Error())
}
} }
return enabled == "1"
userEnabled, err := IsUserEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if userEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user enabled error: " + err.Error())
}
return userEnabled, err
} }
var group2model2channels map[string]map[string][]*Channel var group2model2channels map[string]map[string][]*Channel

View File

@ -141,6 +141,14 @@ func UpdateChannelStatusById(id int, status int) {
} }
func UpdateChannelUsedQuota(id int, quota int) { func UpdateChannelUsedQuota(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return
}
updateChannelUsedQuota(id, quota)
}
func updateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil { if err != nil {
common.SysError("failed to update channel used quota: " + err.Error()) common.SysError("failed to update channel used quota: " + err.Error())

View File

@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) {
} }
token, err = CacheGetTokenByKey(key) token, err = CacheGetTokenByKey(key)
if err == nil { if err == nil {
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽")
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled { if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用") return nil, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
token.Status = common.TokenStatusExpired if !common.RedisEnabled {
err := token.SelectUpdate() token.Status = common.TokenStatusExpired
if err != nil { err := token.SelectUpdate()
common.SysError("failed to update token status" + err.Error()) if err != nil {
common.SysError("failed to update token status" + err.Error())
}
} }
return nil, errors.New("该令牌已过期") return nil, errors.New("该令牌已过期")
} }
if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !token.UnlimitedQuota && token.RemainQuota <= 0 {
token.Status = common.TokenStatusExhausted if !common.RedisEnabled {
err := token.SelectUpdate() // in this case, we can make sure the token is exhausted
if err != nil { token.Status = common.TokenStatusExhausted
common.SysError("failed to update token status" + err.Error()) err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
}
} }
return nil, errors.New("该令牌额度已用尽") return nil, errors.New("该令牌额度已用尽")
} }
go func() {
token.AccessedTime = common.GetTimestamp()
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token" + err.Error())
}
}()
return token, nil return token, nil
} }
return nil, errors.New("无效的令牌") return nil, errors.New("无效的令牌")
@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil
}
return increaseTokenQuota(id, quota)
}
func increaseTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates( err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota), "remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota), "used_quota": gorm.Expr("used_quota - ?", quota),
"accessed_time": common.GetTimestamp(),
}, },
).Error ).Error
return err return err
@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil
}
return decreaseTokenQuota(id, quota)
}
func decreaseTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates( err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota), "remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota),
"accessed_time": common.GetTimestamp(),
}, },
).Error ).Error
return err return err

View File

@ -226,17 +226,16 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser return user.Role >= common.RoleAdminUser
} }
func IsUserEnabled(userId int) bool { func IsUserEnabled(userId int) (bool, error) {
if userId == 0 { if userId == 0 {
return false return false, errors.New("user id is empty")
} }
var user User var user User
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
if err != nil { if err != nil {
common.SysError("no such user " + err.Error()) return false, err
return false
} }
return user.Status == common.UserStatusEnabled return user.Status == common.UserStatusEnabled, nil
} }
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
@ -275,6 +274,14 @@ func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
return increaseUserQuota(id, quota)
}
func increaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
return err return err
} }
@ -283,6 +290,14 @@ func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil
}
return decreaseUserQuota(id, quota)
}
func decreaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err return err
} }
@ -293,10 +308,18 @@ func GetRootUserEmail() (email string) {
} }
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
return
}
updateUserUsedQuotaAndRequestCount(id, quota, 1)
}
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates( err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{ map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota),
"request_count": gorm.Expr("request_count + ?", 1), "request_count": gorm.Expr("request_count + ?", count),
}, },
).Error ).Error
if err != nil { if err != nil {

75
model/utils.go Normal file
View File

@ -0,0 +1,75 @@
package model
import (
"one-api/common"
"sync"
"time"
)
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
const (
BatchUpdateTypeUserQuota = iota
BatchUpdateTypeTokenQuota
BatchUpdateTypeUsedQuotaAndRequestCount
BatchUpdateTypeChannelUsedQuota
)
var batchUpdateStores []map[int]int
var batchUpdateLocks []sync.Mutex
func init() {
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
}
}
func InitBatchUpdater() {
go func() {
for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
batchUpdate()
}
}()
}
func addNewRecord(type_ int, id int, value int) {
batchUpdateLocks[type_].Lock()
defer batchUpdateLocks[type_].Unlock()
if _, ok := batchUpdateStores[type_][id]; !ok {
batchUpdateStores[type_][id] = value
} else {
batchUpdateStores[type_][id] += value
}
}
func batchUpdate() {
common.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int)
batchUpdateLocks[i].Unlock()
for key, value := range store {
switch i {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
common.SysError("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
common.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuotaAndRequestCount:
updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value)
}
}
}
common.SysLog("batch update finished")
}

View File

@ -324,7 +324,7 @@ const LogsTable = () => {
.map((log, idx) => { .map((log, idx) => {
if (log.deleted) return <></>; if (log.deleted) return <></>;
return ( return (
<Table.Row key={log.created_at}> <Table.Row key={log.id}>
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell> <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
{ {
isAdminUser && ( isAdminUser && (

View File

@ -9,6 +9,8 @@ export const CHANNEL_OPTIONS = [
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue' }, { key: 19, text: '360 智脑', value: 19, color: 'blue' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },
{ key: 20, text: '代理OpenRouter', value: 20, color: 'black' }, { key: 20, text: '代理OpenRouter', value: 20, color: 'black' },
{ key: 2, text: '代理API2D', value: 2, color: 'blue' }, { key: 2, text: '代理API2D', value: 2, color: 'blue' },
{ key: 5, text: '代理OpenAI-SB', value: 5, color: 'brown' }, { key: 5, text: '代理OpenAI-SB', value: 5, color: 'brown' },

View File

@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = {
'gpt-4-32k-0314': 'gpt-4-32k' 'gpt-4-32k-0314': 'gpt-4-32k'
}; };
function type2secretPrompt(type) {
// inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
switch (type) {
case 15:
return '按照如下格式输入APIKey|SecretKey';
case 18:
return '按照如下格式输入APPID|APISecret|APIKey';
case 22:
return '按照如下格式输入APIKey-AppId例如fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
default:
return '请输入渠道对应的鉴权密钥';
}
}
const EditChannel = () => { const EditChannel = () => {
const params = useParams(); const params = useParams();
const navigate = useNavigate(); const navigate = useNavigate();
@ -193,6 +207,24 @@ const EditChannel = () => {
} }
}; };
const addCustomModel = () => {
if (customModel.trim() === '') return;
if (inputs.models.includes(customModel)) return;
let localModels = [...inputs.models];
localModels.push(customModel);
let localModelOptions = [];
localModelOptions.push({
key: customModel,
text: customModel,
value: customModel
});
setModelOptions(modelOptions => {
return [...modelOptions, ...localModelOptions];
});
setCustomModel('');
handleInputChange(null, { name: 'models', value: localModels });
};
return ( return (
<> <>
<Segment loading={loading}> <Segment loading={loading}>
@ -295,6 +327,20 @@ const EditChannel = () => {
</Form.Field> </Form.Field>
) )
} }
{
inputs.type === 21 && (
<Form.Field>
<Form.Input
label='知识库 ID'
name='other'
placeholder={'请输入知识库 ID例如123456'}
onChange={handleInputChange}
value={inputs.other}
autoComplete='new-password'
/>
</Form.Field>
)
}
<Form.Field> <Form.Field>
<Form.Dropdown <Form.Dropdown
label='模型' label='模型'
@ -322,29 +368,19 @@ const EditChannel = () => {
}}>清除所有模型</Button> }}>清除所有模型</Button>
<Input <Input
action={ action={
<Button type={'button'} onClick={() => { <Button type={'button'} onClick={addCustomModel}>填入</Button>
if (customModel.trim() === '') return;
if (inputs.models.includes(customModel)) return;
let localModels = [...inputs.models];
localModels.push(customModel);
let localModelOptions = [];
localModelOptions.push({
key: customModel,
text: customModel,
value: customModel
});
setModelOptions(modelOptions => {
return [...modelOptions, ...localModelOptions];
});
setCustomModel('');
handleInputChange(null, { name: 'models', value: localModels });
}}>填入</Button>
} }
placeholder='输入自定义模型名称' placeholder='输入自定义模型名称'
value={customModel} value={customModel}
onChange={(e, { value }) => { onChange={(e, { value }) => {
setCustomModel(value); setCustomModel(value);
}} }}
onKeyDown={(e) => {
if (e.key === 'Enter') {
addCustomModel();
e.preventDefault();
}
}}
/> />
</div> </div>
<Form.Field> <Form.Field>
@ -375,7 +411,7 @@ const EditChannel = () => {
label='密钥' label='密钥'
name='key' name='key'
required required
placeholder={inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} placeholder={type2secretPrompt(inputs.type)}
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.key} value={inputs.key}
autoComplete='new-password' autoComplete='new-password'
@ -393,7 +429,7 @@ const EditChannel = () => {
) )
} }
{ {
inputs.type !== 3 && inputs.type !== 8 && ( inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
<Form.Field> <Form.Field>
<Form.Input <Form.Input
label='代理' label='代理'
@ -406,6 +442,20 @@ const EditChannel = () => {
</Form.Field> </Form.Field>
) )
} }
{
inputs.type === 22 && (
<Form.Field>
<Form.Input
label='私有部署地址'
name='base_url'
placeholder={'请输入私有部署地址格式为https://fastgpt.run/api/openapi'}
onChange={handleInputChange}
value={inputs.base_url}
autoComplete='new-password'
/>
</Form.Field>
)
}
<Button onClick={handleCancel}>取消</Button> <Button onClick={handleCancel}>取消</Button>
<Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button> <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
</Form> </Form>