diff --git a/README.md b/README.md index 823e2522..c972c600 100644 --- a/README.md +++ b/README.md @@ -51,15 +51,17 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用 + [x] **Azure OpenAI API** + [x] [API2D](https://api2d.com/r/197971) + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) - + [x] [CloseAI](https://console.openai-asia.com) - + [x] [OpenAI-SB](https://openai-sb.com) + + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) + + [x] [AI.LS](https://ai.ls) + [x] [OpenAI Max](https://openaimax.com) + + [x] [OpenAI-SB](https://openai-sb.com) + + [x] [CloseAI](https://console.openai-asia.com/r/2412) + [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理 2. 支持通过**负载均衡**的方式访问多个渠道。 3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 4. 支持**多机部署**,[详见此处](#多机部署)。 5. 支持**令牌管理**,设置令牌的过期时间和使用次数。 -6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为令牌进行充值。 +6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 7. 支持**通道管理**,批量创建通道。 8. 支持发布公告,设置充值链接,设置新用户初始额度。 9. 支持丰富的**自定义**设置, diff --git a/common/constants.go b/common/constants.go index 5cb55dfb..7c1ff298 100644 --- a/common/constants.go +++ b/common/constants.go @@ -127,6 +127,9 @@ const ( ChannelTypeOpenAIMax = 6 ChannelTypeOhMyGPT = 7 ChannelTypeCustom = 8 + ChannelTypeAILS = 9 + ChannelTypeAIProxy = 10 + ChannelTypePaLM = 11 ) var ChannelBaseURLs = []string{ @@ -139,4 +142,7 @@ var ChannelBaseURLs = []string{ "https://api.openaimax.com", // 6 "https://api.ohmygpt.com", // 7 "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "", // 11 } diff --git a/controller/billing.go b/controller/billing.go new file mode 100644 index 00000000..2f0d90fe --- /dev/null +++ b/controller/billing.go @@ -0,0 +1,41 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "one-api/model" +) + +func GetSubscription(c *gin.Context) { + userId := c.GetInt("id") + quota, err := model.GetUserQuota(userId) + if err != nil { + openAIError := OpenAIError{ + Message: err.Error(), + Type: "one_api_error", + } + c.JSON(200, gin.H{ + "error": openAIError, + }) + return + } + subscription := OpenAISubscriptionResponse{ + Object: "billing_subscription", + HasPaymentMethod: true, + SoftLimitUSD: float64(quota), + HardLimitUSD: float64(quota), + SystemHardLimitUSD: float64(quota), + } + c.JSON(200, subscription) + return +} + +func GetUsage(c *gin.Context) { + //userId := c.GetInt("id") + // TODO: get usage from database + usage := OpenAIUsageResponse{ + Object: "list", + TotalUsage: 0, + } + c.JSON(200, usage) + return +} diff --git a/controller/channel-billing.go b/controller/channel-billing.go new file mode 100644 index 00000000..e135e5fc --- /dev/null +++ b/controller/channel-billing.go @@ -0,0 +1,179 @@ +package controller + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + "time" +) + +// https://github.com/songquanpeng/one-api/issues/79 + +type OpenAISubscriptionResponse struct { + Object string `json:"object"` + HasPaymentMethod bool `json:"has_payment_method"` + SoftLimitUSD float64 `json:"soft_limit_usd"` + HardLimitUSD float64 `json:"hard_limit_usd"` + SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` +} + +type OpenAIUsageDailyCost struct { + Timestamp float64 `json:"timestamp"` + LineItems []struct { + Name string `json:"name"` + Cost float64 `json:"cost"` + } +} + +type OpenAIUsageResponse struct { + Object string `json:"object"` + //DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"` + TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar +} + +func updateChannelBalance(channel *model.Channel) (float64, error) { + baseURL := common.ChannelBaseURLs[channel.Type] + switch channel.Type { + case common.ChannelTypeAzure: + return 0, errors.New("尚未实现") + case common.ChannelTypeCustom: + baseURL = channel.BaseURL + } + url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) + + client := &http.Client{} + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return 0, err + } + auth := fmt.Sprintf("Bearer %s", channel.Key) + req.Header.Add("Authorization", auth) + res, err := client.Do(req) + if err != nil { + return 0, err + } + body, err := io.ReadAll(res.Body) + if err != nil { + return 0, err + } + err = res.Body.Close() + if err != nil { + return 0, err + } + subscription := OpenAISubscriptionResponse{} + err = json.Unmarshal(body, &subscription) + if err != nil { + return 0, err + } + now := time.Now() + startDate := fmt.Sprintf("%s-01", now.Format("2006-01")) + //endDate := now.Format("2006-01-02") + url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, "2023-06-01") + req, err = http.NewRequest("GET", url, nil) + if err != nil { + return 0, err + } + req.Header.Add("Authorization", auth) + res, err = client.Do(req) + if err != nil { + return 0, err + } + body, err = io.ReadAll(res.Body) + if err != nil { + return 0, err + } + err = res.Body.Close() + if err != nil { + return 0, err + } + usage := OpenAIUsageResponse{} + err = json.Unmarshal(body, &usage) + if err != nil { + return 0, err + } + balance := subscription.HardLimitUSD - usage.TotalUsage/100 + channel.UpdateBalance(balance) + return balance, nil +} + +func UpdateChannelBalance(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + balance, err := updateChannelBalance(channel) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "balance": balance, + }) + return +} + +func updateAllChannelsBalance() error { + channels, err := model.GetAllChannels(0, 0, true) + if err != nil { + return err + } + for _, channel := range channels { + if channel.Status != common.ChannelStatusEnabled { + continue + } + // TODO: support Azure + if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { + continue + } + balance, err := updateChannelBalance(channel) + if err != nil { + continue + } else { + // err is nil & balance <= 0 means quota is used up + if balance <= 0 { + disableChannel(channel.Id, channel.Name, "余额不足") + } + } + } + return nil +} + +func UpdateAllChannelsBalance(c *gin.Context) { + // TODO: make it async + err := updateAllChannelsBalance() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/controller/channel-test.go b/controller/channel-test.go new file mode 100644 index 00000000..0d32c8c6 --- /dev/null +++ b/controller/channel-test.go @@ -0,0 +1,199 @@ +package controller + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + "sync" + "time" +) + +func testChannel(channel *model.Channel, request *ChatRequest) error { + if request.Model == "" { + request.Model = "gpt-3.5-turbo" + if channel.Type == common.ChannelTypeAzure { + request.Model = "gpt-35-turbo" + } + } + requestURL := common.ChannelBaseURLs[channel.Type] + if channel.Type == common.ChannelTypeAzure { + requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) + } else { + if channel.Type == common.ChannelTypeCustom { + requestURL = channel.BaseURL + } + requestURL += "/v1/chat/completions" + } + + jsonData, err := json.Marshal(request) + if err != nil { + return err + } + req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) + if err != nil { + return err + } + if channel.Type == common.ChannelTypeAzure { + req.Header.Set("api-key", channel.Key) + } else { + req.Header.Set("Authorization", "Bearer "+channel.Key) + } + req.Header.Set("Content-Type", "application/json") + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + var response TextResponse + err = json.NewDecoder(resp.Body).Decode(&response) + if err != nil { + return err + } + if response.Error.Message != "" || response.Error.Code != "" { + return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) + } + return nil +} + +func buildTestRequest(c *gin.Context) *ChatRequest { + model_ := c.Query("model") + testRequest := &ChatRequest{ + Model: model_, + MaxTokens: 1, + } + testMessage := Message{ + Role: "user", + Content: "hi", + } + testRequest.Messages = append(testRequest.Messages, testMessage) + return testRequest +} + +func TestChannel(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + testRequest := buildTestRequest(c) + tik := time.Now() + err = testChannel(channel, testRequest) + tok := time.Now() + milliseconds := tok.Sub(tik).Milliseconds() + go channel.UpdateResponseTime(milliseconds) + consumedTime := float64(milliseconds) / 1000.0 + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + "time": consumedTime, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "time": consumedTime, + }) + return +} + +var testAllChannelsLock sync.Mutex +var testAllChannelsRunning bool = false + +// disable & notify +func disableChannel(channelId int, channelName string, reason string) { + if common.RootUserEmail == "" { + common.RootUserEmail = model.GetRootUserEmail() + } + model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + err := common.SendEmail(subject, common.RootUserEmail, content) + if err != nil { + common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) + } +} + +func testAllChannels(c *gin.Context) error { + testAllChannelsLock.Lock() + if testAllChannelsRunning { + testAllChannelsLock.Unlock() + return errors.New("测试已在运行中") + } + testAllChannelsRunning = true + testAllChannelsLock.Unlock() + channels, err := model.GetAllChannels(0, 0, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return err + } + testRequest := buildTestRequest(c) + var disableThreshold = int64(common.ChannelDisableThreshold * 1000) + if disableThreshold == 0 { + disableThreshold = 10000000 // a impossible value + } + go func() { + for _, channel := range channels { + if channel.Status != common.ChannelStatusEnabled { + continue + } + tik := time.Now() + err := testChannel(channel, testRequest) + tok := time.Now() + milliseconds := tok.Sub(tik).Milliseconds() + if err != nil || milliseconds > disableThreshold { + if milliseconds > disableThreshold { + err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + } + disableChannel(channel.Id, channel.Name, err.Error()) + } + channel.UpdateResponseTime(milliseconds) + } + err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") + if err != nil { + common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) + } + testAllChannelsLock.Lock() + testAllChannelsRunning = false + testAllChannelsLock.Unlock() + }() + return nil +} + +func TestAllChannels(c *gin.Context) { + err := testAllChannels(c) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/controller/channel.go b/controller/channel.go index 3f047546..8afc0eed 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -1,18 +1,12 @@ package controller import ( - "bytes" - "encoding/json" - "errors" - "fmt" "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "strings" - "sync" - "time" ) func GetAllChannels(c *gin.Context) { @@ -158,187 +152,3 @@ func UpdateChannel(c *gin.Context) { }) return } - -func testChannel(channel *model.Channel, request *ChatRequest) error { - if request.Model == "" { - request.Model = "gpt-3.5-turbo" - if channel.Type == common.ChannelTypeAzure { - request.Model = "gpt-35-turbo" - } - } - requestURL := common.ChannelBaseURLs[channel.Type] - if channel.Type == common.ChannelTypeAzure { - requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) - } else { - if channel.Type == common.ChannelTypeCustom { - requestURL = channel.BaseURL - } - requestURL += "/v1/chat/completions" - } - - jsonData, err := json.Marshal(request) - if err != nil { - return err - } - req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) - if err != nil { - return err - } - if channel.Type == common.ChannelTypeAzure { - req.Header.Set("api-key", channel.Key) - } else { - req.Header.Set("Authorization", "Bearer "+channel.Key) - } - req.Header.Set("Content-Type", "application/json") - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - var response TextResponse - err = json.NewDecoder(resp.Body).Decode(&response) - if err != nil { - return err - } - if response.Error.Message != "" { - return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) - } - return nil -} - -func buildTestRequest(c *gin.Context) *ChatRequest { - model_ := c.Query("model") - testRequest := &ChatRequest{ - Model: model_, - MaxTokens: 1, - } - testMessage := Message{ - Role: "user", - Content: "hi", - } - testRequest.Messages = append(testRequest.Messages, testMessage) - return testRequest -} - -func TestChannel(c *gin.Context) { - id, err := strconv.Atoi(c.Param("id")) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - channel, err := model.GetChannelById(id, true) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - testRequest := buildTestRequest(c) - tik := time.Now() - err = testChannel(channel, testRequest) - tok := time.Now() - milliseconds := tok.Sub(tik).Milliseconds() - go channel.UpdateResponseTime(milliseconds) - consumedTime := float64(milliseconds) / 1000.0 - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - "time": consumedTime, - }) - return - } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "time": consumedTime, - }) - return -} - -var testAllChannelsLock sync.Mutex -var testAllChannelsRunning bool = false - -// disable & notify -func disableChannel(channelId int, channelName string, reason string) { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() - } - model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - err := common.SendEmail(subject, common.RootUserEmail, content) - if err != nil { - common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) - } -} - -func testAllChannels(c *gin.Context) error { - testAllChannelsLock.Lock() - if testAllChannelsRunning { - testAllChannelsLock.Unlock() - return errors.New("测试已在运行中") - } - testAllChannelsRunning = true - testAllChannelsLock.Unlock() - channels, err := model.GetAllChannels(0, 0, true) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return err - } - testRequest := buildTestRequest(c) - var disableThreshold = int64(common.ChannelDisableThreshold * 1000) - if disableThreshold == 0 { - disableThreshold = 10000000 // a impossible value - } - go func() { - for _, channel := range channels { - if channel.Status != common.ChannelStatusEnabled { - continue - } - tik := time.Now() - err := testChannel(channel, testRequest) - tok := time.Now() - milliseconds := tok.Sub(tik).Milliseconds() - if err != nil || milliseconds > disableThreshold { - if milliseconds > disableThreshold { - err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - } - disableChannel(channel.Id, channel.Name, err.Error()) - } - channel.UpdateResponseTime(milliseconds) - } - err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") - if err != nil { - common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) - } - testAllChannelsLock.Lock() - testAllChannelsRunning = false - testAllChannelsLock.Unlock() - }() - return nil -} - -func TestAllChannels(c *gin.Context) { - err := testAllChannels(c) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - }) - return -} diff --git a/controller/relay-palm.go b/controller/relay-palm.go new file mode 100644 index 00000000..ae739ca0 --- /dev/null +++ b/controller/relay-palm.go @@ -0,0 +1,59 @@ +package controller + +import ( + "fmt" + "github.com/gin-gonic/gin" +) + +type PaLMChatMessage struct { + Author string `json:"author"` + Content string `json:"content"` +} + +type PaLMFilter struct { + Reason string `json:"reason"` + Message string `json:"message"` +} + +// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body +type PaLMChatRequest struct { + Prompt []Message `json:"prompt"` + Temperature float64 `json:"temperature"` + CandidateCount int `json:"candidateCount"` + TopP float64 `json:"topP"` + TopK int `json:"topK"` +} + +// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body +type PaLMChatResponse struct { + Candidates []Message `json:"candidates"` + Messages []Message `json:"messages"` + Filters []PaLMFilter `json:"filters"` +} + +func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode { + // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage + messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages)) + for _, message := range openAIRequest.Messages { + var author string + if message.Role == "user" { + author = "0" + } else { + author = "1" + } + messages = append(messages, PaLMChatMessage{ + Author: author, + Content: message.Content, + }) + } + request := PaLMChatRequest{ + Prompt: nil, + Temperature: openAIRequest.Temperature, + CandidateCount: openAIRequest.N, + TopP: openAIRequest.TopP, + TopK: openAIRequest.MaxTokens, + } + // TODO: forward request to PaLM & convert response + fmt.Print(request) + return nil +} diff --git a/controller/relay-utils.go b/controller/relay-utils.go new file mode 100644 index 00000000..bb25fa3b --- /dev/null +++ b/controller/relay-utils.go @@ -0,0 +1,65 @@ +package controller + +import ( + "fmt" + "github.com/pkoukk/tiktoken-go" + "one-api/common" + "strings" +) + +var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} + +func getTokenEncoder(model string) *tiktoken.Tiktoken { + if tokenEncoder, ok := tokenEncoderMap[model]; ok { + return tokenEncoder + } + tokenEncoder, err := tiktoken.EncodingForModel(model) + if err != nil { + common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) + tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo") + if err != nil { + common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error())) + } + } + tokenEncoderMap[model] = tokenEncoder + return tokenEncoder +} + +func countTokenMessages(messages []Message, model string) int { + tokenEncoder := getTokenEncoder(model) + // Reference: + // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + // https://github.com/pkoukk/tiktoken-go/issues/6 + // + // Every message follows <|start|>{role/name}\n{content}<|end|>\n + var tokensPerMessage int + var tokensPerName int + if strings.HasPrefix(model, "gpt-3.5") { + tokensPerMessage = 4 + tokensPerName = -1 // If there's a name, the role is omitted + } else if strings.HasPrefix(model, "gpt-4") { + tokensPerMessage = 3 + tokensPerName = 1 + } else { + tokensPerMessage = 3 + tokensPerName = 1 + } + tokenNum := 0 + for _, message := range messages { + tokenNum += tokensPerMessage + tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil)) + tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil)) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil)) + } + } + tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> + return tokenNum +} + +func countTokenText(text string, model string) int { + tokenEncoder := getTokenEncoder(model) + token := tokenEncoder.Encode(text, nil, nil) + return len(token) +} diff --git a/controller/relay.go b/controller/relay.go index db6298fa..81497d81 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" "io" "net/http" "one-api/common" @@ -15,8 +14,22 @@ import ( ) type Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + Name *string `json:"name,omitempty"` +} + +// https://platform.openai.com/docs/api-reference/chat + +type GeneralOpenAIRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Prompt string `json:"prompt"` + Stream bool `json:"stream"` + MaxTokens int `json:"max_tokens"` + Temperature float64 `json:"temperature"` + TopP float64 `json:"top_p"` + N int `json:"n"` } type ChatRequest struct { @@ -65,40 +78,6 @@ type StreamResponse struct { } `json:"choices"` } -func countTokenMessages(messages []Message, model string) int { - // 获取模型的编码器 - tokenEncoder, _ := tiktoken.EncodingForModel(model) - - // 参照官方的token计算cookbook - // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - // https://github.com/pkoukk/tiktoken-go/issues/6 - var tokens_per_message int - if strings.HasPrefix(model, "gpt-3.5") { - tokens_per_message = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n - } else if strings.HasPrefix(model, "gpt-4") { - tokens_per_message = 3 - } else { - tokens_per_message = 3 - } - - token := 0 - for _, message := range messages { - token += tokens_per_message - token += len(tokenEncoder.Encode(message.Content, nil, nil)) - token += len(tokenEncoder.Encode(message.Role, nil, nil)) - } - // 经过测试这个assistant的token是算在prompt里面的,而不是算在Completion里面的 - token += 3 // every reply is primed with <|start|>assistant<|message|> - return token -} - -func countTokenText(text string, model string) int { - // 获取模型的编码器 - tokenEncoder, _ := tiktoken.EncodingForModel(model) - token := tokenEncoder.Encode(text, nil, nil) - return len(token) -} - func Relay(c *gin.Context) { err := relayHelper(c) if err != nil { @@ -110,8 +89,8 @@ func Relay(c *gin.Context) { }) channelId := c.GetInt("channel_id") common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message)) - if err.Type != "invalid_request_error" && err.StatusCode != http.StatusTooManyRequests && - common.AutomaticDisableChannelEnabled { + // https://platform.openai.com/docs/guides/error-codes/api-errors + if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key") { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") disableChannel(channelId, channelName, err.Message) @@ -135,8 +114,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { channelType := c.GetInt("channel") tokenId := c.GetInt("token_id") consumeQuota := c.GetBool("consume_quota") - var textRequest TextRequest - if consumeQuota || channelType == common.ChannelTypeAzure { + var textRequest GeneralOpenAIRequest + if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { requestBody, err := io.ReadAll(c.Request.Body) if err != nil { return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest) @@ -175,6 +154,9 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { model_ = strings.TrimSuffix(model_, "-0301") model_ = strings.TrimSuffix(model_, "-0314") fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) + } else if channelType == common.ChannelTypePaLM { + err := relayPaLM(textRequest, c) + return err } promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model) @@ -230,7 +212,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { completionRatio = 2 } if isStream { - quota = promptTokens + countTokenText(streamResponseText, textRequest.Model)*completionRatio + responseTokens := countTokenText(streamResponseText, textRequest.Model) + quota = promptTokens + responseTokens*completionRatio } else { quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio } @@ -265,6 +248,10 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { go func() { for scanner.Scan() { data := scanner.Text() + if len(data) < 6 { // must be something wrong! + common.SysError("Invalid stream response: " + data) + continue + } dataChan <- data data = data[6:] if !strings.HasPrefix(data, "[DONE]") { diff --git a/middleware/cache.go b/middleware/cache.go index 7f6099f5..979734ab 100644 --- a/middleware/cache.go +++ b/middleware/cache.go @@ -6,7 +6,11 @@ import ( func Cache() func(c *gin.Context) { return func(c *gin.Context) { - c.Header("Cache-Control", "max-age=604800") // one week + if c.Request.RequestURI == "/" { + c.Header("Cache-Control", "no-cache") + } else { + c.Header("Cache-Control", "max-age=604800") // one week + } c.Next() } } diff --git a/model/channel.go b/model/channel.go index 0335207b..35d65827 100644 --- a/model/channel.go +++ b/model/channel.go @@ -6,17 +6,19 @@ import ( ) type Channel struct { - Id int `json:"id"` - Type int `json:"type" gorm:"default:0"` - Key string `json:"key" gorm:"not null"` - Status int `json:"status" gorm:"default:1"` - Name string `json:"name" gorm:"index"` - Weight int `json:"weight"` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - TestTime int64 `json:"test_time" gorm:"bigint"` - ResponseTime int `json:"response_time"` // in milliseconds - BaseURL string `json:"base_url" gorm:"column:base_url"` - Other string `json:"other"` + Id int `json:"id"` + Type int `json:"type" gorm:"default:0"` + Key string `json:"key" gorm:"not null"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index"` + Weight int `json:"weight"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + TestTime int64 `json:"test_time" gorm:"bigint"` + ResponseTime int `json:"response_time"` // in milliseconds + BaseURL string `json:"base_url" gorm:"column:base_url"` + Other string `json:"other"` + Balance float64 `json:"balance"` // in USD + BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -86,6 +88,16 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { } } +func (channel *Channel) UpdateBalance(balance float64) { + err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ + BalanceUpdatedTime: common.GetTimestamp(), + Balance: balance, + }).Error + if err != nil { + common.SysError("failed to update balance: " + err.Error()) + } +} + func (channel *Channel) Delete() error { var err error err = DB.Delete(channel).Error diff --git a/model/main.go b/model/main.go index 0bc09230..3f6fafbf 100644 --- a/model/main.go +++ b/model/main.go @@ -26,6 +26,7 @@ func createRootAccountIfNeed() error { Status: common.UserStatusEnabled, DisplayName: "Root User", AccessToken: common.GetUUID(), + Quota: 100000000, } DB.Create(&rootUser) } diff --git a/model/user.go b/model/user.go index a54351c7..2ca0d6a4 100644 --- a/model/user.go +++ b/model/user.go @@ -19,8 +19,7 @@ type User struct { Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` - VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! - Balance int `json:"balance" gorm:"type:int;default:0"` + VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int `json:"quota" gorm:"type:int;default:0"` } diff --git a/router/api-router.go b/router/api-router.go index 5cd86e3e..9ca2226a 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -66,6 +66,8 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/test", controller.TestAllChannels) channelRoute.GET("/test/:id", controller.TestChannel) + channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) + channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) channelRoute.POST("/", controller.AddChannel) channelRoute.PUT("/", controller.UpdateChannel) channelRoute.DELETE("/:id", controller.DeleteChannel) diff --git a/router/dashboard.go b/router/dashboard.go index 3eacaf9a..39ed1f93 100644 --- a/router/dashboard.go +++ b/router/dashboard.go @@ -8,11 +8,14 @@ import ( ) func SetDashboardRouter(router *gin.Engine) { - apiRouter := router.Group("/dashboard") + apiRouter := router.Group("/") apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) apiRouter.Use(middleware.GlobalAPIRateLimit()) apiRouter.Use(middleware.TokenAuth()) { - apiRouter.GET("/billing/credit_grants", controller.GetTokenStatus) + apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription) + apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription) + apiRouter.GET("/dashboard/billing/usage", controller.GetUsage) + apiRouter.GET("/v1/dashboard/billing/usage", controller.GetUsage) } } diff --git a/router/web-router.go b/router/web-router.go index 71e98c27..8f6d1ac4 100644 --- a/router/web-router.go +++ b/router/web-router.go @@ -16,6 +16,7 @@ func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { router.Use(middleware.Cache()) router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build"))) router.NoRoute(func(c *gin.Context) { + c.Header("Cache-Control", "no-cache") c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage) }) } diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index f0f33e96..a0a0f5dd 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -32,6 +32,7 @@ const ChannelsTable = () => { const [activePage, setActivePage] = useState(1); const [searchKeyword, setSearchKeyword] = useState(''); const [searching, setSearching] = useState(false); + const [updatingBalance, setUpdatingBalance] = useState(false); const loadChannels = async (startIdx) => { const res = await API.get(`/api/channel/?p=${startIdx}`); @@ -63,7 +64,7 @@ const ChannelsTable = () => { const refresh = async () => { setLoading(true); await loadChannels(0); - } + }; useEffect(() => { loadChannels(0) @@ -127,7 +128,7 @@ const ChannelsTable = () => { const renderResponseTime = (responseTime) => { let time = responseTime / 1000; - time = time.toFixed(2) + " 秒"; + time = time.toFixed(2) + ' 秒'; if (responseTime === 0) { return ; } else if (responseTime <= 1000) { @@ -179,11 +180,38 @@ const ChannelsTable = () => { const res = await API.get(`/api/channel/test`); const { success, message } = res.data; if (success) { - showInfo("已成功开始测试所有已启用通道,请刷新页面查看结果。"); + showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。'); } else { showError(message); } - } + }; + + const updateChannelBalance = async (id, name, idx) => { + const res = await API.get(`/api/channel/update_balance/${id}/`); + const { success, message, balance } = res.data; + if (success) { + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].balance = balance; + newChannels[realIdx].balance_updated_time = Date.now() / 1000; + setChannels(newChannels); + showInfo(`通道 ${name} 余额更新成功!`); + } else { + showError(message); + } + }; + + const updateAllChannelsBalance = async () => { + setUpdatingBalance(true); + const res = await API.get(`/api/channel/update_balance`); + const { success, message } = res.data; + if (success) { + showInfo('已更新完毕所有已启用通道余额!'); + } else { + showError(message); + } + setUpdatingBalance(false); + }; const handleKeywordChange = async (e, { value }) => { setSearchKeyword(value.trim()); @@ -263,10 +291,10 @@ const ChannelsTable = () => { { - sortChannel('test_time'); + sortChannel('balance'); }} > - 测试时间 + 余额 操作 @@ -286,8 +314,22 @@ const ChannelsTable = () => { {channel.name ? channel.name : '无'} {renderType(channel.type)} {renderStatus(channel.status)} - {renderResponseTime(channel.response_time)} - {channel.test_time ? renderTimestamp(channel.test_time) : "未测试"} + + + + + ${channel.balance.toFixed(2)}} + basic + /> +
+ @@ -353,6 +405,7 @@ const ChannelsTable = () => { + {
账号绑定
- + { + status.wechat_login && ( + + ) + } setShowWeChatBindModal(false)} onOpen={() => setShowWeChatBindModal(true)} @@ -148,7 +152,11 @@ const PersonalSetting = () => { - + { + status.github_oauth && ( + + ) + } + diff --git a/web/src/pages/Redemption/EditRedemption.js b/web/src/pages/Redemption/EditRedemption.js index 687864ba..3f418926 100644 --- a/web/src/pages/Redemption/EditRedemption.js +++ b/web/src/pages/Redemption/EditRedemption.js @@ -111,7 +111,7 @@ const EditRedemption = () => { } - + diff --git a/web/src/pages/Token/EditToken.js b/web/src/pages/Token/EditToken.js index dd8022e1..a2e4f9e8 100644 --- a/web/src/pages/Token/EditToken.js +++ b/web/src/pages/Token/EditToken.js @@ -106,6 +106,34 @@ const EditToken = () => { required={!isEdit} /> + + + +
+ + + + + +
注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。 { disabled={unlimited_quota} /> - - - - - - - - - - + diff --git a/web/src/pages/TopUp/index.js b/web/src/pages/TopUp/index.js index d32d1115..b710b14e 100644 --- a/web/src/pages/TopUp/index.js +++ b/web/src/pages/TopUp/index.js @@ -1,6 +1,6 @@ import React, { useEffect, useState } from 'react'; import { Button, Form, Grid, Header, Segment, Statistic } from 'semantic-ui-react'; -import { API, showError, showSuccess } from '../../helpers'; +import { API, showError, showInfo, showSuccess } from '../../helpers'; const TopUp = () => { const [redemptionCode, setRedemptionCode] = useState(''); @@ -9,6 +9,7 @@ const TopUp = () => { const topUp = async () => { if (redemptionCode === '') { + showInfo('请输入充值码!') return; } const res = await API.post('/api/user/topup', { @@ -80,7 +81,7 @@ const TopUp = () => { - {userQuota} + {userQuota.toLocaleString()} 剩余额度 diff --git a/web/src/pages/User/AddUser.js b/web/src/pages/User/AddUser.js index 73036ada..f9f4bc18 100644 --- a/web/src/pages/User/AddUser.js +++ b/web/src/pages/User/AddUser.js @@ -65,7 +65,7 @@ const AddUser = () => { required /> - diff --git a/web/src/pages/User/EditUser.js b/web/src/pages/User/EditUser.js index a8d4e7cf..bef421bc 100644 --- a/web/src/pages/User/EditUser.js +++ b/web/src/pages/User/EditUser.js @@ -14,8 +14,9 @@ const EditUser = () => { github_id: '', wechat_id: '', email: '', + quota: 0, }); - const { username, display_name, password, github_id, wechat_id, email } = + const { username, display_name, password, github_id, wechat_id, email, quota } = inputs; const handleInputChange = (e, { name, value }) => { setInputs((inputs) => ({ ...inputs, [name]: value })); @@ -44,7 +45,11 @@ const EditUser = () => { const submit = async () => { let res = undefined; if (userId) { - res = await API.put(`/api/user/`, { ...inputs, id: parseInt(userId) }); + let data = { ...inputs, id: parseInt(userId) }; + if (typeof data.quota === 'string') { + data.quota = parseInt(data.quota); + } + res = await API.put(`/api/user/`, data); } else { res = await API.put(`/api/user/self`, inputs); } @@ -92,6 +97,21 @@ const EditUser = () => { autoComplete='new-password' /> + { + userId && ( + + + + ) + } { readOnly /> - +