Merge remote-tracking branch 'songquanpeng/main'
# Conflicts: # controller/relay.go
This commit is contained in:
commit
5eea5af4fb
@ -51,15 +51,17 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
|
|||||||
+ [x] **Azure OpenAI API**
|
+ [x] **Azure OpenAI API**
|
||||||
+ [x] [API2D](https://api2d.com/r/197971)
|
+ [x] [API2D](https://api2d.com/r/197971)
|
||||||
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
|
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
|
||||||
+ [x] [CloseAI](https://console.openai-asia.com)
|
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
|
||||||
+ [x] [OpenAI-SB](https://openai-sb.com)
|
+ [x] [AI.LS](https://ai.ls)
|
||||||
+ [x] [OpenAI Max](https://openaimax.com)
|
+ [x] [OpenAI Max](https://openaimax.com)
|
||||||
|
+ [x] [OpenAI-SB](https://openai-sb.com)
|
||||||
|
+ [x] [CloseAI](https://console.openai-asia.com/r/2412)
|
||||||
+ [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理
|
+ [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理
|
||||||
2. 支持通过**负载均衡**的方式访问多个渠道。
|
2. 支持通过**负载均衡**的方式访问多个渠道。
|
||||||
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
||||||
4. 支持**多机部署**,[详见此处](#多机部署)。
|
4. 支持**多机部署**,[详见此处](#多机部署)。
|
||||||
5. 支持**令牌管理**,设置令牌的过期时间和使用次数。
|
5. 支持**令牌管理**,设置令牌的过期时间和使用次数。
|
||||||
6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为令牌进行充值。
|
6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
|
||||||
7. 支持**通道管理**,批量创建通道。
|
7. 支持**通道管理**,批量创建通道。
|
||||||
8. 支持发布公告,设置充值链接,设置新用户初始额度。
|
8. 支持发布公告,设置充值链接,设置新用户初始额度。
|
||||||
9. 支持丰富的**自定义**设置,
|
9. 支持丰富的**自定义**设置,
|
||||||
|
@ -127,6 +127,9 @@ const (
|
|||||||
ChannelTypeOpenAIMax = 6
|
ChannelTypeOpenAIMax = 6
|
||||||
ChannelTypeOhMyGPT = 7
|
ChannelTypeOhMyGPT = 7
|
||||||
ChannelTypeCustom = 8
|
ChannelTypeCustom = 8
|
||||||
|
ChannelTypeAILS = 9
|
||||||
|
ChannelTypeAIProxy = 10
|
||||||
|
ChannelTypePaLM = 11
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@ -139,4 +142,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.openaimax.com", // 6
|
"https://api.openaimax.com", // 6
|
||||||
"https://api.ohmygpt.com", // 7
|
"https://api.ohmygpt.com", // 7
|
||||||
"", // 8
|
"", // 8
|
||||||
|
"https://api.caipacity.com", // 9
|
||||||
|
"https://api.aiproxy.io", // 10
|
||||||
|
"", // 11
|
||||||
}
|
}
|
||||||
|
41
controller/billing.go
Normal file
41
controller/billing.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetSubscription(c *gin.Context) {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
quota, err := model.GetUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
|
openAIError := OpenAIError{
|
||||||
|
Message: err.Error(),
|
||||||
|
Type: "one_api_error",
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"error": openAIError,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
subscription := OpenAISubscriptionResponse{
|
||||||
|
Object: "billing_subscription",
|
||||||
|
HasPaymentMethod: true,
|
||||||
|
SoftLimitUSD: float64(quota),
|
||||||
|
HardLimitUSD: float64(quota),
|
||||||
|
SystemHardLimitUSD: float64(quota),
|
||||||
|
}
|
||||||
|
c.JSON(200, subscription)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUsage(c *gin.Context) {
|
||||||
|
//userId := c.GetInt("id")
|
||||||
|
// TODO: get usage from database
|
||||||
|
usage := OpenAIUsageResponse{
|
||||||
|
Object: "list",
|
||||||
|
TotalUsage: 0,
|
||||||
|
}
|
||||||
|
c.JSON(200, usage)
|
||||||
|
return
|
||||||
|
}
|
179
controller/channel-billing.go
Normal file
179
controller/channel-billing.go
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://github.com/songquanpeng/one-api/issues/79
|
||||||
|
|
||||||
|
type OpenAISubscriptionResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
HasPaymentMethod bool `json:"has_payment_method"`
|
||||||
|
SoftLimitUSD float64 `json:"soft_limit_usd"`
|
||||||
|
HardLimitUSD float64 `json:"hard_limit_usd"`
|
||||||
|
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIUsageDailyCost struct {
|
||||||
|
Timestamp float64 `json:"timestamp"`
|
||||||
|
LineItems []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Cost float64 `json:"cost"`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIUsageResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
|
||||||
|
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||||
|
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||||
|
switch channel.Type {
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
|
return 0, errors.New("尚未实现")
|
||||||
|
case common.ChannelTypeCustom:
|
||||||
|
baseURL = channel.BaseURL
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
auth := fmt.Sprintf("Bearer %s", channel.Key)
|
||||||
|
req.Header.Add("Authorization", auth)
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
err = res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
subscription := OpenAISubscriptionResponse{}
|
||||||
|
err = json.Unmarshal(body, &subscription)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
|
||||||
|
//endDate := now.Format("2006-01-02")
|
||||||
|
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, "2023-06-01")
|
||||||
|
req, err = http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
req.Header.Add("Authorization", auth)
|
||||||
|
res, err = client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
body, err = io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
err = res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
usage := OpenAIUsageResponse{}
|
||||||
|
err = json.Unmarshal(body, &usage)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
balance := subscription.HardLimitUSD - usage.TotalUsage/100
|
||||||
|
channel.UpdateBalance(balance)
|
||||||
|
return balance, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateChannelBalance(c *gin.Context) {
|
||||||
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
channel, err := model.GetChannelById(id, true)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
balance, err := updateChannelBalance(channel)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"balance": balance,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateAllChannelsBalance() error {
|
||||||
|
channels, err := model.GetAllChannels(0, 0, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, channel := range channels {
|
||||||
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// TODO: support Azure
|
||||||
|
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
balance, err := updateChannelBalance(channel)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
// err is nil & balance <= 0 means quota is used up
|
||||||
|
if balance <= 0 {
|
||||||
|
disableChannel(channel.Id, channel.Name, "余额不足")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateAllChannelsBalance(c *gin.Context) {
|
||||||
|
// TODO: make it async
|
||||||
|
err := updateAllChannelsBalance()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
199
controller/channel-test.go
Normal file
199
controller/channel-test.go
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testChannel(channel *model.Channel, request *ChatRequest) error {
|
||||||
|
if request.Model == "" {
|
||||||
|
request.Model = "gpt-3.5-turbo"
|
||||||
|
if channel.Type == common.ChannelTypeAzure {
|
||||||
|
request.Model = "gpt-35-turbo"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
requestURL := common.ChannelBaseURLs[channel.Type]
|
||||||
|
if channel.Type == common.ChannelTypeAzure {
|
||||||
|
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
|
||||||
|
} else {
|
||||||
|
if channel.Type == common.ChannelTypeCustom {
|
||||||
|
requestURL = channel.BaseURL
|
||||||
|
}
|
||||||
|
requestURL += "/v1/chat/completions"
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if channel.Type == common.ChannelTypeAzure {
|
||||||
|
req.Header.Set("api-key", channel.Key)
|
||||||
|
} else {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
var response TextResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if response.Error.Message != "" || response.Error.Code != "" {
|
||||||
|
return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTestRequest(c *gin.Context) *ChatRequest {
|
||||||
|
model_ := c.Query("model")
|
||||||
|
testRequest := &ChatRequest{
|
||||||
|
Model: model_,
|
||||||
|
MaxTokens: 1,
|
||||||
|
}
|
||||||
|
testMessage := Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: "hi",
|
||||||
|
}
|
||||||
|
testRequest.Messages = append(testRequest.Messages, testMessage)
|
||||||
|
return testRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChannel(c *gin.Context) {
|
||||||
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
channel, err := model.GetChannelById(id, true)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
testRequest := buildTestRequest(c)
|
||||||
|
tik := time.Now()
|
||||||
|
err = testChannel(channel, testRequest)
|
||||||
|
tok := time.Now()
|
||||||
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
go channel.UpdateResponseTime(milliseconds)
|
||||||
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
"time": consumedTime,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"time": consumedTime,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var testAllChannelsLock sync.Mutex
|
||||||
|
var testAllChannelsRunning bool = false
|
||||||
|
|
||||||
|
// disable & notify
|
||||||
|
func disableChannel(channelId int, channelName string, reason string) {
|
||||||
|
if common.RootUserEmail == "" {
|
||||||
|
common.RootUserEmail = model.GetRootUserEmail()
|
||||||
|
}
|
||||||
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
|
||||||
|
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||||
|
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||||
|
err := common.SendEmail(subject, common.RootUserEmail, content)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAllChannels(c *gin.Context) error {
|
||||||
|
testAllChannelsLock.Lock()
|
||||||
|
if testAllChannelsRunning {
|
||||||
|
testAllChannelsLock.Unlock()
|
||||||
|
return errors.New("测试已在运行中")
|
||||||
|
}
|
||||||
|
testAllChannelsRunning = true
|
||||||
|
testAllChannelsLock.Unlock()
|
||||||
|
channels, err := model.GetAllChannels(0, 0, true)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
testRequest := buildTestRequest(c)
|
||||||
|
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||||
|
if disableThreshold == 0 {
|
||||||
|
disableThreshold = 10000000 // a impossible value
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for _, channel := range channels {
|
||||||
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tik := time.Now()
|
||||||
|
err := testChannel(channel, testRequest)
|
||||||
|
tok := time.Now()
|
||||||
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
if err != nil || milliseconds > disableThreshold {
|
||||||
|
if milliseconds > disableThreshold {
|
||||||
|
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||||
|
}
|
||||||
|
disableChannel(channel.Id, channel.Name, err.Error())
|
||||||
|
}
|
||||||
|
channel.UpdateResponseTime(milliseconds)
|
||||||
|
}
|
||||||
|
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
|
||||||
|
}
|
||||||
|
testAllChannelsLock.Lock()
|
||||||
|
testAllChannelsRunning = false
|
||||||
|
testAllChannelsLock.Unlock()
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllChannels(c *gin.Context) {
|
||||||
|
err := testAllChannels(c)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
@ -1,18 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllChannels(c *gin.Context) {
|
func GetAllChannels(c *gin.Context) {
|
||||||
@ -158,187 +152,3 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request *ChatRequest) error {
|
|
||||||
if request.Model == "" {
|
|
||||||
request.Model = "gpt-3.5-turbo"
|
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
|
||||||
request.Model = "gpt-35-turbo"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
requestURL := common.ChannelBaseURLs[channel.Type]
|
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
|
||||||
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
|
|
||||||
} else {
|
|
||||||
if channel.Type == common.ChannelTypeCustom {
|
|
||||||
requestURL = channel.BaseURL
|
|
||||||
}
|
|
||||||
requestURL += "/v1/chat/completions"
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
|
||||||
req.Header.Set("api-key", channel.Key)
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
var response TextResponse
|
|
||||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if response.Error.Message != "" {
|
|
||||||
return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildTestRequest(c *gin.Context) *ChatRequest {
|
|
||||||
model_ := c.Query("model")
|
|
||||||
testRequest := &ChatRequest{
|
|
||||||
Model: model_,
|
|
||||||
MaxTokens: 1,
|
|
||||||
}
|
|
||||||
testMessage := Message{
|
|
||||||
Role: "user",
|
|
||||||
Content: "hi",
|
|
||||||
}
|
|
||||||
testRequest.Messages = append(testRequest.Messages, testMessage)
|
|
||||||
return testRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChannel(c *gin.Context) {
|
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
channel, err := model.GetChannelById(id, true)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
testRequest := buildTestRequest(c)
|
|
||||||
tik := time.Now()
|
|
||||||
err = testChannel(channel, testRequest)
|
|
||||||
tok := time.Now()
|
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
|
||||||
go channel.UpdateResponseTime(milliseconds)
|
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
"time": consumedTime,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"message": "",
|
|
||||||
"time": consumedTime,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var testAllChannelsLock sync.Mutex
|
|
||||||
var testAllChannelsRunning bool = false
|
|
||||||
|
|
||||||
// disable & notify
|
|
||||||
func disableChannel(channelId int, channelName string, reason string) {
|
|
||||||
if common.RootUserEmail == "" {
|
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
|
||||||
}
|
|
||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
|
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
|
||||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testAllChannels(c *gin.Context) error {
|
|
||||||
testAllChannelsLock.Lock()
|
|
||||||
if testAllChannelsRunning {
|
|
||||||
testAllChannelsLock.Unlock()
|
|
||||||
return errors.New("测试已在运行中")
|
|
||||||
}
|
|
||||||
testAllChannelsRunning = true
|
|
||||||
testAllChannelsLock.Unlock()
|
|
||||||
channels, err := model.GetAllChannels(0, 0, true)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
testRequest := buildTestRequest(c)
|
|
||||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
|
||||||
if disableThreshold == 0 {
|
|
||||||
disableThreshold = 10000000 // a impossible value
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
for _, channel := range channels {
|
|
||||||
if channel.Status != common.ChannelStatusEnabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
tik := time.Now()
|
|
||||||
err := testChannel(channel, testRequest)
|
|
||||||
tok := time.Now()
|
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
|
||||||
if err != nil || milliseconds > disableThreshold {
|
|
||||||
if milliseconds > disableThreshold {
|
|
||||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
|
||||||
}
|
|
||||||
disableChannel(channel.Id, channel.Name, err.Error())
|
|
||||||
}
|
|
||||||
channel.UpdateResponseTime(milliseconds)
|
|
||||||
}
|
|
||||||
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
|
|
||||||
}
|
|
||||||
testAllChannelsLock.Lock()
|
|
||||||
testAllChannelsRunning = false
|
|
||||||
testAllChannelsLock.Unlock()
|
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAllChannels(c *gin.Context) {
|
|
||||||
err := testAllChannels(c)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"message": "",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
59
controller/relay-palm.go
Normal file
59
controller/relay-palm.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PaLMChatMessage struct {
|
||||||
|
Author string `json:"author"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMFilter struct {
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||||
|
type PaLMChatRequest struct {
|
||||||
|
Prompt []Message `json:"prompt"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
CandidateCount int `json:"candidateCount"`
|
||||||
|
TopP float64 `json:"topP"`
|
||||||
|
TopK int `json:"topK"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||||
|
type PaLMChatResponse struct {
|
||||||
|
Candidates []Message `json:"candidates"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
Filters []PaLMFilter `json:"filters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode {
|
||||||
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
|
||||||
|
messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages))
|
||||||
|
for _, message := range openAIRequest.Messages {
|
||||||
|
var author string
|
||||||
|
if message.Role == "user" {
|
||||||
|
author = "0"
|
||||||
|
} else {
|
||||||
|
author = "1"
|
||||||
|
}
|
||||||
|
messages = append(messages, PaLMChatMessage{
|
||||||
|
Author: author,
|
||||||
|
Content: message.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
request := PaLMChatRequest{
|
||||||
|
Prompt: nil,
|
||||||
|
Temperature: openAIRequest.Temperature,
|
||||||
|
CandidateCount: openAIRequest.N,
|
||||||
|
TopP: openAIRequest.TopP,
|
||||||
|
TopK: openAIRequest.MaxTokens,
|
||||||
|
}
|
||||||
|
// TODO: forward request to PaLM & convert response
|
||||||
|
fmt.Print(request)
|
||||||
|
return nil
|
||||||
|
}
|
65
controller/relay-utils.go
Normal file
65
controller/relay-utils.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/pkoukk/tiktoken-go"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||||
|
|
||||||
|
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||||
|
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
|
||||||
|
return tokenEncoder
|
||||||
|
}
|
||||||
|
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||||
|
tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||||
|
if err != nil {
|
||||||
|
common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokenEncoderMap[model] = tokenEncoder
|
||||||
|
return tokenEncoder
|
||||||
|
}
|
||||||
|
|
||||||
|
func countTokenMessages(messages []Message, model string) int {
|
||||||
|
tokenEncoder := getTokenEncoder(model)
|
||||||
|
// Reference:
|
||||||
|
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
|
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||||
|
//
|
||||||
|
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
|
var tokensPerMessage int
|
||||||
|
var tokensPerName int
|
||||||
|
if strings.HasPrefix(model, "gpt-3.5") {
|
||||||
|
tokensPerMessage = 4
|
||||||
|
tokensPerName = -1 // If there's a name, the role is omitted
|
||||||
|
} else if strings.HasPrefix(model, "gpt-4") {
|
||||||
|
tokensPerMessage = 3
|
||||||
|
tokensPerName = 1
|
||||||
|
} else {
|
||||||
|
tokensPerMessage = 3
|
||||||
|
tokensPerName = 1
|
||||||
|
}
|
||||||
|
tokenNum := 0
|
||||||
|
for _, message := range messages {
|
||||||
|
tokenNum += tokensPerMessage
|
||||||
|
tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
|
||||||
|
tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
|
||||||
|
if message.Name != nil {
|
||||||
|
tokenNum += tokensPerName
|
||||||
|
tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||||
|
return tokenNum
|
||||||
|
}
|
||||||
|
|
||||||
|
func countTokenText(text string, model string) int {
|
||||||
|
tokenEncoder := getTokenEncoder(model)
|
||||||
|
token := tokenEncoder.Encode(text, nil, nil)
|
||||||
|
return len(token)
|
||||||
|
}
|
@ -6,7 +6,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pkoukk/tiktoken-go"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@ -15,8 +14,22 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
Name *string `json:"name,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/chat
|
||||||
|
|
||||||
|
type GeneralOpenAIRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
N int `json:"n"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
@ -65,40 +78,6 @@ type StreamResponse struct {
|
|||||||
} `json:"choices"`
|
} `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func countTokenMessages(messages []Message, model string) int {
|
|
||||||
// 获取模型的编码器
|
|
||||||
tokenEncoder, _ := tiktoken.EncodingForModel(model)
|
|
||||||
|
|
||||||
// 参照官方的token计算cookbook
|
|
||||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
||||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
|
||||||
var tokens_per_message int
|
|
||||||
if strings.HasPrefix(model, "gpt-3.5") {
|
|
||||||
tokens_per_message = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
|
||||||
} else if strings.HasPrefix(model, "gpt-4") {
|
|
||||||
tokens_per_message = 3
|
|
||||||
} else {
|
|
||||||
tokens_per_message = 3
|
|
||||||
}
|
|
||||||
|
|
||||||
token := 0
|
|
||||||
for _, message := range messages {
|
|
||||||
token += tokens_per_message
|
|
||||||
token += len(tokenEncoder.Encode(message.Content, nil, nil))
|
|
||||||
token += len(tokenEncoder.Encode(message.Role, nil, nil))
|
|
||||||
}
|
|
||||||
// 经过测试这个assistant的token是算在prompt里面的,而不是算在Completion里面的
|
|
||||||
token += 3 // every reply is primed with <|start|>assistant<|message|>
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
|
|
||||||
func countTokenText(text string, model string) int {
|
|
||||||
// 获取模型的编码器
|
|
||||||
tokenEncoder, _ := tiktoken.EncodingForModel(model)
|
|
||||||
token := tokenEncoder.Encode(text, nil, nil)
|
|
||||||
return len(token)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
err := relayHelper(c)
|
err := relayHelper(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -110,8 +89,8 @@ func Relay(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message))
|
common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message))
|
||||||
if err.Type != "invalid_request_error" && err.StatusCode != http.StatusTooManyRequests &&
|
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||||
common.AutomaticDisableChannelEnabled {
|
if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key") {
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
channelName := c.GetString("channel_name")
|
channelName := c.GetString("channel_name")
|
||||||
disableChannel(channelId, channelName, err.Message)
|
disableChannel(channelId, channelName, err.Message)
|
||||||
@ -135,8 +114,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
consumeQuota := c.GetBool("consume_quota")
|
||||||
var textRequest TextRequest
|
var textRequest GeneralOpenAIRequest
|
||||||
if consumeQuota || channelType == common.ChannelTypeAzure {
|
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
|
||||||
requestBody, err := io.ReadAll(c.Request.Body)
|
requestBody, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
|
return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
|
||||||
@ -175,6 +154,9 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|||||||
model_ = strings.TrimSuffix(model_, "-0301")
|
model_ = strings.TrimSuffix(model_, "-0301")
|
||||||
model_ = strings.TrimSuffix(model_, "-0314")
|
model_ = strings.TrimSuffix(model_, "-0314")
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
||||||
|
} else if channelType == common.ChannelTypePaLM {
|
||||||
|
err := relayPaLM(textRequest, c)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
|
promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
|
||||||
@ -230,7 +212,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|||||||
completionRatio = 2
|
completionRatio = 2
|
||||||
}
|
}
|
||||||
if isStream {
|
if isStream {
|
||||||
quota = promptTokens + countTokenText(streamResponseText, textRequest.Model)*completionRatio
|
responseTokens := countTokenText(streamResponseText, textRequest.Model)
|
||||||
|
quota = promptTokens + responseTokens*completionRatio
|
||||||
} else {
|
} else {
|
||||||
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
|
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
|
||||||
}
|
}
|
||||||
@ -265,6 +248,10 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|||||||
go func() {
|
go func() {
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
|
if len(data) < 6 { // must be something wrong!
|
||||||
|
common.SysError("Invalid stream response: " + data)
|
||||||
|
continue
|
||||||
|
}
|
||||||
dataChan <- data
|
dataChan <- data
|
||||||
data = data[6:]
|
data = data[6:]
|
||||||
if !strings.HasPrefix(data, "[DONE]") {
|
if !strings.HasPrefix(data, "[DONE]") {
|
||||||
|
@ -6,7 +6,11 @@ import (
|
|||||||
|
|
||||||
func Cache() func(c *gin.Context) {
|
func Cache() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
c.Header("Cache-Control", "max-age=604800") // one week
|
if c.Request.RequestURI == "/" {
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
} else {
|
||||||
|
c.Header("Cache-Control", "max-age=604800") // one week
|
||||||
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,17 +6,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Type int `json:"type" gorm:"default:0"`
|
Type int `json:"type" gorm:"default:0"`
|
||||||
Key string `json:"key" gorm:"not null"`
|
Key string `json:"key" gorm:"not null"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index"`
|
Name string `json:"name" gorm:"index"`
|
||||||
Weight int `json:"weight"`
|
Weight int `json:"weight"`
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
TestTime int64 `json:"test_time" gorm:"bigint"`
|
TestTime int64 `json:"test_time" gorm:"bigint"`
|
||||||
ResponseTime int `json:"response_time"` // in milliseconds
|
ResponseTime int `json:"response_time"` // in milliseconds
|
||||||
BaseURL string `json:"base_url" gorm:"column:base_url"`
|
BaseURL string `json:"base_url" gorm:"column:base_url"`
|
||||||
Other string `json:"other"`
|
Other string `json:"other"`
|
||||||
|
Balance float64 `json:"balance"` // in USD
|
||||||
|
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||||
@ -86,6 +88,16 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) UpdateBalance(balance float64) {
|
||||||
|
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
|
||||||
|
BalanceUpdatedTime: common.GetTimestamp(),
|
||||||
|
Balance: balance,
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to update balance: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) Delete() error {
|
func (channel *Channel) Delete() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Delete(channel).Error
|
err = DB.Delete(channel).Error
|
||||||
|
@ -26,6 +26,7 @@ func createRootAccountIfNeed() error {
|
|||||||
Status: common.UserStatusEnabled,
|
Status: common.UserStatusEnabled,
|
||||||
DisplayName: "Root User",
|
DisplayName: "Root User",
|
||||||
AccessToken: common.GetUUID(),
|
AccessToken: common.GetUUID(),
|
||||||
|
Quota: 100000000,
|
||||||
}
|
}
|
||||||
DB.Create(&rootUser)
|
DB.Create(&rootUser)
|
||||||
}
|
}
|
||||||
|
@ -19,8 +19,7 @@ type User struct {
|
|||||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||||
Balance int `json:"balance" gorm:"type:int;default:0"`
|
|
||||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||||
Quota int `json:"quota" gorm:"type:int;default:0"`
|
Quota int `json:"quota" gorm:"type:int;default:0"`
|
||||||
}
|
}
|
||||||
|
@ -66,6 +66,8 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
channelRoute.GET("/:id", controller.GetChannel)
|
channelRoute.GET("/:id", controller.GetChannel)
|
||||||
channelRoute.GET("/test", controller.TestAllChannels)
|
channelRoute.GET("/test", controller.TestAllChannels)
|
||||||
channelRoute.GET("/test/:id", controller.TestChannel)
|
channelRoute.GET("/test/:id", controller.TestChannel)
|
||||||
|
channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance)
|
||||||
|
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
|
||||||
channelRoute.POST("/", controller.AddChannel)
|
channelRoute.POST("/", controller.AddChannel)
|
||||||
channelRoute.PUT("/", controller.UpdateChannel)
|
channelRoute.PUT("/", controller.UpdateChannel)
|
||||||
channelRoute.DELETE("/:id", controller.DeleteChannel)
|
channelRoute.DELETE("/:id", controller.DeleteChannel)
|
||||||
|
@ -8,11 +8,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func SetDashboardRouter(router *gin.Engine) {
|
func SetDashboardRouter(router *gin.Engine) {
|
||||||
apiRouter := router.Group("/dashboard")
|
apiRouter := router.Group("/")
|
||||||
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
|
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||||
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
||||||
apiRouter.Use(middleware.TokenAuth())
|
apiRouter.Use(middleware.TokenAuth())
|
||||||
{
|
{
|
||||||
apiRouter.GET("/billing/credit_grants", controller.GetTokenStatus)
|
apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription)
|
||||||
|
apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription)
|
||||||
|
apiRouter.GET("/dashboard/billing/usage", controller.GetUsage)
|
||||||
|
apiRouter.GET("/v1/dashboard/billing/usage", controller.GetUsage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,7 @@ func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
|||||||
router.Use(middleware.Cache())
|
router.Use(middleware.Cache())
|
||||||
router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build")))
|
router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build")))
|
||||||
router.NoRoute(func(c *gin.Context) {
|
router.NoRoute(func(c *gin.Context) {
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage)
|
c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,7 @@ const ChannelsTable = () => {
|
|||||||
const [activePage, setActivePage] = useState(1);
|
const [activePage, setActivePage] = useState(1);
|
||||||
const [searchKeyword, setSearchKeyword] = useState('');
|
const [searchKeyword, setSearchKeyword] = useState('');
|
||||||
const [searching, setSearching] = useState(false);
|
const [searching, setSearching] = useState(false);
|
||||||
|
const [updatingBalance, setUpdatingBalance] = useState(false);
|
||||||
|
|
||||||
const loadChannels = async (startIdx) => {
|
const loadChannels = async (startIdx) => {
|
||||||
const res = await API.get(`/api/channel/?p=${startIdx}`);
|
const res = await API.get(`/api/channel/?p=${startIdx}`);
|
||||||
@ -63,7 +64,7 @@ const ChannelsTable = () => {
|
|||||||
const refresh = async () => {
|
const refresh = async () => {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
await loadChannels(0);
|
await loadChannels(0);
|
||||||
}
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
loadChannels(0)
|
loadChannels(0)
|
||||||
@ -127,7 +128,7 @@ const ChannelsTable = () => {
|
|||||||
|
|
||||||
const renderResponseTime = (responseTime) => {
|
const renderResponseTime = (responseTime) => {
|
||||||
let time = responseTime / 1000;
|
let time = responseTime / 1000;
|
||||||
time = time.toFixed(2) + " 秒";
|
time = time.toFixed(2) + ' 秒';
|
||||||
if (responseTime === 0) {
|
if (responseTime === 0) {
|
||||||
return <Label basic color='grey'>未测试</Label>;
|
return <Label basic color='grey'>未测试</Label>;
|
||||||
} else if (responseTime <= 1000) {
|
} else if (responseTime <= 1000) {
|
||||||
@ -179,11 +180,38 @@ const ChannelsTable = () => {
|
|||||||
const res = await API.get(`/api/channel/test`);
|
const res = await API.get(`/api/channel/test`);
|
||||||
const { success, message } = res.data;
|
const { success, message } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
showInfo("已成功开始测试所有已启用通道,请刷新页面查看结果。");
|
showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。');
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
const updateChannelBalance = async (id, name, idx) => {
|
||||||
|
const res = await API.get(`/api/channel/update_balance/${id}/`);
|
||||||
|
const { success, message, balance } = res.data;
|
||||||
|
if (success) {
|
||||||
|
let newChannels = [...channels];
|
||||||
|
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
|
||||||
|
newChannels[realIdx].balance = balance;
|
||||||
|
newChannels[realIdx].balance_updated_time = Date.now() / 1000;
|
||||||
|
setChannels(newChannels);
|
||||||
|
showInfo(`通道 ${name} 余额更新成功!`);
|
||||||
|
} else {
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const updateAllChannelsBalance = async () => {
|
||||||
|
setUpdatingBalance(true);
|
||||||
|
const res = await API.get(`/api/channel/update_balance`);
|
||||||
|
const { success, message } = res.data;
|
||||||
|
if (success) {
|
||||||
|
showInfo('已更新完毕所有已启用通道余额!');
|
||||||
|
} else {
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
setUpdatingBalance(false);
|
||||||
|
};
|
||||||
|
|
||||||
const handleKeywordChange = async (e, { value }) => {
|
const handleKeywordChange = async (e, { value }) => {
|
||||||
setSearchKeyword(value.trim());
|
setSearchKeyword(value.trim());
|
||||||
@ -263,10 +291,10 @@ const ChannelsTable = () => {
|
|||||||
<Table.HeaderCell
|
<Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
sortChannel('test_time');
|
sortChannel('balance');
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
测试时间
|
余额
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
<Table.HeaderCell>操作</Table.HeaderCell>
|
<Table.HeaderCell>操作</Table.HeaderCell>
|
||||||
</Table.Row>
|
</Table.Row>
|
||||||
@ -286,8 +314,22 @@ const ChannelsTable = () => {
|
|||||||
<Table.Cell>{channel.name ? channel.name : '无'}</Table.Cell>
|
<Table.Cell>{channel.name ? channel.name : '无'}</Table.Cell>
|
||||||
<Table.Cell>{renderType(channel.type)}</Table.Cell>
|
<Table.Cell>{renderType(channel.type)}</Table.Cell>
|
||||||
<Table.Cell>{renderStatus(channel.status)}</Table.Cell>
|
<Table.Cell>{renderStatus(channel.status)}</Table.Cell>
|
||||||
<Table.Cell>{renderResponseTime(channel.response_time)}</Table.Cell>
|
<Table.Cell>
|
||||||
<Table.Cell>{channel.test_time ? renderTimestamp(channel.test_time) : "未测试"}</Table.Cell>
|
<Popup
|
||||||
|
content={channel.test_time ? renderTimestamp(channel.test_time) : '未测试'}
|
||||||
|
key={channel.id}
|
||||||
|
trigger={renderResponseTime(channel.response_time)}
|
||||||
|
basic
|
||||||
|
/>
|
||||||
|
</Table.Cell>
|
||||||
|
<Table.Cell>
|
||||||
|
<Popup
|
||||||
|
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
|
||||||
|
key={channel.id}
|
||||||
|
trigger={<span>${channel.balance.toFixed(2)}</span>}
|
||||||
|
basic
|
||||||
|
/>
|
||||||
|
</Table.Cell>
|
||||||
<Table.Cell>
|
<Table.Cell>
|
||||||
<div>
|
<div>
|
||||||
<Button
|
<Button
|
||||||
@ -299,6 +341,16 @@ const ChannelsTable = () => {
|
|||||||
>
|
>
|
||||||
测试
|
测试
|
||||||
</Button>
|
</Button>
|
||||||
|
<Button
|
||||||
|
size={'small'}
|
||||||
|
positive
|
||||||
|
loading={updatingBalance}
|
||||||
|
onClick={() => {
|
||||||
|
updateChannelBalance(channel.id, channel.name, idx);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
更新余额
|
||||||
|
</Button>
|
||||||
<Popup
|
<Popup
|
||||||
trigger={
|
trigger={
|
||||||
<Button size='small' negative>
|
<Button size='small' negative>
|
||||||
@ -353,6 +405,7 @@ const ChannelsTable = () => {
|
|||||||
<Button size='small' loading={loading} onClick={testAllChannels}>
|
<Button size='small' loading={loading} onClick={testAllChannels}>
|
||||||
测试所有已启用通道
|
测试所有已启用通道
|
||||||
</Button>
|
</Button>
|
||||||
|
<Button size='small' onClick={updateAllChannelsBalance} loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
|
||||||
<Pagination
|
<Pagination
|
||||||
floated='right'
|
floated='right'
|
||||||
activePage={activePage}
|
activePage={activePage}
|
||||||
|
@ -112,13 +112,17 @@ const PersonalSetting = () => {
|
|||||||
<Button onClick={generateAccessToken}>生成系统访问令牌</Button>
|
<Button onClick={generateAccessToken}>生成系统访问令牌</Button>
|
||||||
<Divider />
|
<Divider />
|
||||||
<Header as='h3'>账号绑定</Header>
|
<Header as='h3'>账号绑定</Header>
|
||||||
<Button
|
{
|
||||||
onClick={() => {
|
status.wechat_login && (
|
||||||
setShowWeChatBindModal(true);
|
<Button
|
||||||
}}
|
onClick={() => {
|
||||||
>
|
setShowWeChatBindModal(true);
|
||||||
绑定微信账号
|
}}
|
||||||
</Button>
|
>
|
||||||
|
绑定微信账号
|
||||||
|
</Button>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Modal
|
<Modal
|
||||||
onClose={() => setShowWeChatBindModal(false)}
|
onClose={() => setShowWeChatBindModal(false)}
|
||||||
onOpen={() => setShowWeChatBindModal(true)}
|
onOpen={() => setShowWeChatBindModal(true)}
|
||||||
@ -148,7 +152,11 @@ const PersonalSetting = () => {
|
|||||||
</Modal.Description>
|
</Modal.Description>
|
||||||
</Modal.Content>
|
</Modal.Content>
|
||||||
</Modal>
|
</Modal>
|
||||||
<Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
|
{
|
||||||
|
status.github_oauth && (
|
||||||
|
<Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Button
|
<Button
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
setShowEmailBindModal(true);
|
setShowEmailBindModal(true);
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
export const CHANNEL_OPTIONS = [
|
export const CHANNEL_OPTIONS = [
|
||||||
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
|
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
|
||||||
{ key: 2, text: 'API2D', value: 2, color: 'blue' },
|
{ key: 8, text: '自定义', value: 8, color: 'pink' },
|
||||||
{ key: 3, text: 'Azure', value: 3, color: 'olive' },
|
{ key: 3, text: 'Azure', value: 3, color: 'olive' },
|
||||||
|
{ key: 2, text: 'API2D', value: 2, color: 'blue' },
|
||||||
{ key: 4, text: 'CloseAI', value: 4, color: 'teal' },
|
{ key: 4, text: 'CloseAI', value: 4, color: 'teal' },
|
||||||
{ key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
|
{ key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
|
||||||
{ key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
|
{ key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
|
||||||
{ key: 7, text: 'OhMyGPT', value: 7, color: 'purple' },
|
{ key: 7, text: 'OhMyGPT', value: 7, color: 'purple' },
|
||||||
{ key: 8, text: '自定义', value: 8, color: 'pink' }
|
{ key: 9, text: 'AI.LS', value: 9, color: 'yellow' },
|
||||||
|
{ key: 10, text: 'AI Proxy', value: 10, color: 'purple' }
|
||||||
];
|
];
|
||||||
|
@ -46,6 +46,9 @@ const EditChannel = () => {
|
|||||||
if (localInputs.base_url.endsWith('/')) {
|
if (localInputs.base_url.endsWith('/')) {
|
||||||
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
||||||
}
|
}
|
||||||
|
if (localInputs.type === 3 && localInputs.other === '') {
|
||||||
|
localInputs.other = '2023-03-15-preview';
|
||||||
|
}
|
||||||
let res;
|
let res;
|
||||||
if (isEdit) {
|
if (isEdit) {
|
||||||
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
|
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
|
||||||
@ -164,7 +167,7 @@ const EditChannel = () => {
|
|||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
<Button onClick={submit}>提交</Button>
|
<Button positive onClick={submit}>提交</Button>
|
||||||
</Form>
|
</Form>
|
||||||
</Segment>
|
</Segment>
|
||||||
</>
|
</>
|
||||||
|
@ -111,7 +111,7 @@ const EditRedemption = () => {
|
|||||||
</Form.Field>
|
</Form.Field>
|
||||||
</>
|
</>
|
||||||
}
|
}
|
||||||
<Button onClick={submit}>提交</Button>
|
<Button positive onClick={submit}>提交</Button>
|
||||||
</Form>
|
</Form>
|
||||||
</Segment>
|
</Segment>
|
||||||
</>
|
</>
|
||||||
|
@ -106,6 +106,34 @@ const EditToken = () => {
|
|||||||
required={!isEdit}
|
required={!isEdit}
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Input
|
||||||
|
label='过期时间'
|
||||||
|
name='expired_time'
|
||||||
|
placeholder={'请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss,-1 表示无限制'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={expired_time}
|
||||||
|
autoComplete='new-password'
|
||||||
|
type='datetime-local'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
|
<div style={{ lineHeight: '40px' }}>
|
||||||
|
<Button type={'button'} onClick={() => {
|
||||||
|
setExpiredTime(0, 0, 0, 0);
|
||||||
|
}}>永不过期</Button>
|
||||||
|
<Button type={'button'} onClick={() => {
|
||||||
|
setExpiredTime(1, 0, 0, 0);
|
||||||
|
}}>一个月后过期</Button>
|
||||||
|
<Button type={'button'} onClick={() => {
|
||||||
|
setExpiredTime(0, 1, 0, 0);
|
||||||
|
}}>一天后过期</Button>
|
||||||
|
<Button type={'button'} onClick={() => {
|
||||||
|
setExpiredTime(0, 0, 1, 0);
|
||||||
|
}}>一小时后过期</Button>
|
||||||
|
<Button type={'button'} onClick={() => {
|
||||||
|
setExpiredTime(0, 0, 0, 1);
|
||||||
|
}}>一分钟后过期</Button>
|
||||||
|
</div>
|
||||||
<Message>注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。</Message>
|
<Message>注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。</Message>
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
@ -119,36 +147,10 @@ const EditToken = () => {
|
|||||||
disabled={unlimited_quota}
|
disabled={unlimited_quota}
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
<Button type={'button'} style={{ marginBottom: '14px' }} onClick={() => {
|
<Button type={'button'} onClick={() => {
|
||||||
setUnlimitedQuota();
|
setUnlimitedQuota();
|
||||||
}}>{unlimited_quota ? '取消无限额度' : '设置为无限额度'}</Button>
|
}}>{unlimited_quota ? '取消无限额度' : '设置为无限额度'}</Button>
|
||||||
<Form.Field>
|
<Button positive onClick={submit}>提交</Button>
|
||||||
<Form.Input
|
|
||||||
label='过期时间'
|
|
||||||
name='expired_time'
|
|
||||||
placeholder={'请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss,-1 表示无限制'}
|
|
||||||
onChange={handleInputChange}
|
|
||||||
value={expired_time}
|
|
||||||
autoComplete='new-password'
|
|
||||||
type='datetime-local'
|
|
||||||
/>
|
|
||||||
</Form.Field>
|
|
||||||
<Button type={'button'} onClick={() => {
|
|
||||||
setExpiredTime(0, 0, 0, 0);
|
|
||||||
}}>永不过期</Button>
|
|
||||||
<Button type={'button'} onClick={() => {
|
|
||||||
setExpiredTime(1, 0, 0, 0);
|
|
||||||
}}>一个月后过期</Button>
|
|
||||||
<Button type={'button'} onClick={() => {
|
|
||||||
setExpiredTime(0, 1, 0, 0);
|
|
||||||
}}>一天后过期</Button>
|
|
||||||
<Button type={'button'} onClick={() => {
|
|
||||||
setExpiredTime(0, 0, 1, 0);
|
|
||||||
}}>一小时后过期</Button>
|
|
||||||
<Button type={'button'} onClick={() => {
|
|
||||||
setExpiredTime(0, 0, 0, 1);
|
|
||||||
}}>一分钟后过期</Button>
|
|
||||||
<Button onClick={submit}>提交</Button>
|
|
||||||
</Form>
|
</Form>
|
||||||
</Segment>
|
</Segment>
|
||||||
</>
|
</>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Button, Form, Grid, Header, Segment, Statistic } from 'semantic-ui-react';
|
import { Button, Form, Grid, Header, Segment, Statistic } from 'semantic-ui-react';
|
||||||
import { API, showError, showSuccess } from '../../helpers';
|
import { API, showError, showInfo, showSuccess } from '../../helpers';
|
||||||
|
|
||||||
const TopUp = () => {
|
const TopUp = () => {
|
||||||
const [redemptionCode, setRedemptionCode] = useState('');
|
const [redemptionCode, setRedemptionCode] = useState('');
|
||||||
@ -9,6 +9,7 @@ const TopUp = () => {
|
|||||||
|
|
||||||
const topUp = async () => {
|
const topUp = async () => {
|
||||||
if (redemptionCode === '') {
|
if (redemptionCode === '') {
|
||||||
|
showInfo('请输入充值码!')
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const res = await API.post('/api/user/topup', {
|
const res = await API.post('/api/user/topup', {
|
||||||
@ -80,7 +81,7 @@ const TopUp = () => {
|
|||||||
<Grid.Column>
|
<Grid.Column>
|
||||||
<Statistic.Group widths='one'>
|
<Statistic.Group widths='one'>
|
||||||
<Statistic>
|
<Statistic>
|
||||||
<Statistic.Value>{userQuota}</Statistic.Value>
|
<Statistic.Value>{userQuota.toLocaleString()}</Statistic.Value>
|
||||||
<Statistic.Label>剩余额度</Statistic.Label>
|
<Statistic.Label>剩余额度</Statistic.Label>
|
||||||
</Statistic>
|
</Statistic>
|
||||||
</Statistic.Group>
|
</Statistic.Group>
|
||||||
|
@ -65,7 +65,7 @@ const AddUser = () => {
|
|||||||
required
|
required
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
<Button type={'submit'} onClick={submit}>
|
<Button positive type={'submit'} onClick={submit}>
|
||||||
提交
|
提交
|
||||||
</Button>
|
</Button>
|
||||||
</Form>
|
</Form>
|
||||||
|
@ -14,8 +14,9 @@ const EditUser = () => {
|
|||||||
github_id: '',
|
github_id: '',
|
||||||
wechat_id: '',
|
wechat_id: '',
|
||||||
email: '',
|
email: '',
|
||||||
|
quota: 0,
|
||||||
});
|
});
|
||||||
const { username, display_name, password, github_id, wechat_id, email } =
|
const { username, display_name, password, github_id, wechat_id, email, quota } =
|
||||||
inputs;
|
inputs;
|
||||||
const handleInputChange = (e, { name, value }) => {
|
const handleInputChange = (e, { name, value }) => {
|
||||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||||
@ -44,7 +45,11 @@ const EditUser = () => {
|
|||||||
const submit = async () => {
|
const submit = async () => {
|
||||||
let res = undefined;
|
let res = undefined;
|
||||||
if (userId) {
|
if (userId) {
|
||||||
res = await API.put(`/api/user/`, { ...inputs, id: parseInt(userId) });
|
let data = { ...inputs, id: parseInt(userId) };
|
||||||
|
if (typeof data.quota === 'string') {
|
||||||
|
data.quota = parseInt(data.quota);
|
||||||
|
}
|
||||||
|
res = await API.put(`/api/user/`, data);
|
||||||
} else {
|
} else {
|
||||||
res = await API.put(`/api/user/self`, inputs);
|
res = await API.put(`/api/user/self`, inputs);
|
||||||
}
|
}
|
||||||
@ -92,6 +97,21 @@ const EditUser = () => {
|
|||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
|
{
|
||||||
|
userId && (
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Input
|
||||||
|
label='剩余额度'
|
||||||
|
name='quota'
|
||||||
|
placeholder={'请输入新的剩余额度'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={quota}
|
||||||
|
type={'number'}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
label='已绑定的 GitHub 账户'
|
label='已绑定的 GitHub 账户'
|
||||||
@ -122,7 +142,7 @@ const EditUser = () => {
|
|||||||
readOnly
|
readOnly
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
<Button onClick={submit}>提交</Button>
|
<Button positive onClick={submit}>提交</Button>
|
||||||
</Form>
|
</Form>
|
||||||
</Segment>
|
</Segment>
|
||||||
</>
|
</>
|
||||||
|
Loading…
Reference in New Issue
Block a user