diff --git a/Dockerfile b/Dockerfile
index b21a7b3c..94cd8468 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,9 +1,16 @@
FROM node:16 as builder
-WORKDIR /build
-COPY ./web .
+WORKDIR /web
COPY ./VERSION .
-RUN chmod u+x ./build.sh && ./build.sh
+COPY ./web .
+
+WORKDIR /web/default
+RUN npm install
+RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
+
+WORKDIR /web/berry
+RUN npm install
+RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2
@@ -15,7 +22,7 @@ WORKDIR /build
ADD go.mod go.sum ./
RUN go mod download
COPY . .
-COPY --from=builder /build/build ./web/build
+COPY --from=builder /web/build ./web/build
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine
@@ -28,4 +35,4 @@ RUN apk update \
COPY --from=builder2 /build/one-api /
EXPOSE 3000
WORKDIR /data
-ENTRYPOINT ["/one-api"]
+ENTRYPOINT ["/one-api"]
\ No newline at end of file
diff --git a/README.md b/README.md
index 27acfedd..02a62387 100644
--- a/README.md
+++ b/README.md
@@ -414,6 +414,9 @@ https://openai.justsong.cn
8. 升级之前数据库需要做变更吗?
+ 一般情况下不需要,系统将在初始化的时候自动调整。
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
+9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`?
+ + 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。
+ + 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。
## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
diff --git a/common/config/config.go b/common/config/config.go
new file mode 100644
index 00000000..34d871fe
--- /dev/null
+++ b/common/config/config.go
@@ -0,0 +1,127 @@
+package config
+
+import (
+ "one-api/common/helper"
+ "os"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+)
+
+var SystemName = "One API"
+var ServerAddress = "http://localhost:3000"
+var Footer = ""
+var Logo = ""
+var TopUpLink = ""
+var ChatLink = ""
+var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
+var DisplayInCurrencyEnabled = true
+var DisplayTokenStatEnabled = true
+
+// Any options with "Secret", "Token" in its key won't be return by GetOptions
+
+var SessionSecret = uuid.New().String()
+
+var OptionMap map[string]string
+var OptionMapRWMutex sync.RWMutex
+
+var ItemsPerPage = 10
+var MaxRecentItems = 100
+
+var PasswordLoginEnabled = true
+var PasswordRegisterEnabled = true
+var EmailVerificationEnabled = false
+var GitHubOAuthEnabled = false
+var WeChatAuthEnabled = false
+var TurnstileCheckEnabled = false
+var RegisterEnabled = true
+
+var EmailDomainRestrictionEnabled = false
+var EmailDomainWhitelist = []string{
+ "gmail.com",
+ "163.com",
+ "126.com",
+ "qq.com",
+ "outlook.com",
+ "hotmail.com",
+ "icloud.com",
+ "yahoo.com",
+ "foxmail.com",
+}
+
+var DebugEnabled = os.Getenv("DEBUG") == "true"
+var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
+
+var LogConsumeEnabled = true
+
+var SMTPServer = ""
+var SMTPPort = 587
+var SMTPAccount = ""
+var SMTPFrom = ""
+var SMTPToken = ""
+
+var GitHubClientId = ""
+var GitHubClientSecret = ""
+
+var WeChatServerAddress = ""
+var WeChatServerToken = ""
+var WeChatAccountQRCodeImageURL = ""
+
+var TurnstileSiteKey = ""
+var TurnstileSecretKey = ""
+
+var QuotaForNewUser = 0
+var QuotaForInviter = 0
+var QuotaForInvitee = 0
+var ChannelDisableThreshold = 5.0
+var AutomaticDisableChannelEnabled = false
+var AutomaticEnableChannelEnabled = false
+var QuotaRemindThreshold = 1000
+var PreConsumedQuota = 500
+var ApproximateTokenEnabled = false
+var RetryTimes = 0
+
+var RootUserEmail = ""
+
+var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
+
+var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
+var RequestInterval = time.Duration(requestInterval) * time.Second
+
+var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
+
+var BatchUpdateEnabled = false
+var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
+
+var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
+
+var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
+
+var Theme = helper.GetOrDefaultEnvString("THEME", "default")
+var ValidThemes = map[string]bool{
+ "default": true,
+ "berry": true,
+}
+
+// All duration's unit is seconds
+// Shouldn't larger then RateLimitKeyExpirationDuration
+var (
+ GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
+ GlobalApiRateLimitDuration int64 = 3 * 60
+
+ GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
+ GlobalWebRateLimitDuration int64 = 3 * 60
+
+ UploadRateLimitNum = 10
+ UploadRateLimitDuration int64 = 60
+
+ DownloadRateLimitNum = 10
+ DownloadRateLimitDuration int64 = 60
+
+ CriticalRateLimitNum = 20
+ CriticalRateLimitDuration int64 = 20 * 60
+)
+
+var RateLimitKeyExpirationDuration = 20 * time.Minute
diff --git a/common/constants.go b/common/constants.go
index 70589041..325454d4 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -1,110 +1,9 @@
package common
-import (
- "os"
- "strconv"
- "sync"
- "time"
-
- "github.com/google/uuid"
-)
+import "time"
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
-var SystemName = "One API"
-var ServerAddress = "http://localhost:3000"
-var Footer = ""
-var Logo = ""
-var TopUpLink = ""
-var ChatLink = ""
-var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
-var DisplayInCurrencyEnabled = true
-var DisplayTokenStatEnabled = true
-
-// Any options with "Secret", "Token" in its key won't be return by GetOptions
-
-var SessionSecret = uuid.New().String()
-
-var OptionMap map[string]string
-var OptionMapRWMutex sync.RWMutex
-
-var ItemsPerPage = 10
-var MaxRecentItems = 100
-
-var PasswordLoginEnabled = true
-var PasswordRegisterEnabled = true
-var EmailVerificationEnabled = false
-var GitHubOAuthEnabled = false
-var WeChatAuthEnabled = false
-var TurnstileCheckEnabled = false
-var RegisterEnabled = true
-
-var EmailDomainRestrictionEnabled = false
-var EmailDomainWhitelist = []string{
- "gmail.com",
- "163.com",
- "126.com",
- "qq.com",
- "outlook.com",
- "hotmail.com",
- "icloud.com",
- "yahoo.com",
- "foxmail.com",
-}
-
-var DebugEnabled = os.Getenv("DEBUG") == "true"
-var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
-
-var LogConsumeEnabled = true
-
-var SMTPServer = ""
-var SMTPPort = 587
-var SMTPAccount = ""
-var SMTPFrom = ""
-var SMTPToken = ""
-
-var GitHubClientId = ""
-var GitHubClientSecret = ""
-
-var WeChatServerAddress = ""
-var WeChatServerToken = ""
-var WeChatAccountQRCodeImageURL = ""
-
-var TurnstileSiteKey = ""
-var TurnstileSecretKey = ""
-
-var QuotaForNewUser = 0
-var QuotaForInviter = 0
-var QuotaForInvitee = 0
-var ChannelDisableThreshold = 5.0
-var AutomaticDisableChannelEnabled = false
-var AutomaticEnableChannelEnabled = false
-var QuotaRemindThreshold = 1000
-var PreConsumedQuota = 500
-var ApproximateTokenEnabled = false
-var RetryTimes = 0
-
-var RootUserEmail = ""
-
-var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
-
-var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
-var RequestInterval = time.Duration(requestInterval) * time.Second
-
-var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
-
-var BatchUpdateEnabled = false
-var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
-
-var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
-
-var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
-
-var Theme = GetOrDefaultString("THEME", "default")
-
-const (
- RequestIdKey = "X-Oneapi-Request-Id"
-)
const (
RoleGuestUser = 0
@@ -113,34 +12,6 @@ const (
RoleRootUser = 100
)
-var (
- FileUploadPermission = RoleGuestUser
- FileDownloadPermission = RoleGuestUser
- ImageUploadPermission = RoleGuestUser
- ImageDownloadPermission = RoleGuestUser
-)
-
-// All duration's unit is seconds
-// Shouldn't larger then RateLimitKeyExpirationDuration
-var (
- GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
- GlobalApiRateLimitDuration int64 = 3 * 60
-
- GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
- GlobalWebRateLimitDuration int64 = 3 * 60
-
- UploadRateLimitNum = 10
- UploadRateLimitDuration int64 = 60
-
- DownloadRateLimitNum = 10
- DownloadRateLimitDuration int64 = 60
-
- CriticalRateLimitNum = 20
- CriticalRateLimitDuration int64 = 20 * 60
-)
-
-var RateLimitKeyExpirationDuration = 20 * time.Minute
-
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
@@ -195,29 +66,29 @@ const (
)
var ChannelBaseURLs = []string{
- "", // 0
- "https://api.openai.com", // 1
- "https://oa.api2d.net", // 2
- "", // 3
- "https://api.closeai-proxy.xyz", // 4
- "https://api.openai-sb.com", // 5
- "https://api.openaimax.com", // 6
- "https://api.ohmygpt.com", // 7
- "", // 8
- "https://api.caipacity.com", // 9
- "https://api.aiproxy.io", // 10
- "", // 11
- "https://api.api2gpt.com", // 12
- "https://api.aigc2d.com", // 13
- "https://api.anthropic.com", // 14
- "https://aip.baidubce.com", // 15
- "https://open.bigmodel.cn", // 16
- "https://dashscope.aliyuncs.com", // 17
- "", // 18
- "https://ai.360.cn", // 19
- "https://openrouter.ai/api", // 20
- "https://api.aiproxy.io", // 21
- "https://fastgpt.run/api/openapi", // 22
- "https://hunyuan.cloud.tencent.com", //23
- "", //24
+ "", // 0
+ "https://api.openai.com", // 1
+ "https://oa.api2d.net", // 2
+ "", // 3
+ "https://api.closeai-proxy.xyz", // 4
+ "https://api.openai-sb.com", // 5
+ "https://api.openaimax.com", // 6
+ "https://api.ohmygpt.com", // 7
+ "", // 8
+ "https://api.caipacity.com", // 9
+ "https://api.aiproxy.io", // 10
+ "https://generativelanguage.googleapis.com", // 11
+ "https://api.api2gpt.com", // 12
+ "https://api.aigc2d.com", // 13
+ "https://api.anthropic.com", // 14
+ "https://aip.baidubce.com", // 15
+ "https://open.bigmodel.cn", // 16
+ "https://dashscope.aliyuncs.com", // 17
+ "", // 18
+ "https://ai.360.cn", // 19
+ "https://openrouter.ai/api", // 20
+ "https://api.aiproxy.io", // 21
+ "https://fastgpt.run/api/openapi", // 22
+ "https://hunyuan.cloud.tencent.com", // 23
+ "https://generativelanguage.googleapis.com", // 24
}
diff --git a/common/database.go b/common/database.go
index 76f2cd55..4c2c2717 100644
--- a/common/database.go
+++ b/common/database.go
@@ -1,7 +1,9 @@
package common
+import "one-api/common/helper"
+
var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"
-var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)
+var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000)
diff --git a/common/email.go b/common/email.go
index b915f0f9..28719742 100644
--- a/common/email.go
+++ b/common/email.go
@@ -6,18 +6,19 @@ import (
"encoding/base64"
"fmt"
"net/smtp"
+ "one-api/common/config"
"strings"
"time"
)
func SendEmail(subject string, receiver string, content string) error {
- if SMTPFrom == "" { // for compatibility
- SMTPFrom = SMTPAccount
+ if config.SMTPFrom == "" { // for compatibility
+ config.SMTPFrom = config.SMTPAccount
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom
- parts := strings.Split(SMTPFrom, "@")
+ parts := strings.Split(config.SMTPFrom, "@")
var domain string
if len(parts) > 1 {
domain = parts[1]
@@ -36,21 +37,21 @@ func SendEmail(subject string, receiver string, content string) error {
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
- receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
- auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
- addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
+ receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
+ auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
+ addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";")
- if SMTPPort == 465 {
+ if config.SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
- ServerName: SMTPServer,
+ ServerName: config.SMTPServer,
}
- conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
+ conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
if err != nil {
return err
}
- client, err := smtp.NewClient(conn, SMTPServer)
+ client, err := smtp.NewClient(conn, config.SMTPServer)
if err != nil {
return err
}
@@ -58,7 +59,7 @@ func SendEmail(subject string, receiver string, content string) error {
if err = client.Auth(auth); err != nil {
return err
}
- if err = client.Mail(SMTPFrom); err != nil {
+ if err = client.Mail(config.SMTPFrom); err != nil {
return err
}
receiverEmails := strings.Split(receiver, ";")
@@ -80,7 +81,7 @@ func SendEmail(subject string, receiver string, content string) error {
return err
}
} else {
- err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
+ err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail)
}
return err
}
diff --git a/common/gin.go b/common/gin.go
index f5012688..bed2c2b1 100644
--- a/common/gin.go
+++ b/common/gin.go
@@ -31,3 +31,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}
+
+func SetEventStreamHeaders(c *gin.Context) {
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("Transfer-Encoding", "chunked")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+}
diff --git a/common/group-ratio.go b/common/group-ratio.go
index 1ec73c78..86e9e1f6 100644
--- a/common/group-ratio.go
+++ b/common/group-ratio.go
@@ -1,6 +1,9 @@
package common
-import "encoding/json"
+import (
+ "encoding/json"
+ "one-api/common/logger"
+)
var GroupRatio = map[string]float64{
"default": 1,
@@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
if err != nil {
- SysError("error marshalling model ratio: " + err.Error())
+ logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error {
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
if !ok {
- SysError("group ratio not found: " + name)
+ logger.SysError("group ratio not found: " + name)
return 1
}
return ratio
diff --git a/common/helper/helper.go b/common/helper/helper.go
new file mode 100644
index 00000000..09f1df29
--- /dev/null
+++ b/common/helper/helper.go
@@ -0,0 +1,224 @@
+package helper
+
+import (
+ "fmt"
+ "github.com/google/uuid"
+ "html/template"
+ "log"
+ "math/rand"
+ "net"
+ "one-api/common/logger"
+ "os"
+ "os/exec"
+ "runtime"
+ "strconv"
+ "strings"
+ "time"
+)
+
+func OpenBrowser(url string) {
+ var err error
+
+ switch runtime.GOOS {
+ case "linux":
+ err = exec.Command("xdg-open", url).Start()
+ case "windows":
+ err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
+ case "darwin":
+ err = exec.Command("open", url).Start()
+ }
+ if err != nil {
+ log.Println(err)
+ }
+}
+
+func GetIp() (ip string) {
+ ips, err := net.InterfaceAddrs()
+ if err != nil {
+ log.Println(err)
+ return ip
+ }
+
+ for _, a := range ips {
+ if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
+ if ipNet.IP.To4() != nil {
+ ip = ipNet.IP.String()
+ if strings.HasPrefix(ip, "10") {
+ return
+ }
+ if strings.HasPrefix(ip, "172") {
+ return
+ }
+ if strings.HasPrefix(ip, "192.168") {
+ return
+ }
+ ip = ""
+ }
+ }
+ }
+ return
+}
+
+var sizeKB = 1024
+var sizeMB = sizeKB * 1024
+var sizeGB = sizeMB * 1024
+
+func Bytes2Size(num int64) string {
+ numStr := ""
+ unit := "B"
+ if num/int64(sizeGB) > 1 {
+ numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
+ unit = "GB"
+ } else if num/int64(sizeMB) > 1 {
+ numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
+ unit = "MB"
+ } else if num/int64(sizeKB) > 1 {
+ numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
+ unit = "KB"
+ } else {
+ numStr = fmt.Sprintf("%d", num)
+ }
+ return numStr + " " + unit
+}
+
+func Seconds2Time(num int) (time string) {
+ if num/31104000 > 0 {
+ time += strconv.Itoa(num/31104000) + " 年 "
+ num %= 31104000
+ }
+ if num/2592000 > 0 {
+ time += strconv.Itoa(num/2592000) + " 个月 "
+ num %= 2592000
+ }
+ if num/86400 > 0 {
+ time += strconv.Itoa(num/86400) + " 天 "
+ num %= 86400
+ }
+ if num/3600 > 0 {
+ time += strconv.Itoa(num/3600) + " 小时 "
+ num %= 3600
+ }
+ if num/60 > 0 {
+ time += strconv.Itoa(num/60) + " 分钟 "
+ num %= 60
+ }
+ time += strconv.Itoa(num) + " 秒"
+ return
+}
+
+func Interface2String(inter interface{}) string {
+ switch inter.(type) {
+ case string:
+ return inter.(string)
+ case int:
+ return fmt.Sprintf("%d", inter.(int))
+ case float64:
+ return fmt.Sprintf("%f", inter.(float64))
+ }
+ return "Not Implemented"
+}
+
+func UnescapeHTML(x string) interface{} {
+ return template.HTML(x)
+}
+
+func IntMax(a int, b int) int {
+ if a >= b {
+ return a
+ } else {
+ return b
+ }
+}
+
+func GetUUID() string {
+ code := uuid.New().String()
+ code = strings.Replace(code, "-", "", -1)
+ return code
+}
+
+const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+
+func init() {
+ rand.Seed(time.Now().UnixNano())
+}
+
+func GenerateKey() string {
+ rand.Seed(time.Now().UnixNano())
+ key := make([]byte, 48)
+ for i := 0; i < 16; i++ {
+ key[i] = keyChars[rand.Intn(len(keyChars))]
+ }
+ uuid_ := GetUUID()
+ for i := 0; i < 32; i++ {
+ c := uuid_[i]
+ if i%2 == 0 && c >= 'a' && c <= 'z' {
+ c = c - 'a' + 'A'
+ }
+ key[i+16] = c
+ }
+ return string(key)
+}
+
+func GetRandomString(length int) string {
+ rand.Seed(time.Now().UnixNano())
+ key := make([]byte, length)
+ for i := 0; i < length; i++ {
+ key[i] = keyChars[rand.Intn(len(keyChars))]
+ }
+ return string(key)
+}
+
+func GetTimestamp() int64 {
+ return time.Now().Unix()
+}
+
+func GetTimeString() string {
+ now := time.Now()
+ return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
+}
+
+func Max(a int, b int) int {
+ if a >= b {
+ return a
+ } else {
+ return b
+ }
+}
+
+func GetOrDefaultEnvInt(env string, defaultValue int) int {
+ if env == "" || os.Getenv(env) == "" {
+ return defaultValue
+ }
+ num, err := strconv.Atoi(os.Getenv(env))
+ if err != nil {
+ logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
+ return defaultValue
+ }
+ return num
+}
+
+func GetOrDefaultEnvString(env string, defaultValue string) string {
+ if env == "" || os.Getenv(env) == "" {
+ return defaultValue
+ }
+ return os.Getenv(env)
+}
+
+func AssignOrDefault(value string, defaultValue string) string {
+ if len(value) != 0 {
+ return value
+ }
+ return defaultValue
+}
+
+func MessageWithRequestId(message string, id string) string {
+ return fmt.Sprintf("%s (request id: %s)", message, id)
+}
+
+func String2Int(str string) int {
+ num, err := strconv.Atoi(str)
+ if err != nil {
+ return 0
+ }
+ return num
+}
diff --git a/common/init.go b/common/init.go
index 12df5f51..abc71108 100644
--- a/common/init.go
+++ b/common/init.go
@@ -4,6 +4,8 @@ import (
"flag"
"fmt"
"log"
+ "one-api/common/config"
+ "one-api/common/logger"
"os"
"path/filepath"
)
@@ -37,9 +39,9 @@ func init() {
if os.Getenv("SESSION_SECRET") != "" {
if os.Getenv("SESSION_SECRET") == "random_string" {
- SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
+ logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
- SessionSecret = os.Getenv("SESSION_SECRET")
+ config.SessionSecret = os.Getenv("SESSION_SECRET")
}
}
if os.Getenv("SQLITE_PATH") != "" {
@@ -57,5 +59,6 @@ func init() {
log.Fatal(err)
}
}
+ logger.LogDir = *LogDir
}
}
diff --git a/common/logger/constants.go b/common/logger/constants.go
new file mode 100644
index 00000000..78d32062
--- /dev/null
+++ b/common/logger/constants.go
@@ -0,0 +1,7 @@
+package logger
+
+const (
+ RequestIdKey = "X-Oneapi-Request-Id"
+)
+
+var LogDir string
diff --git a/common/logger.go b/common/logger/logger.go
similarity index 76%
rename from common/logger.go
rename to common/logger/logger.go
index 61627217..b89dbdb7 100644
--- a/common/logger.go
+++ b/common/logger/logger.go
@@ -1,4 +1,4 @@
-package common
+package logger
import (
"context"
@@ -25,7 +25,7 @@ var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
- if *LogDir != "" {
+ if LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
@@ -35,7 +35,7 @@ func SetupLogger() {
setupLogLock.Unlock()
setupLogWorking = false
}()
- logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
+ logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
@@ -55,18 +55,30 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
-func LogInfo(ctx context.Context, msg string) {
+func Info(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
-func LogWarn(ctx context.Context, msg string) {
+func Warn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
-func LogError(ctx context.Context, msg string) {
+func Error(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
+func Infof(ctx context.Context, format string, a ...any) {
+ Info(ctx, fmt.Sprintf(format, a))
+}
+
+func Warnf(ctx context.Context, format string, a ...any) {
+ Warn(ctx, fmt.Sprintf(format, a))
+}
+
+func Errorf(ctx context.Context, format string, a ...any) {
+ Error(ctx, fmt.Sprintf(format, a))
+}
+
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
@@ -90,11 +102,3 @@ func FatalLog(v ...any) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
-
-func LogQuota(quota int) string {
- if DisplayInCurrencyEnabled {
- return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
- } else {
- return fmt.Sprintf("%d 点额度", quota)
- }
-}
diff --git a/common/model-ratio.go b/common/model-ratio.go
index 97cb060d..9f31e0d7 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -2,6 +2,7 @@ package common
import (
"encoding/json"
+ "one-api/common/logger"
"strings"
"time"
)
@@ -107,7 +108,7 @@ var ModelRatio = map[string]float64{
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
- SysError("error marshalling model ratio: " + err.Error())
+ logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -123,7 +124,7 @@ func GetModelRatio(name string) float64 {
}
ratio, ok := ModelRatio[name]
if !ok {
- SysError("model ratio not found: " + name)
+ logger.SysError("model ratio not found: " + name)
return 30
}
return ratio
diff --git a/common/redis.go b/common/redis.go
index 12c477b8..ed3fcd9d 100644
--- a/common/redis.go
+++ b/common/redis.go
@@ -3,6 +3,7 @@ package common
import (
"context"
"github.com/go-redis/redis/v8"
+ "one-api/common/logger"
"os"
"time"
)
@@ -14,18 +15,18 @@ var RedisEnabled = true
func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" {
RedisEnabled = false
- SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
+ logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
return nil
}
if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false
- SysLog("SYNC_FREQUENCY not set, Redis is disabled")
+ logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil
}
- SysLog("Redis is enabled")
+ logger.SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
- FatalLog("failed to parse Redis connection string: " + err.Error())
+ logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
RDB = redis.NewClient(opt)
@@ -34,7 +35,7 @@ func InitRedisClient() (err error) {
_, err = RDB.Ping(ctx).Result()
if err != nil {
- FatalLog("Redis ping test failed: " + err.Error())
+ logger.FatalLog("Redis ping test failed: " + err.Error())
}
return err
}
@@ -42,7 +43,7 @@ func InitRedisClient() (err error) {
func ParseRedisOption() *redis.Options {
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
- FatalLog("failed to parse Redis connection string: " + err.Error())
+ logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
return opt
}
diff --git a/common/utils.go b/common/utils.go
index 9a7038e2..41de9367 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -2,215 +2,13 @@ package common
import (
"fmt"
- "github.com/google/uuid"
- "html/template"
- "log"
- "math/rand"
- "net"
- "os"
- "os/exec"
- "runtime"
- "strconv"
- "strings"
- "time"
+ "one-api/common/config"
)
-func OpenBrowser(url string) {
- var err error
-
- switch runtime.GOOS {
- case "linux":
- err = exec.Command("xdg-open", url).Start()
- case "windows":
- err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
- case "darwin":
- err = exec.Command("open", url).Start()
- }
- if err != nil {
- log.Println(err)
- }
-}
-
-func GetIp() (ip string) {
- ips, err := net.InterfaceAddrs()
- if err != nil {
- log.Println(err)
- return ip
- }
-
- for _, a := range ips {
- if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
- if ipNet.IP.To4() != nil {
- ip = ipNet.IP.String()
- if strings.HasPrefix(ip, "10") {
- return
- }
- if strings.HasPrefix(ip, "172") {
- return
- }
- if strings.HasPrefix(ip, "192.168") {
- return
- }
- ip = ""
- }
- }
- }
- return
-}
-
-var sizeKB = 1024
-var sizeMB = sizeKB * 1024
-var sizeGB = sizeMB * 1024
-
-func Bytes2Size(num int64) string {
- numStr := ""
- unit := "B"
- if num/int64(sizeGB) > 1 {
- numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
- unit = "GB"
- } else if num/int64(sizeMB) > 1 {
- numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
- unit = "MB"
- } else if num/int64(sizeKB) > 1 {
- numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
- unit = "KB"
+func LogQuota(quota int) string {
+ if config.DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit)
} else {
- numStr = fmt.Sprintf("%d", num)
- }
- return numStr + " " + unit
-}
-
-func Seconds2Time(num int) (time string) {
- if num/31104000 > 0 {
- time += strconv.Itoa(num/31104000) + " 年 "
- num %= 31104000
- }
- if num/2592000 > 0 {
- time += strconv.Itoa(num/2592000) + " 个月 "
- num %= 2592000
- }
- if num/86400 > 0 {
- time += strconv.Itoa(num/86400) + " 天 "
- num %= 86400
- }
- if num/3600 > 0 {
- time += strconv.Itoa(num/3600) + " 小时 "
- num %= 3600
- }
- if num/60 > 0 {
- time += strconv.Itoa(num/60) + " 分钟 "
- num %= 60
- }
- time += strconv.Itoa(num) + " 秒"
- return
-}
-
-func Interface2String(inter interface{}) string {
- switch inter.(type) {
- case string:
- return inter.(string)
- case int:
- return fmt.Sprintf("%d", inter.(int))
- case float64:
- return fmt.Sprintf("%f", inter.(float64))
- }
- return "Not Implemented"
-}
-
-func UnescapeHTML(x string) interface{} {
- return template.HTML(x)
-}
-
-func IntMax(a int, b int) int {
- if a >= b {
- return a
- } else {
- return b
+ return fmt.Sprintf("%d 点额度", quota)
}
}
-
-func GetUUID() string {
- code := uuid.New().String()
- code = strings.Replace(code, "-", "", -1)
- return code
-}
-
-const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
-
-func init() {
- rand.Seed(time.Now().UnixNano())
-}
-
-func GenerateKey() string {
- rand.Seed(time.Now().UnixNano())
- key := make([]byte, 48)
- for i := 0; i < 16; i++ {
- key[i] = keyChars[rand.Intn(len(keyChars))]
- }
- uuid_ := GetUUID()
- for i := 0; i < 32; i++ {
- c := uuid_[i]
- if i%2 == 0 && c >= 'a' && c <= 'z' {
- c = c - 'a' + 'A'
- }
- key[i+16] = c
- }
- return string(key)
-}
-
-func GetRandomString(length int) string {
- rand.Seed(time.Now().UnixNano())
- key := make([]byte, length)
- for i := 0; i < length; i++ {
- key[i] = keyChars[rand.Intn(len(keyChars))]
- }
- return string(key)
-}
-
-func GetTimestamp() int64 {
- return time.Now().Unix()
-}
-
-func GetTimeString() string {
- now := time.Now()
- return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
-}
-
-func Max(a int, b int) int {
- if a >= b {
- return a
- } else {
- return b
- }
-}
-
-func GetOrDefault(env string, defaultValue int) int {
- if env == "" || os.Getenv(env) == "" {
- return defaultValue
- }
- num, err := strconv.Atoi(os.Getenv(env))
- if err != nil {
- SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
- return defaultValue
- }
- return num
-}
-
-func GetOrDefaultString(env string, defaultValue string) string {
- if env == "" || os.Getenv(env) == "" {
- return defaultValue
- }
- return os.Getenv(env)
-}
-
-func MessageWithRequestId(message string, id string) string {
- return fmt.Sprintf("%s (request id: %s)", message, id)
-}
-
-func String2Int(str string) int {
- num, err := strconv.Atoi(str)
- if err != nil {
- return 0
- }
- return num
-}
diff --git a/controller/billing.go b/controller/billing.go
index 42e86aea..1003bf10 100644
--- a/controller/billing.go
+++ b/controller/billing.go
@@ -2,8 +2,9 @@ package controller
import (
"github.com/gin-gonic/gin"
- "one-api/common"
+ "one-api/common/config"
"one-api/model"
+ "one-api/relay/channel/openai"
)
func GetSubscription(c *gin.Context) {
@@ -12,7 +13,7 @@ func GetSubscription(c *gin.Context) {
var err error
var token *model.Token
var expiredTime int64
- if common.DisplayTokenStatEnabled {
+ if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime
@@ -27,19 +28,19 @@ func GetSubscription(c *gin.Context) {
expiredTime = 0
}
if err != nil {
- openAIError := OpenAIError{
+ Error := openai.Error{
Message: err.Error(),
Type: "upstream_error",
}
c.JSON(200, gin.H{
- "error": openAIError,
+ "error": Error,
})
return
}
quota := remainQuota + usedQuota
amount := float64(quota)
- if common.DisplayInCurrencyEnabled {
- amount /= common.QuotaPerUnit
+ if config.DisplayInCurrencyEnabled {
+ amount /= config.QuotaPerUnit
}
if token != nil && token.UnlimitedQuota {
amount = 100000000
@@ -60,7 +61,7 @@ func GetUsage(c *gin.Context) {
var quota int
var err error
var token *model.Token
- if common.DisplayTokenStatEnabled {
+ if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
quota = token.UsedQuota
@@ -69,18 +70,18 @@ func GetUsage(c *gin.Context) {
quota, err = model.GetUserUsedQuota(userId)
}
if err != nil {
- openAIError := OpenAIError{
+ Error := openai.Error{
Message: err.Error(),
Type: "one_api_error",
}
c.JSON(200, gin.H{
- "error": openAIError,
+ "error": Error,
})
return
}
amount := float64(quota)
- if common.DisplayInCurrencyEnabled {
- amount /= common.QuotaPerUnit
+ if config.DisplayInCurrencyEnabled {
+ amount /= config.QuotaPerUnit
}
usage := OpenAIUsageResponse{
Object: "list",
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 6ddad7ea..49c76760 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -7,7 +7,10 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/logger"
"one-api/model"
+ "one-api/relay/util"
"strconv"
"time"
@@ -92,7 +95,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
- res, err := httpClient.Do(req)
+ res, err := util.HTTPClient.Do(req)
if err != nil {
return nil, err
}
@@ -313,7 +316,7 @@ func updateAllChannelsBalance() error {
disableChannel(channel.Id, channel.Name, "余额不足")
}
}
- time.Sleep(common.RequestInterval)
+ time.Sleep(config.RequestInterval)
}
return nil
}
@@ -338,8 +341,8 @@ func UpdateAllChannelsBalance(c *gin.Context) {
func AutomaticallyUpdateChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
- common.SysLog("updating all channels")
+ logger.SysLog("updating all channels")
_ = updateAllChannelsBalance()
- common.SysLog("channels update done")
+ logger.SysLog("channels update done")
}
}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 3aaa4897..88d6e3f2 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -8,7 +8,11 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/logger"
"one-api/model"
+ "one-api/relay/channel/openai"
+ "one-api/relay/util"
"strconv"
"sync"
"time"
@@ -16,7 +20,7 @@ import (
"github.com/gin-gonic/gin"
)
-func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
+func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) {
switch channel.Type {
case common.ChannelTypePaLM:
fallthrough
@@ -46,13 +50,13 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
- requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
+ requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
} else {
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
requestURL = baseURL
}
- requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
+ requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
}
jsonData, err := json.Marshal(request)
if err != nil {
@@ -68,12 +72,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
- resp, err := httpClient.Do(req)
+ resp, err := util.HTTPClient.Do(req)
if err != nil {
return err, nil
}
defer resp.Body.Close()
- var response TextResponse
+ var response openai.SlimTextResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
return err, nil
@@ -91,12 +95,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
return nil, nil
}
-func buildTestRequest() *ChatRequest {
- testRequest := &ChatRequest{
+func buildTestRequest() *openai.ChatRequest {
+ testRequest := &openai.ChatRequest{
Model: "", // this will be set later
MaxTokens: 1,
}
- testMessage := Message{
+ testMessage := openai.Message{
Role: "user",
Content: "hi",
}
@@ -148,12 +152,12 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func notifyRootUser(subject string, content string) {
- if common.RootUserEmail == "" {
- common.RootUserEmail = model.GetRootUserEmail()
+ if config.RootUserEmail == "" {
+ config.RootUserEmail = model.GetRootUserEmail()
}
- err := common.SendEmail(subject, common.RootUserEmail, content)
+ err := common.SendEmail(subject, config.RootUserEmail, content)
if err != nil {
- common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+ logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
@@ -174,8 +178,8 @@ func enableChannel(channelId int, channelName string) {
}
func testAllChannels(notify bool) error {
- if common.RootUserEmail == "" {
- common.RootUserEmail = model.GetRootUserEmail()
+ if config.RootUserEmail == "" {
+ config.RootUserEmail = model.GetRootUserEmail()
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
@@ -189,7 +193,7 @@ func testAllChannels(notify bool) error {
return err
}
testRequest := buildTestRequest()
- var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
+ var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
@@ -204,22 +208,22 @@ func testAllChannels(notify bool) error {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
disableChannel(channel.Id, channel.Name, err.Error())
}
- if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
+ if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) {
disableChannel(channel.Id, channel.Name, err.Error())
}
- if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
+ if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
- time.Sleep(common.RequestInterval)
+ time.Sleep(config.RequestInterval)
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
- err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
+ err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
- common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+ logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
}()
@@ -245,8 +249,8 @@ func TestAllChannels(c *gin.Context) {
func AutomaticallyTestChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
- common.SysLog("testing all channels")
+ logger.SysLog("testing all channels")
_ = testAllChannels(false)
- common.SysLog("channel test finished")
+ logger.SysLog("channel test finished")
}
}
diff --git a/controller/channel.go b/controller/channel.go
index 904abc23..6d368066 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -3,7 +3,8 @@ package controller
import (
"github.com/gin-gonic/gin"
"net/http"
- "one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
"one-api/model"
"strconv"
"strings"
@@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) {
if p < 0 {
p = 0
}
- channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
+ channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) {
})
return
}
- channel.CreatedTime = common.GetTimestamp()
+ channel.CreatedTime = helper.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
diff --git a/controller/github.go b/controller/github.go
index ee995379..f6049ddb 100644
--- a/controller/github.go
+++ b/controller/github.go
@@ -9,6 +9,9 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
+ "one-api/common/logger"
"one-api/model"
"strconv"
"time"
@@ -30,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
- values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
+ values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
@@ -46,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
}
res, err := client.Do(req)
if err != nil {
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res.Body.Close()
@@ -62,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res2.Body.Close()
@@ -93,7 +96,7 @@ func GitHubOAuth(c *gin.Context) {
return
}
- if !common.GitHubOAuthEnabled {
+ if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
@@ -122,7 +125,7 @@ func GitHubOAuth(c *gin.Context) {
return
}
} else {
- if common.RegisterEnabled {
+ if config.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" {
user.DisplayName = githubUser.Name
@@ -160,7 +163,7 @@ func GitHubOAuth(c *gin.Context) {
}
func GitHubBind(c *gin.Context) {
- if !common.GitHubOAuthEnabled {
+ if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
@@ -216,7 +219,7 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
- state := common.GetRandomString(12)
+ state := helper.GetRandomString(12)
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
diff --git a/controller/log.go b/controller/log.go
index b65867fe..418c6859 100644
--- a/controller/log.go
+++ b/controller/log.go
@@ -3,7 +3,7 @@ package controller
import (
"github.com/gin-gonic/gin"
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/model"
"strconv"
)
@@ -20,7 +20,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
- logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
+ logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -47,7 +47,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
- logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
+ logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
diff --git a/controller/misc.go b/controller/misc.go
index 2bcbb41f..df7c0728 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/model"
"strings"
@@ -18,55 +19,55 @@ func GetStatus(c *gin.Context) {
"data": gin.H{
"version": common.Version,
"start_time": common.StartTime,
- "email_verification": common.EmailVerificationEnabled,
- "github_oauth": common.GitHubOAuthEnabled,
- "github_client_id": common.GitHubClientId,
- "system_name": common.SystemName,
- "logo": common.Logo,
- "footer_html": common.Footer,
- "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
- "wechat_login": common.WeChatAuthEnabled,
- "server_address": common.ServerAddress,
- "turnstile_check": common.TurnstileCheckEnabled,
- "turnstile_site_key": common.TurnstileSiteKey,
- "top_up_link": common.TopUpLink,
- "chat_link": common.ChatLink,
- "quota_per_unit": common.QuotaPerUnit,
- "display_in_currency": common.DisplayInCurrencyEnabled,
+ "email_verification": config.EmailVerificationEnabled,
+ "github_oauth": config.GitHubOAuthEnabled,
+ "github_client_id": config.GitHubClientId,
+ "system_name": config.SystemName,
+ "logo": config.Logo,
+ "footer_html": config.Footer,
+ "wechat_qrcode": config.WeChatAccountQRCodeImageURL,
+ "wechat_login": config.WeChatAuthEnabled,
+ "server_address": config.ServerAddress,
+ "turnstile_check": config.TurnstileCheckEnabled,
+ "turnstile_site_key": config.TurnstileSiteKey,
+ "top_up_link": config.TopUpLink,
+ "chat_link": config.ChatLink,
+ "quota_per_unit": config.QuotaPerUnit,
+ "display_in_currency": config.DisplayInCurrencyEnabled,
},
})
return
}
func GetNotice(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
+ config.OptionMapRWMutex.RLock()
+ defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": common.OptionMap["Notice"],
+ "data": config.OptionMap["Notice"],
})
return
}
func GetAbout(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
+ config.OptionMapRWMutex.RLock()
+ defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": common.OptionMap["About"],
+ "data": config.OptionMap["About"],
})
return
}
func GetHomePageContent(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
+ config.OptionMapRWMutex.RLock()
+ defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": common.OptionMap["HomePageContent"],
+ "data": config.OptionMap["HomePageContent"],
})
return
}
@@ -80,9 +81,9 @@ func SendEmailVerification(c *gin.Context) {
})
return
}
- if common.EmailDomainRestrictionEnabled {
+ if config.EmailDomainRestrictionEnabled {
allowed := false
- for _, domain := range common.EmailDomainWhitelist {
+ for _, domain := range config.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) {
allowed = true
break
@@ -105,10 +106,10 @@ func SendEmailVerification(c *gin.Context) {
}
code := common.GenerateVerificationCode(6)
common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
- subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
+ subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName)
content := fmt.Sprintf("
您好,你正在进行%s邮箱验证。
"+
"您的验证码为: %s
"+
- "验证码 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, code, common.VerificationValidMinutes)
+ "验证码 %d 分钟内有效,如果不是本人操作,请忽略。
", config.SystemName, code, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -142,12 +143,12 @@ func SendPasswordResetEmail(c *gin.Context) {
}
code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
- link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
- subject := fmt.Sprintf("%s密码重置", common.SystemName)
+ link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code)
+ subject := fmt.Sprintf("%s密码重置", config.SystemName)
content := fmt.Sprintf("您好,你正在进行%s密码重置。
"+
"点击 此处 进行密码重置。
"+
"如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s
"+
- "重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, link, link, common.VerificationValidMinutes)
+ "重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
", config.SystemName, link, link, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content)
if err != nil {
c.JSON(http.StatusOK, gin.H{
diff --git a/controller/model.go b/controller/model.go
index 6cb530db..b7ec1b6a 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -2,8 +2,8 @@ package controller
import (
"fmt"
-
"github.com/gin-gonic/gin"
+ "one-api/relay/channel/openai"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -436,7 +436,7 @@ func init() {
Id: "PaLM-2",
Object: "model",
Created: 1677649963,
- OwnedBy: "google",
+ OwnedBy: "google palm",
Permission: permission,
Root: "PaLM-2",
Parent: nil,
@@ -445,7 +445,7 @@ func init() {
Id: "gemini-pro",
Object: "model",
Created: 1677649963,
- OwnedBy: "google",
+ OwnedBy: "google gemini",
Permission: permission,
Root: "gemini-pro",
Parent: nil,
@@ -454,7 +454,7 @@ func init() {
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
- OwnedBy: "google",
+ OwnedBy: "google gemini",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
@@ -613,14 +613,14 @@ func RetrieveModel(c *gin.Context) {
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
- openAIError := OpenAIError{
+ Error := openai.Error{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",
Code: "model_not_found",
}
c.JSON(200, gin.H{
- "error": openAIError,
+ "error": Error,
})
}
}
diff --git a/controller/option.go b/controller/option.go
index bbf83578..593ac2ae 100644
--- a/controller/option.go
+++ b/controller/option.go
@@ -3,7 +3,8 @@ package controller
import (
"encoding/json"
"net/http"
- "one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
"one-api/model"
"strings"
@@ -12,17 +13,17 @@ import (
func GetOptions(c *gin.Context) {
var options []*model.Option
- common.OptionMapRWMutex.Lock()
- for k, v := range common.OptionMap {
+ config.OptionMapRWMutex.Lock()
+ for k, v := range config.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue
}
options = append(options, &model.Option{
Key: k,
- Value: common.Interface2String(v),
+ Value: helper.Interface2String(v),
})
}
- common.OptionMapRWMutex.Unlock()
+ config.OptionMapRWMutex.Unlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -42,8 +43,16 @@ func UpdateOption(c *gin.Context) {
return
}
switch option.Key {
+ case "Theme":
+ if !config.ValidThemes[option.Value] {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的主题",
+ })
+ return
+ }
case "GitHubOAuthEnabled":
- if option.Value == "true" && common.GitHubClientId == "" {
+ if option.Value == "true" && config.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
@@ -51,7 +60,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "EmailDomainRestrictionEnabled":
- if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
+ if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
@@ -59,7 +68,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "WeChatAuthEnabled":
- if option.Value == "true" && common.WeChatServerAddress == "" {
+ if option.Value == "true" && config.WeChatServerAddress == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用微信登录,请先填入微信登录相关配置信息!",
@@ -67,7 +76,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "TurnstileCheckEnabled":
- if option.Value == "true" && common.TurnstileSiteKey == "" {
+ if option.Value == "true" && config.TurnstileSiteKey == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
diff --git a/controller/redemption.go b/controller/redemption.go
index 0f656be0..9eeeb943 100644
--- a/controller/redemption.go
+++ b/controller/redemption.go
@@ -3,7 +3,8 @@ package controller
import (
"github.com/gin-gonic/gin"
"net/http"
- "one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
"one-api/model"
"strconv"
)
@@ -13,7 +14,7 @@ func GetAllRedemptions(c *gin.Context) {
if p < 0 {
p = 0
}
- redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage)
+ redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -105,12 +106,12 @@ func AddRedemption(c *gin.Context) {
}
var keys []string
for i := 0; i < redemption.Count; i++ {
- key := common.GetUUID()
+ key := helper.GetUUID()
cleanRedemption := model.Redemption{
UserId: c.GetInt("id"),
Name: redemption.Name,
Key: key,
- CreatedTime: common.GetTimestamp(),
+ CreatedTime: helper.GetTimestamp(),
Quota: redemption.Quota,
}
err = cleanRedemption.Insert()
diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go
deleted file mode 100644
index 5930ae89..00000000
--- a/controller/relay-tencent.go
+++ /dev/null
@@ -1,288 +0,0 @@
-package controller
-
-import (
- "bufio"
- "crypto/hmac"
- "crypto/sha1"
- "encoding/base64"
- "encoding/json"
- "errors"
- "fmt"
- "github.com/gin-gonic/gin"
- "io"
- "net/http"
- "one-api/common"
- "sort"
- "strconv"
- "strings"
-)
-
-// https://cloud.tencent.com/document/product/1729/97732
-
-type TencentMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
-}
-
-type TencentChatRequest struct {
- AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
- SecretId string `json:"secret_id"` // 官网 SecretId
- // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
- // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
- Timestamp int64 `json:"timestamp"`
- // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
- // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
- Expired int64 `json:"expired"`
- QueryID string `json:"query_id"` //请求 Id,用于问题排查
- // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
- // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
- // 建议该参数和 top_p 只设置1个,不要同时更改 top_p
- Temperature float64 `json:"temperature"`
- // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
- // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
- // 建议该参数和 temperature 只设置1个,不要同时更改
- TopP float64 `json:"top_p"`
- // Stream 0:同步,1:流式 (默认,协议:SSE)
- // 同步请求超时:60s,如果内容较长建议使用流式
- Stream int `json:"stream"`
- // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
- // 输入 content 总数最大支持 3000 token。
- Messages []TencentMessage `json:"messages"`
-}
-
-type TencentError struct {
- Code int `json:"code"`
- Message string `json:"message"`
-}
-
-type TencentUsage struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- TotalTokens int `json:"total_tokens"`
-}
-
-type TencentResponseChoices struct {
- FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
- Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
- Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
-}
-
-type TencentChatResponse struct {
- Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
- Created string `json:"created,omitempty"` // unix 时间戳的字符串
- Id string `json:"id,omitempty"` // 会话 id
- Usage Usage `json:"usage,omitempty"` // token 数量
- Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
- Note string `json:"note,omitempty"` // 注释
- ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
-}
-
-func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
- messages := make([]TencentMessage, 0, len(request.Messages))
- for i := 0; i < len(request.Messages); i++ {
- message := request.Messages[i]
- if message.Role == "system" {
- messages = append(messages, TencentMessage{
- Role: "user",
- Content: message.StringContent(),
- })
- messages = append(messages, TencentMessage{
- Role: "assistant",
- Content: "Okay",
- })
- continue
- }
- messages = append(messages, TencentMessage{
- Content: message.StringContent(),
- Role: message.Role,
- })
- }
- stream := 0
- if request.Stream {
- stream = 1
- }
- return &TencentChatRequest{
- Timestamp: common.GetTimestamp(),
- Expired: common.GetTimestamp() + 24*60*60,
- QueryID: common.GetUUID(),
- Temperature: request.Temperature,
- TopP: request.TopP,
- Stream: stream,
- Messages: messages,
- }
-}
-
-func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
- fullTextResponse := OpenAITextResponse{
- Object: "chat.completion",
- Created: common.GetTimestamp(),
- Usage: response.Usage,
- }
- if len(response.Choices) > 0 {
- choice := OpenAITextResponseChoice{
- Index: 0,
- Message: Message{
- Role: "assistant",
- Content: response.Choices[0].Messages.Content,
- },
- FinishReason: response.Choices[0].FinishReason,
- }
- fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
- }
- return &fullTextResponse
-}
-
-func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
- response := ChatCompletionsStreamResponse{
- Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
- Model: "tencent-hunyuan",
- }
- if len(TencentResponse.Choices) > 0 {
- var choice ChatCompletionsStreamResponseChoice
- choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
- if TencentResponse.Choices[0].FinishReason == "stop" {
- choice.FinishReason = &stopFinishReason
- }
- response.Choices = append(response.Choices, choice)
- }
- return &response
-}
-
-func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
- var responseText string
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- if i := strings.Index(string(data), "\n"); i >= 0 {
- return i + 1, data[0:i], nil
- }
- if atEOF {
- return len(data), data, nil
- }
- return 0, nil, nil
- })
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < 5 { // ignore blank line or wrong format
- continue
- }
- if data[:5] != "data:" {
- continue
- }
- data = data[5:]
- dataChan <- data
- }
- stopChan <- true
- }()
- setEventStreamHeaders(c)
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- var TencentResponse TencentChatResponse
- err := json.Unmarshal([]byte(data), &TencentResponse)
- if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- response := streamResponseTencent2OpenAI(&TencentResponse)
- if len(response.Choices) != 0 {
- responseText += response.Choices[0].Delta.Content
- }
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
- }
- })
- err := resp.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
- }
- return nil, responseText
-}
-
-func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
- var TencentResponse TencentChatResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- err = json.Unmarshal(responseBody, &TencentResponse)
- if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
- }
- if TencentResponse.Error.Code != 0 {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: OpenAIError{
- Message: TencentResponse.Error.Message,
- Code: TencentResponse.Error.Code,
- },
- StatusCode: resp.StatusCode,
- }, nil
- }
- fullTextResponse := responseTencent2OpenAI(&TencentResponse)
- fullTextResponse.Model = "hunyuan"
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &fullTextResponse.Usage
-}
-
-func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
- parts := strings.Split(config, "|")
- if len(parts) != 3 {
- err = errors.New("invalid tencent config")
- return
- }
- appId, err = strconv.ParseInt(parts[0], 10, 64)
- secretId = parts[1]
- secretKey = parts[2]
- return
-}
-
-func getTencentSign(req TencentChatRequest, secretKey string) string {
- params := make([]string, 0)
- params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
- params = append(params, "secret_id="+req.SecretId)
- params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
- params = append(params, "query_id="+req.QueryID)
- params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
- params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
- params = append(params, "stream="+strconv.Itoa(req.Stream))
- params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
-
- var messageStr string
- for _, msg := range req.Messages {
- messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
- }
- messageStr = strings.TrimSuffix(messageStr, ",")
- params = append(params, "messages=["+messageStr+"]")
-
- sort.Sort(sort.StringSlice(params))
- url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
- mac := hmac.New(sha1.New, []byte(secretKey))
- signURL := url
- mac.Write([]byte(signURL))
- sign := mac.Sum([]byte(nil))
- return base64.StdEncoding.EncodeToString(sign)
-}
diff --git a/controller/relay-text.go b/controller/relay-text.go
deleted file mode 100644
index 64338545..00000000
--- a/controller/relay-text.go
+++ /dev/null
@@ -1,689 +0,0 @@
-package controller
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "math"
- "net/http"
- "one-api/common"
- "one-api/model"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-const (
- APITypeOpenAI = iota
- APITypeClaude
- APITypePaLM
- APITypeBaidu
- APITypeZhipu
- APITypeAli
- APITypeXunfei
- APITypeAIProxyLibrary
- APITypeTencent
- APITypeGemini
-)
-
-var httpClient *http.Client
-var impatientHTTPClient *http.Client
-
-func init() {
- if common.RelayTimeout == 0 {
- httpClient = &http.Client{}
- } else {
- httpClient = &http.Client{
- Timeout: time.Duration(common.RelayTimeout) * time.Second,
- }
- }
-
- impatientHTTPClient = &http.Client{
- Timeout: 5 * time.Second,
- }
-}
-
-func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
- channelType := c.GetInt("channel")
- channelId := c.GetInt("channel_id")
- tokenId := c.GetInt("token_id")
- userId := c.GetInt("id")
- group := c.GetString("group")
- var textRequest GeneralOpenAIRequest
- err := common.UnmarshalBodyReusable(c, &textRequest)
- if err != nil {
- return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
- }
- if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
- return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest)
- }
- if relayMode == RelayModeModerations && textRequest.Model == "" {
- textRequest.Model = "text-moderation-latest"
- }
- if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
- textRequest.Model = c.Param("model")
- }
- // request validation
- if textRequest.Model == "" {
- return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
- }
- switch relayMode {
- case RelayModeCompletions:
- if textRequest.Prompt == "" {
- return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
- }
- case RelayModeChatCompletions:
- if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
- return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
- }
- case RelayModeEmbeddings:
- case RelayModeModerations:
- if textRequest.Input == "" {
- return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
- }
- case RelayModeEdits:
- if textRequest.Instruction == "" {
- return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
- }
- }
- // map model name
- modelMapping := c.GetString("model_mapping")
- isModelMapped := false
- if modelMapping != "" && modelMapping != "{}" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[textRequest.Model] != "" {
- textRequest.Model = modelMap[textRequest.Model]
- isModelMapped = true
- }
- }
- apiType := APITypeOpenAI
- switch channelType {
- case common.ChannelTypeAnthropic:
- apiType = APITypeClaude
- case common.ChannelTypeBaidu:
- apiType = APITypeBaidu
- case common.ChannelTypePaLM:
- apiType = APITypePaLM
- case common.ChannelTypeZhipu:
- apiType = APITypeZhipu
- case common.ChannelTypeAli:
- apiType = APITypeAli
- case common.ChannelTypeXunfei:
- apiType = APITypeXunfei
- case common.ChannelTypeAIProxyLibrary:
- apiType = APITypeAIProxyLibrary
- case common.ChannelTypeTencent:
- apiType = APITypeTencent
- case common.ChannelTypeGemini:
- apiType = APITypeGemini
- }
- baseURL := common.ChannelBaseURLs[channelType]
- requestURL := c.Request.URL.String()
- if c.GetString("base_url") != "" {
- baseURL = c.GetString("base_url")
- }
- fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
- switch apiType {
- case APITypeOpenAI:
- if channelType == common.ChannelTypeAzure {
- // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
- apiVersion := GetAPIVersion(c)
- requestURL := strings.Split(requestURL, "?")[0]
- requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
- baseURL = c.GetString("base_url")
- task := strings.TrimPrefix(requestURL, "/v1/")
- model_ := textRequest.Model
- model_ = strings.Replace(model_, ".", "", -1)
- // https://github.com/songquanpeng/one-api/issues/67
- model_ = strings.TrimSuffix(model_, "-0301")
- model_ = strings.TrimSuffix(model_, "-0314")
- model_ = strings.TrimSuffix(model_, "-0613")
-
- requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
- fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
- }
- case APITypeClaude:
- fullRequestURL = "https://api.anthropic.com/v1/complete"
- if baseURL != "" {
- fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
- }
- case APITypeBaidu:
- switch textRequest.Model {
- case "ERNIE-Bot":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
- case "ERNIE-Bot-turbo":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
- case "ERNIE-Bot-4":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
- case "BLOOMZ-7B":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
- case "Embedding-V1":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
- }
- apiKey := c.Request.Header.Get("Authorization")
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
- var err error
- if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
- return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
- }
- fullRequestURL += "?access_token=" + apiKey
- case APITypePaLM:
- fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
- if baseURL != "" {
- fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
- }
- case APITypeGemini:
- requestBaseURL := "https://generativelanguage.googleapis.com"
- if baseURL != "" {
- requestBaseURL = baseURL
- }
- version := "v1"
- if c.GetString("api_version") != "" {
- version = c.GetString("api_version")
- }
- action := "generateContent"
- if textRequest.Stream {
- action = "streamGenerateContent"
- }
- fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
- case APITypeZhipu:
- method := "invoke"
- if textRequest.Stream {
- method = "sse-invoke"
- }
- fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
- case APITypeAli:
- fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
- if relayMode == RelayModeEmbeddings {
- fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
- }
- case APITypeTencent:
- fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
- case APITypeAIProxyLibrary:
- fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
- }
- var promptTokens int
- var completionTokens int
- switch relayMode {
- case RelayModeChatCompletions:
- promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
- case RelayModeCompletions:
- promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
- case RelayModeModerations:
- promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
- }
- preConsumedTokens := common.PreConsumedQuota
- if textRequest.MaxTokens != 0 {
- preConsumedTokens = promptTokens + textRequest.MaxTokens
- }
- modelRatio := common.GetModelRatio(textRequest.Model)
- groupRatio := common.GetGroupRatio(group)
- ratio := modelRatio * groupRatio
- preConsumedQuota := int(float64(preConsumedTokens) * ratio)
- userQuota, err := model.CacheGetUserQuota(userId)
- if err != nil {
- return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
- }
- if userQuota-preConsumedQuota < 0 {
- return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
- }
- err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
- if err != nil {
- return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
- }
- if userQuota > 100*preConsumedQuota {
- // in this case, we do not pre-consume quota
- // because the user has enough quota
- preConsumedQuota = 0
- common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
- }
- if preConsumedQuota > 0 {
- err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
- if err != nil {
- return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
- }
- }
- var requestBody io.Reader
- if isModelMapped {
- jsonStr, err := json.Marshal(textRequest)
- if err != nil {
- return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- } else {
- requestBody = c.Request.Body
- }
- switch apiType {
- case APITypeClaude:
- claudeRequest := requestOpenAI2Claude(textRequest)
- jsonStr, err := json.Marshal(claudeRequest)
- if err != nil {
- return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeBaidu:
- var jsonData []byte
- var err error
- switch relayMode {
- case RelayModeEmbeddings:
- baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
- jsonData, err = json.Marshal(baiduEmbeddingRequest)
- default:
- baiduRequest := requestOpenAI2Baidu(textRequest)
- jsonData, err = json.Marshal(baiduRequest)
- }
- if err != nil {
- return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonData)
- case APITypePaLM:
- palmRequest := requestOpenAI2PaLM(textRequest)
- jsonStr, err := json.Marshal(palmRequest)
- if err != nil {
- return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeGemini:
- geminiChatRequest := requestOpenAI2Gemini(textRequest)
- jsonStr, err := json.Marshal(geminiChatRequest)
- if err != nil {
- return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeZhipu:
- zhipuRequest := requestOpenAI2Zhipu(textRequest)
- jsonStr, err := json.Marshal(zhipuRequest)
- if err != nil {
- return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeAli:
- var jsonStr []byte
- var err error
- switch relayMode {
- case RelayModeEmbeddings:
- aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
- jsonStr, err = json.Marshal(aliEmbeddingRequest)
- default:
- aliRequest := requestOpenAI2Ali(textRequest)
- jsonStr, err = json.Marshal(aliRequest)
- }
- if err != nil {
- return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeTencent:
- apiKey := c.Request.Header.Get("Authorization")
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
- appId, secretId, secretKey, err := parseTencentConfig(apiKey)
- if err != nil {
- return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
- }
- tencentRequest := requestOpenAI2Tencent(textRequest)
- tencentRequest.AppId = appId
- tencentRequest.SecretId = secretId
- jsonStr, err := json.Marshal(tencentRequest)
- if err != nil {
- return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- sign := getTencentSign(*tencentRequest, secretKey)
- c.Request.Header.Set("Authorization", sign)
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeAIProxyLibrary:
- aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
- aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
- jsonStr, err := json.Marshal(aiProxyLibraryRequest)
- if err != nil {
- return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- }
-
- var req *http.Request
- var resp *http.Response
- isStream := textRequest.Stream
-
- if apiType != APITypeXunfei { // cause xunfei use websocket
- req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
- if err != nil {
- return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
- }
- apiKey := c.Request.Header.Get("Authorization")
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
- switch apiType {
- case APITypeOpenAI:
- if channelType == common.ChannelTypeAzure {
- req.Header.Set("api-key", apiKey)
- } else {
- req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
- if channelType == common.ChannelTypeOpenRouter {
- req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
- req.Header.Set("X-Title", "One API")
- }
- }
- case APITypeClaude:
- req.Header.Set("x-api-key", apiKey)
- anthropicVersion := c.Request.Header.Get("anthropic-version")
- if anthropicVersion == "" {
- anthropicVersion = "2023-06-01"
- }
- req.Header.Set("anthropic-version", anthropicVersion)
- case APITypeZhipu:
- token := getZhipuToken(apiKey)
- req.Header.Set("Authorization", token)
- case APITypeAli:
- req.Header.Set("Authorization", "Bearer "+apiKey)
- if textRequest.Stream {
- req.Header.Set("X-DashScope-SSE", "enable")
- }
- if c.GetString("plugin") != "" {
- req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
- }
- case APITypeTencent:
- req.Header.Set("Authorization", apiKey)
- case APITypePaLM:
- req.Header.Set("x-goog-api-key", apiKey)
- case APITypeGemini:
- req.Header.Set("x-goog-api-key", apiKey)
- default:
- req.Header.Set("Authorization", "Bearer "+apiKey)
- }
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- if isStream && c.Request.Header.Get("Accept") == "" {
- req.Header.Set("Accept", "text/event-stream")
- }
- //req.Header.Set("Connection", c.Request.Header.Get("Connection"))
- resp, err = httpClient.Do(req)
- if err != nil {
- return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- }
- err = req.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
- }
- err = c.Request.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
- }
- isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
-
- if resp.StatusCode != http.StatusOK {
- if preConsumedQuota != 0 {
- go func(ctx context.Context) {
- // return pre-consumed quota
- err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
- if err != nil {
- common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
- }
- }(c.Request.Context())
- }
- return relayErrorHandler(resp)
- }
- }
-
- var textResponse TextResponse
- tokenName := c.GetString("token_name")
-
- defer func(ctx context.Context) {
- // c.Writer.Flush()
- go func() {
- quota := 0
- completionRatio := common.GetCompletionRatio(textRequest.Model)
- promptTokens = textResponse.Usage.PromptTokens
- completionTokens = textResponse.Usage.CompletionTokens
- quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
- if ratio != 0 && quota <= 0 {
- quota = 1
- }
- totalTokens := promptTokens + completionTokens
- if totalTokens == 0 {
- // in this case, must be some error happened
- // we cannot just return, because we may have to return the pre-consumed quota
- quota = 0
- }
- quotaDelta := quota - preConsumedQuota
- err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
- if err != nil {
- common.LogError(ctx, "error consuming token remain quota: "+err.Error())
- }
- err = model.CacheUpdateUserQuota(userId)
- if err != nil {
- common.LogError(ctx, "error update user quota cache: "+err.Error())
- }
- if quota != 0 {
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- model.UpdateChannelUsedQuota(channelId, quota)
- }
-
- }()
- }(c.Request.Context())
- switch apiType {
- case APITypeOpenAI:
- if isStream {
- err, responseText := openaiStreamHandler(c, resp, relayMode)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeClaude:
- if isStream {
- err, responseText := claudeStreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeBaidu:
- if isStream {
- err, usage := baiduStreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- } else {
- var err *OpenAIErrorWithStatusCode
- var usage *Usage
- switch relayMode {
- case RelayModeEmbeddings:
- err, usage = baiduEmbeddingHandler(c, resp)
- default:
- err, usage = baiduHandler(c, resp)
- }
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypePaLM:
- if textRequest.Stream { // PaLM2 API does not support stream
- err, responseText := palmStreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeGemini:
- if textRequest.Stream {
- err, responseText := geminiChatStreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeZhipu:
- if isStream {
- err, usage := zhipuStreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- // zhipu's API does not return prompt tokens & completion tokens
- textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
- return nil
- } else {
- err, usage := zhipuHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- // zhipu's API does not return prompt tokens & completion tokens
- textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
- return nil
- }
- case APITypeAli:
- if isStream {
- err, usage := aliStreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- } else {
- var err *OpenAIErrorWithStatusCode
- var usage *Usage
- switch relayMode {
- case RelayModeEmbeddings:
- err, usage = aliEmbeddingHandler(c, resp)
- default:
- err, usage = aliHandler(c, resp)
- }
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeXunfei:
- auth := c.Request.Header.Get("Authorization")
- auth = strings.TrimPrefix(auth, "Bearer ")
- splits := strings.Split(auth, "|")
- if len(splits) != 3 {
- return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
- }
- var err *OpenAIErrorWithStatusCode
- var usage *Usage
- if isStream {
- err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
- } else {
- err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
- }
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- case APITypeAIProxyLibrary:
- if isStream {
- err, usage := aiProxyLibraryStreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- } else {
- err, usage := aiProxyLibraryHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeTencent:
- if isStream {
- err, responseText := tencentStreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := tencentHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- default:
- return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
- }
-}
diff --git a/controller/relay.go b/controller/relay.go
index e45fd3eb..46fedc7e 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -2,374 +2,57 @@ package controller
import (
"fmt"
- "net/http"
- "one-api/common"
- "strconv"
- "strings"
-
"github.com/gin-gonic/gin"
-)
-
-type Message struct {
- Role string `json:"role"`
- Content any `json:"content"`
- Name *string `json:"name,omitempty"`
-}
-
-type ImageURL struct {
- Url string `json:"url,omitempty"`
- Detail string `json:"detail,omitempty"`
-}
-
-type TextContent struct {
- Type string `json:"type,omitempty"`
- Text string `json:"text,omitempty"`
-}
-
-type ImageContent struct {
- Type string `json:"type,omitempty"`
- ImageURL *ImageURL `json:"image_url,omitempty"`
-}
-
-const (
- ContentTypeText = "text"
- ContentTypeImageURL = "image_url"
-)
-
-type OpenAIMessageContent struct {
- Type string `json:"type,omitempty"`
- Text string `json:"text"`
- ImageURL *ImageURL `json:"image_url,omitempty"`
-}
-
-func (m Message) IsStringContent() bool {
- _, ok := m.Content.(string)
- return ok
-}
-
-func (m Message) StringContent() string {
- content, ok := m.Content.(string)
- if ok {
- return content
- }
- contentList, ok := m.Content.([]any)
- if ok {
- var contentStr string
- for _, contentItem := range contentList {
- contentMap, ok := contentItem.(map[string]any)
- if !ok {
- continue
- }
- if contentMap["type"] == ContentTypeText {
- if subStr, ok := contentMap["text"].(string); ok {
- contentStr += subStr
- }
- }
- }
- return contentStr
- }
- return ""
-}
-
-func (m Message) ParseContent() []OpenAIMessageContent {
- var contentList []OpenAIMessageContent
- content, ok := m.Content.(string)
- if ok {
- contentList = append(contentList, OpenAIMessageContent{
- Type: ContentTypeText,
- Text: content,
- })
- return contentList
- }
- anyList, ok := m.Content.([]any)
- if ok {
- for _, contentItem := range anyList {
- contentMap, ok := contentItem.(map[string]any)
- if !ok {
- continue
- }
- switch contentMap["type"] {
- case ContentTypeText:
- if subStr, ok := contentMap["text"].(string); ok {
- contentList = append(contentList, OpenAIMessageContent{
- Type: ContentTypeText,
- Text: subStr,
- })
- }
- case ContentTypeImageURL:
- if subObj, ok := contentMap["image_url"].(map[string]any); ok {
- contentList = append(contentList, OpenAIMessageContent{
- Type: ContentTypeImageURL,
- ImageURL: &ImageURL{
- Url: subObj["url"].(string),
- },
- })
- }
- }
- }
- return contentList
- }
- return nil
-}
-
-const (
- RelayModeUnknown = iota
- RelayModeChatCompletions
- RelayModeCompletions
- RelayModeEmbeddings
- RelayModeModerations
- RelayModeImagesGenerations
- RelayModeEdits
- RelayModeAudioSpeech
- RelayModeAudioTranscription
- RelayModeAudioTranslation
+ "net/http"
+ "one-api/common/config"
+ "one-api/common/helper"
+ "one-api/common/logger"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
+ "one-api/relay/controller"
+ "one-api/relay/util"
+ "strconv"
)
// https://platform.openai.com/docs/api-reference/chat
-type ResponseFormat struct {
- Type string `json:"type,omitempty"`
-}
-
-type GeneralOpenAIRequest struct {
- Model string `json:"model,omitempty"`
- Messages []Message `json:"messages,omitempty"`
- Prompt any `json:"prompt,omitempty"`
- Stream bool `json:"stream,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- N int `json:"n,omitempty"`
- Input any `json:"input,omitempty"`
- Instruction string `json:"instruction,omitempty"`
- Size string `json:"size,omitempty"`
- Functions any `json:"functions,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
- ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
- Seed float64 `json:"seed,omitempty"`
- Tools any `json:"tools,omitempty"`
- ToolChoice any `json:"tool_choice,omitempty"`
- User string `json:"user,omitempty"`
-}
-
-func (r GeneralOpenAIRequest) ParseInput() []string {
- if r.Input == nil {
- return nil
- }
- var input []string
- switch r.Input.(type) {
- case string:
- input = []string{r.Input.(string)}
- case []any:
- input = make([]string, 0, len(r.Input.([]any)))
- for _, item := range r.Input.([]any) {
- if str, ok := item.(string); ok {
- input = append(input, str)
- }
- }
- }
- return input
-}
-
-type ChatRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- MaxTokens int `json:"max_tokens"`
-}
-
-type TextRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- Prompt string `json:"prompt"`
- MaxTokens int `json:"max_tokens"`
- //Stream bool `json:"stream"`
-}
-
-// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
-type ImageRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt" binding:"required"`
- N int `json:"n,omitempty"`
- Size string `json:"size,omitempty"`
- Quality string `json:"quality,omitempty"`
- ResponseFormat string `json:"response_format,omitempty"`
- Style string `json:"style,omitempty"`
- User string `json:"user,omitempty"`
-}
-
-type WhisperJSONResponse struct {
- Text string `json:"text,omitempty"`
-}
-
-type WhisperVerboseJSONResponse struct {
- Task string `json:"task,omitempty"`
- Language string `json:"language,omitempty"`
- Duration float64 `json:"duration,omitempty"`
- Text string `json:"text,omitempty"`
- Segments []Segment `json:"segments,omitempty"`
-}
-
-type Segment struct {
- Id int `json:"id"`
- Seek int `json:"seek"`
- Start float64 `json:"start"`
- End float64 `json:"end"`
- Text string `json:"text"`
- Tokens []int `json:"tokens"`
- Temperature float64 `json:"temperature"`
- AvgLogprob float64 `json:"avg_logprob"`
- CompressionRatio float64 `json:"compression_ratio"`
- NoSpeechProb float64 `json:"no_speech_prob"`
-}
-
-type TextToSpeechRequest struct {
- Model string `json:"model" binding:"required"`
- Input string `json:"input" binding:"required"`
- Voice string `json:"voice" binding:"required"`
- Speed float64 `json:"speed"`
- ResponseFormat string `json:"response_format"`
-}
-
-type Usage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
-}
-
-type OpenAIError struct {
- Message string `json:"message"`
- Type string `json:"type"`
- Param string `json:"param"`
- Code any `json:"code"`
-}
-
-type OpenAIErrorWithStatusCode struct {
- OpenAIError
- StatusCode int `json:"status_code"`
-}
-
-type TextResponse struct {
- Choices []OpenAITextResponseChoice `json:"choices"`
- Usage `json:"usage"`
- Error OpenAIError `json:"error"`
-}
-
-type OpenAITextResponseChoice struct {
- Index int `json:"index"`
- Message `json:"message"`
- FinishReason string `json:"finish_reason"`
-}
-
-type OpenAITextResponse struct {
- Id string `json:"id"`
- Model string `json:"model,omitempty"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Choices []OpenAITextResponseChoice `json:"choices"`
- Usage `json:"usage"`
-}
-
-type OpenAIEmbeddingResponseItem struct {
- Object string `json:"object"`
- Index int `json:"index"`
- Embedding []float64 `json:"embedding"`
-}
-
-type OpenAIEmbeddingResponse struct {
- Object string `json:"object"`
- Data []OpenAIEmbeddingResponseItem `json:"data"`
- Model string `json:"model"`
- Usage `json:"usage"`
-}
-
-type ImageResponse struct {
- Created int `json:"created"`
- Data []struct {
- Url string `json:"url"`
- }
-}
-
-type ChatCompletionsStreamResponseChoice struct {
- Delta struct {
- Content string `json:"content"`
- } `json:"delta"`
- FinishReason *string `json:"finish_reason,omitempty"`
-}
-
-type ChatCompletionsStreamResponse struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
-}
-
-type CompletionsStreamResponse struct {
- Choices []struct {
- Text string `json:"text"`
- FinishReason string `json:"finish_reason"`
- } `json:"choices"`
-}
-
func Relay(c *gin.Context) {
- relayMode := RelayModeUnknown
- if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
- relayMode = RelayModeChatCompletions
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
- relayMode = RelayModeCompletions
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
- relayMode = RelayModeEmbeddings
- } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
- relayMode = RelayModeEmbeddings
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
- relayMode = RelayModeModerations
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
- relayMode = RelayModeImagesGenerations
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
- relayMode = RelayModeEdits
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
- relayMode = RelayModeAudioSpeech
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
- relayMode = RelayModeAudioTranscription
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
- relayMode = RelayModeAudioTranslation
- }
- var err *OpenAIErrorWithStatusCode
+ relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+ var err *openai.ErrorWithStatusCode
switch relayMode {
- case RelayModeImagesGenerations:
- err = relayImageHelper(c, relayMode)
- case RelayModeAudioSpeech:
+ case constant.RelayModeImagesGenerations:
+ err = controller.RelayImageHelper(c, relayMode)
+ case constant.RelayModeAudioSpeech:
fallthrough
- case RelayModeAudioTranslation:
+ case constant.RelayModeAudioTranslation:
fallthrough
- case RelayModeAudioTranscription:
- err = relayAudioHelper(c, relayMode)
+ case constant.RelayModeAudioTranscription:
+ err = controller.RelayAudioHelper(c, relayMode)
default:
- err = relayTextHelper(c, relayMode)
+ err = controller.RelayTextHelper(c, relayMode)
}
if err != nil {
- requestId := c.GetString(common.RequestIdKey)
+ requestId := c.GetString(logger.RequestIdKey)
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
- retryTimes = common.RetryTimes
+ retryTimes = config.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.StatusCode == http.StatusTooManyRequests {
- err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
+ err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
- err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
+ err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId)
c.JSON(err.StatusCode, gin.H{
- "error": err.OpenAIError,
+ "error": err.Error,
})
}
channelId := c.GetInt("channel_id")
- common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
+ logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
- if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
+ if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message)
@@ -378,7 +61,7 @@ func Relay(c *gin.Context) {
}
func RelayNotImplemented(c *gin.Context) {
- err := OpenAIError{
+ err := openai.Error{
Message: "API not implemented",
Type: "one_api_error",
Param: "",
@@ -390,7 +73,7 @@ func RelayNotImplemented(c *gin.Context) {
}
func RelayNotFound(c *gin.Context) {
- err := OpenAIError{
+ err := openai.Error{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",
diff --git a/controller/token.go b/controller/token.go
index 8642122c..d6554abe 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -4,6 +4,8 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
"one-api/model"
"strconv"
)
@@ -14,7 +16,7 @@ func GetAllTokens(c *gin.Context) {
if p < 0 {
p = 0
}
- tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage)
+ tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -119,9 +121,9 @@ func AddToken(c *gin.Context) {
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
- Key: common.GenerateKey(),
- CreatedTime: common.GetTimestamp(),
- AccessedTime: common.GetTimestamp(),
+ Key: helper.GenerateKey(),
+ CreatedTime: helper.GetTimestamp(),
+ AccessedTime: helper.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
@@ -187,7 +189,7 @@ func UpdateToken(c *gin.Context) {
return
}
if token.Status == common.TokenStatusEnabled {
- if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
+ if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
diff --git a/controller/user.go b/controller/user.go
index 8fd10b82..d13acddd 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -5,8 +5,11 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
"one-api/model"
"strconv"
+ "time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
@@ -18,7 +21,7 @@ type LoginRequest struct {
}
func Login(c *gin.Context) {
- if !common.PasswordLoginEnabled {
+ if !config.PasswordLoginEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了密码登录",
"success": false,
@@ -105,14 +108,14 @@ func Logout(c *gin.Context) {
}
func Register(c *gin.Context) {
- if !common.RegisterEnabled {
+ if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册",
"success": false,
})
return
}
- if !common.PasswordRegisterEnabled {
+ if !config.PasswordRegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
"success": false,
@@ -135,7 +138,7 @@ func Register(c *gin.Context) {
})
return
}
- if common.EmailVerificationEnabled {
+ if config.EmailVerificationEnabled {
if user.Email == "" || user.VerificationCode == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -159,7 +162,7 @@ func Register(c *gin.Context) {
DisplayName: user.Username,
InviterId: inviterId,
}
- if common.EmailVerificationEnabled {
+ if config.EmailVerificationEnabled {
cleanUser.Email = user.Email
}
if err := cleanUser.Insert(inviterId); err != nil {
@@ -181,7 +184,7 @@ func GetAllUsers(c *gin.Context) {
if p < 0 {
p = 0
}
- users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage)
+ users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -248,6 +251,29 @@ func GetUser(c *gin.Context) {
return
}
+func GetUserDashboard(c *gin.Context) {
+ id := c.GetInt("id")
+ now := time.Now()
+ startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix()
+ endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix()
+
+ dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法获取统计信息",
+ "data": nil,
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": dashboards,
+ })
+ return
+}
+
func GenerateAccessToken(c *gin.Context) {
id := c.GetInt("id")
user, err := model.GetUserById(id, true)
@@ -258,7 +284,7 @@ func GenerateAccessToken(c *gin.Context) {
})
return
}
- user.AccessToken = common.GetUUID()
+ user.AccessToken = helper.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{
@@ -295,7 +321,7 @@ func GetAffCode(c *gin.Context) {
return
}
if user.AffCode == "" {
- user.AffCode = common.GetRandomString(4)
+ user.AffCode = helper.GetRandomString(4)
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -702,7 +728,7 @@ func EmailBind(c *gin.Context) {
return
}
if user.Role == common.RoleRootUser {
- common.RootUserEmail = email
+ config.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{
"success": true,
diff --git a/controller/wechat.go b/controller/wechat.go
index ff4c9fb6..fc231a24 100644
--- a/controller/wechat.go
+++ b/controller/wechat.go
@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/model"
"strconv"
"time"
@@ -22,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) {
if code == "" {
return "", errors.New("无效的参数")
}
- req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
+ req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil)
if err != nil {
return "", err
}
- req.Header.Set("Authorization", common.WeChatServerToken)
+ req.Header.Set("Authorization", config.WeChatServerToken)
client := http.Client{
Timeout: 5 * time.Second,
}
@@ -50,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) {
}
func WeChatAuth(c *gin.Context) {
- if !common.WeChatAuthEnabled {
+ if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册",
"success": false,
@@ -79,7 +80,7 @@ func WeChatAuth(c *gin.Context) {
return
}
} else {
- if common.RegisterEnabled {
+ if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User"
user.Role = common.RoleCommonUser
@@ -112,7 +113,7 @@ func WeChatAuth(c *gin.Context) {
}
func WeChatBind(c *gin.Context) {
- if !common.WeChatAuthEnabled {
+ if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册",
"success": false,
diff --git a/i18n/en.json b/i18n/en.json
index f67d8665..774be837 100644
--- a/i18n/en.json
+++ b/i18n/en.json
@@ -86,6 +86,7 @@
"该令牌已过期": "The token has expired",
"该令牌额度已用尽": "The token quota has been used up",
"无效的令牌": "Invalid token",
+ "令牌验证失败": "Token verification failed",
"id 或 userId 为空!": "id or userId is empty!",
"quota 不能为负数!": "quota cannot be negative!",
"令牌额度不足": "Insufficient token quota",
diff --git a/main.go b/main.go
index 3ab1872c..b79c7bf7 100644
--- a/main.go
+++ b/main.go
@@ -7,9 +7,12 @@ import (
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/logger"
"one-api/controller"
"one-api/middleware"
"one-api/model"
+ "one-api/relay/channel/openai"
"one-api/router"
"os"
"strconv"
@@ -19,67 +22,68 @@ import (
var buildFS embed.FS
func main() {
- common.SetupLogger()
- common.SysLog(fmt.Sprintf("One API %s started with theme %s", common.Version, common.Theme))
+ logger.SetupLogger()
+ logger.SysLog(fmt.Sprintf("One API %s started", common.Version))
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
}
- if common.DebugEnabled {
- common.SysLog("running in debug mode")
+ if config.DebugEnabled {
+ logger.SysLog("running in debug mode")
}
// Initialize SQL Database
err := model.InitDB()
if err != nil {
- common.FatalLog("failed to initialize database: " + err.Error())
+ logger.FatalLog("failed to initialize database: " + err.Error())
}
defer func() {
err := model.CloseDB()
if err != nil {
- common.FatalLog("failed to close database: " + err.Error())
+ logger.FatalLog("failed to close database: " + err.Error())
}
}()
// Initialize Redis
err = common.InitRedisClient()
if err != nil {
- common.FatalLog("failed to initialize Redis: " + err.Error())
+ logger.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize options
model.InitOptionMap()
+ logger.SysLog(fmt.Sprintf("using theme %s", config.Theme))
if common.RedisEnabled {
// for compatibility with old versions
- common.MemoryCacheEnabled = true
+ config.MemoryCacheEnabled = true
}
- if common.MemoryCacheEnabled {
- common.SysLog("memory cache enabled")
- common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
+ if config.MemoryCacheEnabled {
+ logger.SysLog("memory cache enabled")
+ logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency))
model.InitChannelCache()
}
- if common.MemoryCacheEnabled {
- go model.SyncOptions(common.SyncFrequency)
- go model.SyncChannelCache(common.SyncFrequency)
+ if config.MemoryCacheEnabled {
+ go model.SyncOptions(config.SyncFrequency)
+ go model.SyncChannelCache(config.SyncFrequency)
}
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil {
- common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
+ logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyUpdateChannels(frequency)
}
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil {
- common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
+ logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyTestChannels(frequency)
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
- common.BatchUpdateEnabled = true
- common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
+ config.BatchUpdateEnabled = true
+ logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s")
model.InitBatchUpdater()
}
- controller.InitTokenEncoders()
+ openai.InitTokenEncoders()
// Initialize HTTP server
server := gin.New()
@@ -89,7 +93,7 @@ func main() {
server.Use(middleware.RequestId())
middleware.SetUpLogger(server)
// Initialize session store
- store := cookie.NewStore([]byte(common.SessionSecret))
+ store := cookie.NewStore([]byte(config.SessionSecret))
server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS)
@@ -99,6 +103,6 @@ func main() {
}
err = server.Run(":" + port)
if err != nil {
- common.FatalLog("failed to start HTTP server: " + err.Error())
+ logger.FatalLog("failed to start HTTP server: " + err.Error())
}
}
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 81338130..6b607d68 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/model"
"strconv"
"strings"
@@ -69,7 +70,7 @@ func Distribute() func(c *gin.Context) {
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil {
- common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
+ logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
abortWithMessage(c, http.StatusServiceUnavailable, message)
diff --git a/middleware/logger.go b/middleware/logger.go
index 02f2e0a9..ca372c52 100644
--- a/middleware/logger.go
+++ b/middleware/logger.go
@@ -3,14 +3,14 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
- "one-api/common"
+ "one-api/common/logger"
)
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
- requestID = param.Keys[common.RequestIdKey].(string)
+ requestID = param.Keys[logger.RequestIdKey].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go
index 8e5cff6c..e89ef8d6 100644
--- a/middleware/rate-limit.go
+++ b/middleware/rate-limit.go
@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
+ "one-api/common/config"
"time"
)
@@ -26,7 +27,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
}
if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
} else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
@@ -47,14 +48,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
// time.Since will return negative number!
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests)
c.Abort()
return
} else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
}
}
}
@@ -75,7 +76,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
}
} else {
// It's safe to call multi times.
- inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
+ inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
memoryRateLimiter(c, maxRequestNum, duration, mark)
}
@@ -83,21 +84,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
}
func GlobalWebRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW")
+ return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW")
}
func GlobalAPIRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")
+ return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA")
}
func CriticalRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")
+ return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT")
}
func DownloadRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW")
+ return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW")
}
func UploadRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
+ return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP")
}
diff --git a/middleware/recover.go b/middleware/recover.go
index 8338a514..9d3edc27 100644
--- a/middleware/recover.go
+++ b/middleware/recover.go
@@ -4,7 +4,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
- "one-api/common"
+ "one-api/common/logger"
"runtime/debug"
)
@@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
- common.SysError(fmt.Sprintf("panic detected: %v", err))
- common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
+ logger.SysError(fmt.Sprintf("panic detected: %v", err))
+ logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
diff --git a/middleware/request-id.go b/middleware/request-id.go
index e623be7a..811a7ad8 100644
--- a/middleware/request-id.go
+++ b/middleware/request-id.go
@@ -3,16 +3,17 @@ package middleware
import (
"context"
"github.com/gin-gonic/gin"
- "one-api/common"
+ "one-api/common/helper"
+ "one-api/common/logger"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
- id := common.GetTimeString() + common.GetRandomString(8)
- c.Set(common.RequestIdKey, id)
- ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
+ id := helper.GetTimeString() + helper.GetRandomString(8)
+ c.Set(logger.RequestIdKey, id)
+ ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
- c.Header(common.RequestIdKey, id)
+ c.Header(logger.RequestIdKey, id)
c.Next()
}
}
diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go
index 26688810..629395e7 100644
--- a/middleware/turnstile-check.go
+++ b/middleware/turnstile-check.go
@@ -6,7 +6,8 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"net/url"
- "one-api/common"
+ "one-api/common/config"
+ "one-api/common/logger"
)
type turnstileCheckResponse struct {
@@ -15,7 +16,7 @@ type turnstileCheckResponse struct {
func TurnstileCheck() gin.HandlerFunc {
return func(c *gin.Context) {
- if common.TurnstileCheckEnabled {
+ if config.TurnstileCheckEnabled {
session := sessions.Default(c)
turnstileChecked := session.Get("turnstile")
if turnstileChecked != nil {
@@ -32,12 +33,12 @@ func TurnstileCheck() gin.HandlerFunc {
return
}
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
- "secret": {common.TurnstileSecretKey},
+ "secret": {config.TurnstileSecretKey},
"response": {response},
"remoteip": {c.ClientIP()},
})
if err != nil {
- common.SysError(err.Error())
+ logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc {
var res turnstileCheckResponse
err = json.NewDecoder(rawRes.Body).Decode(&res)
if err != nil {
- common.SysError(err.Error())
+ logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
diff --git a/middleware/utils.go b/middleware/utils.go
index 536125cc..d866d75b 100644
--- a/middleware/utils.go
+++ b/middleware/utils.go
@@ -2,16 +2,17 @@ package middleware
import (
"github.com/gin-gonic/gin"
- "one-api/common"
+ "one-api/common/helper"
+ "one-api/common/logger"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
- "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
+ "message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
"type": "one_api_error",
},
})
c.Abort()
- common.LogError(c.Request.Context(), message)
+ logger.Error(c.Request.Context(), message)
}
diff --git a/model/cache.go b/model/cache.go
index c6d0c70a..a81bdddd 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -6,6 +6,8 @@ import (
"fmt"
"math/rand"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/logger"
"sort"
"strconv"
"strings"
@@ -14,10 +16,10 @@ import (
)
var (
- TokenCacheSeconds = common.SyncFrequency
- UserId2GroupCacheSeconds = common.SyncFrequency
- UserId2QuotaCacheSeconds = common.SyncFrequency
- UserId2StatusCacheSeconds = common.SyncFrequency
+ TokenCacheSeconds = config.SyncFrequency
+ UserId2GroupCacheSeconds = config.SyncFrequency
+ UserId2QuotaCacheSeconds = config.SyncFrequency
+ UserId2StatusCacheSeconds = config.SyncFrequency
)
func CacheGetTokenByKey(key string) (*Token, error) {
@@ -42,7 +44,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
}
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set token error: " + err.Error())
+ logger.SysError("Redis set token error: " + err.Error())
}
return &token, nil
}
@@ -62,7 +64,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
}
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set user group error: " + err.Error())
+ logger.SysError("Redis set user group error: " + err.Error())
}
}
return group, err
@@ -80,7 +82,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set user quota error: " + err.Error())
+ logger.SysError("Redis set user quota error: " + err.Error())
}
return quota, err
}
@@ -127,7 +129,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
}
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set user enabled error: " + err.Error())
+ logger.SysError("Redis set user enabled error: " + err.Error())
}
return userEnabled, err
}
@@ -178,19 +180,19 @@ func InitChannelCache() {
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
- common.SysLog("channels synced from database")
+ logger.SysLog("channels synced from database")
}
func SyncChannelCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing channels from database")
+ logger.SysLog("syncing channels from database")
InitChannelCache()
}
}
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
- if !common.MemoryCacheEnabled {
+ if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
}
channelSyncLock.RLock()
diff --git a/model/channel.go b/model/channel.go
index 085e3ca4..4c77d902 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -1,8 +1,13 @@
package model
import (
+ "encoding/json"
+ "fmt"
"gorm.io/gorm"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
+ "one-api/common/logger"
)
type Channel struct {
@@ -42,7 +47,7 @@ func SearchChannels(keyword string) (channels []*Channel, err error) {
if common.UsingPostgreSQL {
keyCol = `"key"`
}
- err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
+ err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
return channels, err
}
@@ -86,11 +91,17 @@ func (channel *Channel) GetBaseURL() string {
return *channel.BaseURL
}
-func (channel *Channel) GetModelMapping() string {
- if channel.ModelMapping == nil {
- return ""
+func (channel *Channel) GetModelMapping() map[string]string {
+ if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" {
+ return nil
}
- return *channel.ModelMapping
+ modelMapping := make(map[string]string)
+ err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping)
+ if err != nil {
+ logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error()))
+ return nil
+ }
+ return modelMapping
}
func (channel *Channel) Insert() error {
@@ -116,21 +127,21 @@ func (channel *Channel) Update() error {
func (channel *Channel) UpdateResponseTime(responseTime int64) {
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
- TestTime: common.GetTimestamp(),
+ TestTime: helper.GetTimestamp(),
ResponseTime: int(responseTime),
}).Error
if err != nil {
- common.SysError("failed to update response time: " + err.Error())
+ logger.SysError("failed to update response time: " + err.Error())
}
}
func (channel *Channel) UpdateBalance(balance float64) {
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
- BalanceUpdatedTime: common.GetTimestamp(),
+ BalanceUpdatedTime: helper.GetTimestamp(),
Balance: balance,
}).Error
if err != nil {
- common.SysError("failed to update balance: " + err.Error())
+ logger.SysError("failed to update balance: " + err.Error())
}
}
@@ -147,16 +158,16 @@ func (channel *Channel) Delete() error {
func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil {
- common.SysError("failed to update ability status: " + err.Error())
+ logger.SysError("failed to update ability status: " + err.Error())
}
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
- common.SysError("failed to update channel status: " + err.Error())
+ logger.SysError("failed to update channel status: " + err.Error())
}
}
func UpdateChannelUsedQuota(id int, quota int) {
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return
}
@@ -166,7 +177,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
func updateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
- common.SysError("failed to update channel used quota: " + err.Error())
+ logger.SysError("failed to update channel used quota: " + err.Error())
}
}
diff --git a/model/log.go b/model/log.go
index 3d3ffae3..78b6d9b3 100644
--- a/model/log.go
+++ b/model/log.go
@@ -3,14 +3,18 @@ package model
import (
"context"
"fmt"
- "gorm.io/gorm"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
+ "one-api/common/logger"
+
+ "gorm.io/gorm"
)
type Log struct {
- Id int `json:"id;index:idx_created_at_id,priority:1"`
+ Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"`
- CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
+ CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"`
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
@@ -31,31 +35,31 @@ const (
)
func RecordLog(userId int, logType int, content string) {
- if logType == LogTypeConsume && !common.LogConsumeEnabled {
+ if logType == LogTypeConsume && !config.LogConsumeEnabled {
return
}
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
- CreatedAt: common.GetTimestamp(),
+ CreatedAt: helper.GetTimestamp(),
Type: logType,
Content: content,
}
err := DB.Create(log).Error
if err != nil {
- common.SysError("failed to record log: " + err.Error())
+ logger.SysError("failed to record log: " + err.Error())
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
- common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
- if !common.LogConsumeEnabled {
+ logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
+ if !config.LogConsumeEnabled {
return
}
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
- CreatedAt: common.GetTimestamp(),
+ CreatedAt: helper.GetTimestamp(),
Type: LogTypeConsume,
Content: content,
PromptTokens: promptTokens,
@@ -67,7 +71,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
}
err := DB.Create(log).Error
if err != nil {
- common.LogError(ctx, "failed to record log: "+err.Error())
+ logger.Error(ctx, "failed to record log: "+err.Error())
}
}
@@ -124,12 +128,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
}
func SearchAllLogs(keyword string) (logs []*Log, err error) {
- err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
+ err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
return logs, err
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
- err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
+ err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err
}
@@ -182,3 +186,40 @@ func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
}
+
+type LogStatistic struct {
+ Day string `gorm:"column:day"`
+ ModelName string `gorm:"column:model_name"`
+ RequestCount int `gorm:"column:request_count"`
+ Quota int `gorm:"column:quota"`
+ PromptTokens int `gorm:"column:prompt_tokens"`
+ CompletionTokens int `gorm:"column:completion_tokens"`
+}
+
+func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatistic, err error) {
+ groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day"
+
+ if common.UsingPostgreSQL {
+ groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day"
+ }
+
+ if common.UsingSQLite {
+ groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
+ }
+
+ err = DB.Raw(`
+ SELECT `+groupSelect+`,
+ model_name, count(1) as request_count,
+ sum(quota) as quota,
+ sum(prompt_tokens) as prompt_tokens,
+ sum(completion_tokens) as completion_tokens
+ FROM logs
+ WHERE type=2
+ AND user_id= ?
+ AND created_at BETWEEN ? AND ?
+ GROUP BY day, model_name
+ ORDER BY day, model_name
+ `, userId, start, end).Scan(&LogStatistics).Error
+
+ return LogStatistics, err
+}
diff --git a/model/main.go b/model/main.go
index bfd6888b..2ed6f0e3 100644
--- a/model/main.go
+++ b/model/main.go
@@ -7,6 +7,9 @@ import (
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
+ "one-api/common/logger"
"os"
"strings"
"time"
@@ -16,9 +19,9 @@ var DB *gorm.DB
func createRootAccountIfNeed() error {
var user User
- //if user.Status != common.UserStatusEnabled {
+ //if user.Status != util.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil {
- common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
+ logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
hashedPassword, err := common.Password2Hash("123456")
if err != nil {
return err
@@ -29,7 +32,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
DisplayName: "Root User",
- AccessToken: common.GetUUID(),
+ AccessToken: helper.GetUUID(),
Quota: 100000000,
}
DB.Create(&rootUser)
@@ -42,7 +45,7 @@ func chooseDB() (*gorm.DB, error) {
dsn := os.Getenv("SQL_DSN")
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
- common.SysLog("using PostgreSQL as database")
+ logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
@@ -52,13 +55,13 @@ func chooseDB() (*gorm.DB, error) {
})
}
// Use MySQL
- common.SysLog("using MySQL as database")
+ logger.SysLog("using MySQL as database")
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use SQLite
- common.SysLog("SQL_DSN not set, using SQLite as database")
+ logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
@@ -69,7 +72,7 @@ func chooseDB() (*gorm.DB, error) {
func InitDB() (err error) {
db, err := chooseDB()
if err == nil {
- if common.DebugEnabled {
+ if config.DebugEnabled {
db = db.Debug()
}
DB = db
@@ -77,14 +80,14 @@ func InitDB() (err error) {
if err != nil {
return err
}
- sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
- sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
- sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
+ sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
+ sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
+ sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
- if !common.IsMasterNode {
+ if !config.IsMasterNode {
return nil
}
- common.SysLog("database migration started")
+ logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
return err
@@ -113,11 +116,11 @@ func InitDB() (err error) {
if err != nil {
return err
}
- common.SysLog("database migrated")
+ logger.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
} else {
- common.FatalLog(err)
+ logger.FatalLog(err)
}
return err
}
diff --git a/model/option.go b/model/option.go
index bb8b709c..e211264c 100644
--- a/model/option.go
+++ b/model/option.go
@@ -2,6 +2,8 @@ package model
import (
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/logger"
"strconv"
"strings"
"time"
@@ -20,59 +22,56 @@ func AllOption() ([]*Option, error) {
}
func InitOptionMap() {
- common.OptionMapRWMutex.Lock()
- common.OptionMap = make(map[string]string)
- common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
- common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
- common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
- common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission)
- common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
- common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
- common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
- common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
- common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
- common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
- common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
- common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
- common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
- common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
- common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
- common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
- common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
- common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
- common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
- common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
- common.OptionMap["SMTPServer"] = ""
- common.OptionMap["SMTPFrom"] = ""
- common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
- common.OptionMap["SMTPAccount"] = ""
- common.OptionMap["SMTPToken"] = ""
- common.OptionMap["Notice"] = ""
- common.OptionMap["About"] = ""
- common.OptionMap["HomePageContent"] = ""
- common.OptionMap["Footer"] = common.Footer
- common.OptionMap["SystemName"] = common.SystemName
- common.OptionMap["Logo"] = common.Logo
- common.OptionMap["ServerAddress"] = ""
- common.OptionMap["GitHubClientId"] = ""
- common.OptionMap["GitHubClientSecret"] = ""
- common.OptionMap["WeChatServerAddress"] = ""
- common.OptionMap["WeChatServerToken"] = ""
- common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
- common.OptionMap["TurnstileSiteKey"] = ""
- common.OptionMap["TurnstileSecretKey"] = ""
- common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
- common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
- common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
- common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
- common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
- common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
- common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
- common.OptionMap["TopUpLink"] = common.TopUpLink
- common.OptionMap["ChatLink"] = common.ChatLink
- common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
- common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
- common.OptionMapRWMutex.Unlock()
+ config.OptionMapRWMutex.Lock()
+ config.OptionMap = make(map[string]string)
+ config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled)
+ config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
+ config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
+ config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
+ config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
+ config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
+ config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
+ config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled)
+ config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled)
+ config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled)
+ config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled)
+ config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled)
+ config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled)
+ config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64)
+ config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled)
+ config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",")
+ config.OptionMap["SMTPServer"] = ""
+ config.OptionMap["SMTPFrom"] = ""
+ config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
+ config.OptionMap["SMTPAccount"] = ""
+ config.OptionMap["SMTPToken"] = ""
+ config.OptionMap["Notice"] = ""
+ config.OptionMap["About"] = ""
+ config.OptionMap["HomePageContent"] = ""
+ config.OptionMap["Footer"] = config.Footer
+ config.OptionMap["SystemName"] = config.SystemName
+ config.OptionMap["Logo"] = config.Logo
+ config.OptionMap["ServerAddress"] = ""
+ config.OptionMap["GitHubClientId"] = ""
+ config.OptionMap["GitHubClientSecret"] = ""
+ config.OptionMap["WeChatServerAddress"] = ""
+ config.OptionMap["WeChatServerToken"] = ""
+ config.OptionMap["WeChatAccountQRCodeImageURL"] = ""
+ config.OptionMap["TurnstileSiteKey"] = ""
+ config.OptionMap["TurnstileSecretKey"] = ""
+ config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser)
+ config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter)
+ config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee)
+ config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold)
+ config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota)
+ config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
+ config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
+ config.OptionMap["TopUpLink"] = config.TopUpLink
+ config.OptionMap["ChatLink"] = config.ChatLink
+ config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
+ config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes)
+ config.OptionMap["Theme"] = config.Theme
+ config.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@@ -81,7 +80,7 @@ func loadOptionsFromDatabase() {
for _, option := range options {
err := updateOptionMap(option.Key, option.Value)
if err != nil {
- common.SysError("failed to update option map: " + err.Error())
+ logger.SysError("failed to update option map: " + err.Error())
}
}
}
@@ -89,7 +88,7 @@ func loadOptionsFromDatabase() {
func SyncOptions(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing options from database")
+ logger.SysLog("syncing options from database")
loadOptionsFromDatabase()
}
}
@@ -111,115 +110,104 @@ func UpdateOption(key string, value string) error {
}
func updateOptionMap(key string, value string) (err error) {
- common.OptionMapRWMutex.Lock()
- defer common.OptionMapRWMutex.Unlock()
- common.OptionMap[key] = value
- if strings.HasSuffix(key, "Permission") {
- intValue, _ := strconv.Atoi(value)
- switch key {
- case "FileUploadPermission":
- common.FileUploadPermission = intValue
- case "FileDownloadPermission":
- common.FileDownloadPermission = intValue
- case "ImageUploadPermission":
- common.ImageUploadPermission = intValue
- case "ImageDownloadPermission":
- common.ImageDownloadPermission = intValue
- }
- }
+ config.OptionMapRWMutex.Lock()
+ defer config.OptionMapRWMutex.Unlock()
+ config.OptionMap[key] = value
if strings.HasSuffix(key, "Enabled") {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
- common.PasswordRegisterEnabled = boolValue
+ config.PasswordRegisterEnabled = boolValue
case "PasswordLoginEnabled":
- common.PasswordLoginEnabled = boolValue
+ config.PasswordLoginEnabled = boolValue
case "EmailVerificationEnabled":
- common.EmailVerificationEnabled = boolValue
+ config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
- common.GitHubOAuthEnabled = boolValue
+ config.GitHubOAuthEnabled = boolValue
case "WeChatAuthEnabled":
- common.WeChatAuthEnabled = boolValue
+ config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled":
- common.TurnstileCheckEnabled = boolValue
+ config.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
- common.RegisterEnabled = boolValue
+ config.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled":
- common.EmailDomainRestrictionEnabled = boolValue
+ config.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
- common.AutomaticDisableChannelEnabled = boolValue
+ config.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
- common.AutomaticEnableChannelEnabled = boolValue
+ config.AutomaticEnableChannelEnabled = boolValue
case "ApproximateTokenEnabled":
- common.ApproximateTokenEnabled = boolValue
+ config.ApproximateTokenEnabled = boolValue
case "LogConsumeEnabled":
- common.LogConsumeEnabled = boolValue
+ config.LogConsumeEnabled = boolValue
case "DisplayInCurrencyEnabled":
- common.DisplayInCurrencyEnabled = boolValue
+ config.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled":
- common.DisplayTokenStatEnabled = boolValue
+ config.DisplayTokenStatEnabled = boolValue
}
}
switch key {
case "EmailDomainWhitelist":
- common.EmailDomainWhitelist = strings.Split(value, ",")
+ config.EmailDomainWhitelist = strings.Split(value, ",")
case "SMTPServer":
- common.SMTPServer = value
+ config.SMTPServer = value
case "SMTPPort":
intValue, _ := strconv.Atoi(value)
- common.SMTPPort = intValue
+ config.SMTPPort = intValue
case "SMTPAccount":
- common.SMTPAccount = value
+ config.SMTPAccount = value
case "SMTPFrom":
- common.SMTPFrom = value
+ config.SMTPFrom = value
case "SMTPToken":
- common.SMTPToken = value
+ config.SMTPToken = value
case "ServerAddress":
- common.ServerAddress = value
+ config.ServerAddress = value
case "GitHubClientId":
- common.GitHubClientId = value
+ config.GitHubClientId = value
case "GitHubClientSecret":
- common.GitHubClientSecret = value
+ config.GitHubClientSecret = value
case "Footer":
- common.Footer = value
+ config.Footer = value
case "SystemName":
- common.SystemName = value
+ config.SystemName = value
case "Logo":
- common.Logo = value
+ config.Logo = value
case "WeChatServerAddress":
- common.WeChatServerAddress = value
+ config.WeChatServerAddress = value
case "WeChatServerToken":
- common.WeChatServerToken = value
+ config.WeChatServerToken = value
case "WeChatAccountQRCodeImageURL":
- common.WeChatAccountQRCodeImageURL = value
+ config.WeChatAccountQRCodeImageURL = value
case "TurnstileSiteKey":
- common.TurnstileSiteKey = value
+ config.TurnstileSiteKey = value
case "TurnstileSecretKey":
- common.TurnstileSecretKey = value
+ config.TurnstileSecretKey = value
case "QuotaForNewUser":
- common.QuotaForNewUser, _ = strconv.Atoi(value)
+ config.QuotaForNewUser, _ = strconv.Atoi(value)
case "QuotaForInviter":
- common.QuotaForInviter, _ = strconv.Atoi(value)
+ config.QuotaForInviter, _ = strconv.Atoi(value)
case "QuotaForInvitee":
- common.QuotaForInvitee, _ = strconv.Atoi(value)
+ config.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold":
- common.QuotaRemindThreshold, _ = strconv.Atoi(value)
+ config.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota":
- common.PreConsumedQuota, _ = strconv.Atoi(value)
+ config.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes":
- common.RetryTimes, _ = strconv.Atoi(value)
+ config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "TopUpLink":
- common.TopUpLink = value
+ config.TopUpLink = value
case "ChatLink":
- common.ChatLink = value
+ config.ChatLink = value
case "ChannelDisableThreshold":
- common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
+ config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":
- common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
+ config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
+ case "Theme":
+ config.Theme = value
}
return err
}
diff --git a/model/redemption.go b/model/redemption.go
index f16412b5..026794e0 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -5,6 +5,7 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
+ "one-api/common/helper"
)
type Redemption struct {
@@ -67,7 +68,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return err
}
- redemption.RedeemedTime = common.GetTimestamp()
+ redemption.RedeemedTime = helper.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed
err = tx.Save(redemption).Error
return err
diff --git a/model/token.go b/model/token.go
index 0fa984d3..2087225b 100644
--- a/model/token.go
+++ b/model/token.go
@@ -5,6 +5,9 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
+ "one-api/common/logger"
)
type Token struct {
@@ -38,39 +41,43 @@ func ValidateUserToken(key string) (token *Token, err error) {
return nil, errors.New("未提供令牌")
}
token, err = CacheGetTokenByKey(key)
- if err == nil {
- if token.Status == common.TokenStatusExhausted {
- return nil, errors.New("该令牌额度已用尽")
- } else if token.Status == common.TokenStatusExpired {
- return nil, errors.New("该令牌已过期")
+ if err != nil {
+ logger.SysError("CacheGetTokenByKey failed: " + err.Error())
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return nil, errors.New("无效的令牌")
}
- if token.Status != common.TokenStatusEnabled {
- return nil, errors.New("该令牌状态不可用")
- }
- if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
- if !common.RedisEnabled {
- token.Status = common.TokenStatusExpired
- err := token.SelectUpdate()
- if err != nil {
- common.SysError("failed to update token status" + err.Error())
- }
- }
- return nil, errors.New("该令牌已过期")
- }
- if !token.UnlimitedQuota && token.RemainQuota <= 0 {
- if !common.RedisEnabled {
- // in this case, we can make sure the token is exhausted
- token.Status = common.TokenStatusExhausted
- err := token.SelectUpdate()
- if err != nil {
- common.SysError("failed to update token status" + err.Error())
- }
- }
- return nil, errors.New("该令牌额度已用尽")
- }
- return token, nil
+ return nil, errors.New("令牌验证失败")
}
- return nil, errors.New("无效的令牌")
+ if token.Status == common.TokenStatusExhausted {
+ return nil, errors.New("该令牌额度已用尽")
+ } else if token.Status == common.TokenStatusExpired {
+ return nil, errors.New("该令牌已过期")
+ }
+ if token.Status != common.TokenStatusEnabled {
+ return nil, errors.New("该令牌状态不可用")
+ }
+ if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
+ if !common.RedisEnabled {
+ token.Status = common.TokenStatusExpired
+ err := token.SelectUpdate()
+ if err != nil {
+ logger.SysError("failed to update token status" + err.Error())
+ }
+ }
+ return nil, errors.New("该令牌已过期")
+ }
+ if !token.UnlimitedQuota && token.RemainQuota <= 0 {
+ if !common.RedisEnabled {
+ // in this case, we can make sure the token is exhausted
+ token.Status = common.TokenStatusExhausted
+ err := token.SelectUpdate()
+ if err != nil {
+ logger.SysError("failed to update token status" + err.Error())
+ }
+ }
+ return nil, errors.New("该令牌额度已用尽")
+ }
+ return token, nil
}
func GetTokenByIds(id int, userId int) (*Token, error) {
@@ -134,7 +141,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil
}
@@ -146,7 +153,7 @@ func increaseTokenQuota(id int, quota int) (err error) {
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota),
- "accessed_time": common.GetTimestamp(),
+ "accessed_time": helper.GetTimestamp(),
},
).Error
return err
@@ -156,7 +163,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil
}
@@ -168,7 +175,7 @@ func decreaseTokenQuota(id int, quota int) (err error) {
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota),
- "accessed_time": common.GetTimestamp(),
+ "accessed_time": helper.GetTimestamp(),
},
).Error
return err
@@ -192,24 +199,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if userQuota < quota {
return errors.New("用户额度不足")
}
- quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold
+ quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold
noMoreQuota := userQuota-quota <= 0
if quotaTooLow || noMoreQuota {
go func() {
email, err := GetUserEmail(token.UserId)
if err != nil {
- common.SysError("failed to fetch user email: " + err.Error())
+ logger.SysError("failed to fetch user email: " + err.Error())
}
prompt := "您的额度即将用尽"
if noMoreQuota {
prompt = "您的额度已用尽"
}
if email != "" {
- topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
+ topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress)
err = common.SendEmail(prompt, email,
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
- common.SysError("failed to send email" + err.Error())
+ logger.SysError("failed to send email" + err.Error())
}
}
}()
diff --git a/model/user.go b/model/user.go
index e738b1ba..82e9707b 100644
--- a/model/user.go
+++ b/model/user.go
@@ -5,6 +5,9 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
+ "one-api/common/logger"
"strings"
)
@@ -15,7 +18,7 @@ type User struct {
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
- Role int `json:"role" gorm:"type:int;default:1"` // admin, common
+ Role int `json:"role" gorm:"type:int;default:1"` // admin, util
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
@@ -89,24 +92,24 @@ func (user *User) Insert(inviterId int) error {
return err
}
}
- user.Quota = common.QuotaForNewUser
- user.AccessToken = common.GetUUID()
- user.AffCode = common.GetRandomString(4)
+ user.Quota = config.QuotaForNewUser
+ user.AccessToken = helper.GetUUID()
+ user.AffCode = helper.GetRandomString(4)
result := DB.Create(user)
if result.Error != nil {
return result.Error
}
- if common.QuotaForNewUser > 0 {
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
+ if config.QuotaForNewUser > 0 {
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
}
if inviterId != 0 {
- if common.QuotaForInvitee > 0 {
- _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
+ if config.QuotaForInvitee > 0 {
+ _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
}
- if common.QuotaForInviter > 0 {
- _ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
- RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
+ if config.QuotaForInviter > 0 {
+ _ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
+ RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
}
}
return nil
@@ -141,7 +144,15 @@ func (user *User) ValidateAndFill() (err error) {
if user.Username == "" || password == "" {
return errors.New("用户名或密码为空")
}
- DB.Where(User{Username: user.Username}).First(user)
+ err = DB.Where("username = ?", user.Username).First(user).Error
+ if err != nil {
+ // we must make sure check username firstly
+ // consider this case: a malicious user set his username as other's email
+ err := DB.Where("email = ?", user.Username).First(user).Error
+ if err != nil {
+ return errors.New("用户名或密码错误,或用户已被封禁")
+ }
+ }
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
@@ -224,7 +235,7 @@ func IsAdmin(userId int) bool {
var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil {
- common.SysError("no such user " + err.Error())
+ logger.SysError("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
@@ -283,7 +294,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
@@ -299,7 +310,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil
}
@@ -317,7 +328,7 @@ func GetRootUserEmail() (email string) {
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return
@@ -333,7 +344,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
},
).Error
if err != nil {
- common.SysError("failed to update user used quota and request count: " + err.Error())
+ logger.SysError("failed to update user used quota and request count: " + err.Error())
}
}
@@ -344,14 +355,14 @@ func updateUserUsedQuota(id int, quota int) {
},
).Error
if err != nil {
- common.SysError("failed to update user used quota: " + err.Error())
+ logger.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
- common.SysError("failed to update user request count: " + err.Error())
+ logger.SysError("failed to update user request count: " + err.Error())
}
}
diff --git a/model/utils.go b/model/utils.go
index 1c28340b..e0826e0d 100644
--- a/model/utils.go
+++ b/model/utils.go
@@ -1,7 +1,8 @@
package model
import (
- "one-api/common"
+ "one-api/common/config"
+ "one-api/common/logger"
"sync"
"time"
)
@@ -28,7 +29,7 @@ func init() {
func InitBatchUpdater() {
go func() {
for {
- time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
+ time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second)
batchUpdate()
}
}()
@@ -45,7 +46,7 @@ func addNewRecord(type_ int, id int, value int) {
}
func batchUpdate() {
- common.SysLog("batch update started")
+ logger.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
@@ -57,12 +58,12 @@ func batchUpdate() {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
- common.SysError("failed to batch update user quota: " + err.Error())
+ logger.SysError("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
- common.SysError("failed to batch update token quota: " + err.Error())
+ logger.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
@@ -73,5 +74,5 @@ func batchUpdate() {
}
}
}
- common.SysLog("batch update finished")
+ logger.SysLog("batch update finished")
}
diff --git a/relay/channel/aiproxy/adaptor.go b/relay/channel/aiproxy/adaptor.go
new file mode 100644
index 00000000..44b6f58d
--- /dev/null
+++ b/relay/channel/aiproxy/adaptor.go
@@ -0,0 +1,22 @@
+package aiproxy
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/controller/relay-aiproxy.go b/relay/channel/aiproxy/main.go
similarity index 52%
rename from controller/relay-aiproxy.go
rename to relay/channel/aiproxy/main.go
index 543954f7..af9cd6f6 100644
--- a/controller/relay-aiproxy.go
+++ b/relay/channel/aiproxy/main.go
@@ -1,4 +1,4 @@
-package controller
+package aiproxy
import (
"bufio"
@@ -8,56 +8,29 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/helper"
+ "one-api/common/logger"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
"strconv"
"strings"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
-type AIProxyLibraryRequest struct {
- Model string `json:"model"`
- Query string `json:"query"`
- LibraryId string `json:"libraryId"`
- Stream bool `json:"stream"`
-}
-
-type AIProxyLibraryError struct {
- ErrCode int `json:"errCode"`
- Message string `json:"message"`
-}
-
-type AIProxyLibraryDocument struct {
- Title string `json:"title"`
- URL string `json:"url"`
-}
-
-type AIProxyLibraryResponse struct {
- Success bool `json:"success"`
- Answer string `json:"answer"`
- Documents []AIProxyLibraryDocument `json:"documents"`
- AIProxyLibraryError
-}
-
-type AIProxyLibraryStreamResponse struct {
- Content string `json:"content"`
- Finish bool `json:"finish"`
- Model string `json:"model"`
- Documents []AIProxyLibraryDocument `json:"documents"`
-}
-
-func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
+func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].StringContent()
}
- return &AIProxyLibraryRequest{
+ return &LibraryRequest{
Model: request.Model,
Stream: request.Stream,
Query: query,
}
}
-func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
+func aiProxyDocuments2Markdown(documents []LibraryDocument) string {
if len(documents) == 0 {
return ""
}
@@ -68,52 +41,52 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
return content
}
-func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
+func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse {
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
- choice := OpenAITextResponseChoice{
+ choice := openai.TextResponseChoice{
Index: 0,
- Message: Message{
+ Message: openai.Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
- fullTextResponse := OpenAITextResponse{
- Id: common.GetUUID(),
+ fullTextResponse := openai.TextResponse{
+ Id: helper.GetUUID(),
Object: "chat.completion",
- Created: common.GetTimestamp(),
- Choices: []OpenAITextResponseChoice{choice},
+ Created: helper.GetTimestamp(),
+ Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
-func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
- var choice ChatCompletionsStreamResponseChoice
+func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse {
+ var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
- choice.FinishReason = &stopFinishReason
- return &ChatCompletionsStreamResponse{
- Id: common.GetUUID(),
+ choice.FinishReason = &constant.StopFinishReason
+ return &openai.ChatCompletionsStreamResponse{
+ Id: helper.GetUUID(),
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "",
- Choices: []ChatCompletionsStreamResponseChoice{choice},
+ Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
}
-func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
- var choice ChatCompletionsStreamResponseChoice
+func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse {
+ var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
- return &ChatCompletionsStreamResponse{
- Id: common.GetUUID(),
+ return &openai.ChatCompletionsStreamResponse{
+ Id: helper.GetUUID(),
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: response.Model,
- Choices: []ChatCompletionsStreamResponseChoice{choice},
+ Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
}
-func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
- var usage Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+ var usage openai.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -143,15 +116,15 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
}
stopChan <- true
}()
- setEventStreamHeaders(c)
- var documents []AIProxyLibraryDocument
+ common.SetEventStreamHeaders(c)
+ var documents []LibraryDocument
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
- var AIProxyLibraryResponse AIProxyLibraryStreamResponse
+ var AIProxyLibraryResponse LibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
@@ -160,7 +133,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -169,7 +142,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -179,28 +152,28 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
})
err := resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
-func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
- var AIProxyLibraryResponse AIProxyLibraryResponse
+func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+ var AIProxyLibraryResponse LibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: OpenAIError{
+ return &openai.ErrorWithStatusCode{
+ Error: openai.Error{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
@@ -211,7 +184,7 @@ func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
diff --git a/relay/channel/aiproxy/model.go b/relay/channel/aiproxy/model.go
new file mode 100644
index 00000000..39689b3d
--- /dev/null
+++ b/relay/channel/aiproxy/model.go
@@ -0,0 +1,32 @@
+package aiproxy
+
+type LibraryRequest struct {
+ Model string `json:"model"`
+ Query string `json:"query"`
+ LibraryId string `json:"libraryId"`
+ Stream bool `json:"stream"`
+}
+
+type LibraryError struct {
+ ErrCode int `json:"errCode"`
+ Message string `json:"message"`
+}
+
+type LibraryDocument struct {
+ Title string `json:"title"`
+ URL string `json:"url"`
+}
+
+type LibraryResponse struct {
+ Success bool `json:"success"`
+ Answer string `json:"answer"`
+ Documents []LibraryDocument `json:"documents"`
+ LibraryError
+}
+
+type LibraryStreamResponse struct {
+ Content string `json:"content"`
+ Finish bool `json:"finish"`
+ Model string `json:"model"`
+ Documents []LibraryDocument `json:"documents"`
+}
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
new file mode 100644
index 00000000..49022cfc
--- /dev/null
+++ b/relay/channel/ali/adaptor.go
@@ -0,0 +1,22 @@
+package ali
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/controller/relay-ali.go b/relay/channel/ali/main.go
similarity index 51%
rename from controller/relay-ali.go
rename to relay/channel/ali/main.go
index df1cc084..81dc5370 100644
--- a/controller/relay-ali.go
+++ b/relay/channel/ali/main.go
@@ -1,4 +1,4 @@
-package controller
+package ali
import (
"bufio"
@@ -7,112 +7,45 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/helper"
+ "one-api/common/logger"
+ "one-api/relay/channel/openai"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
-type AliMessage struct {
- Content string `json:"content"`
- Role string `json:"role"`
-}
+const EnableSearchModelSuffix = "-internet"
-type AliInput struct {
- //Prompt string `json:"prompt"`
- Messages []AliMessage `json:"messages"`
-}
-
-type AliParameters struct {
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- Seed uint64 `json:"seed,omitempty"`
- EnableSearch bool `json:"enable_search,omitempty"`
- IncrementalOutput bool `json:"incremental_output,omitempty"`
-}
-
-type AliChatRequest struct {
- Model string `json:"model"`
- Input AliInput `json:"input"`
- Parameters AliParameters `json:"parameters,omitempty"`
-}
-
-type AliEmbeddingRequest struct {
- Model string `json:"model"`
- Input struct {
- Texts []string `json:"texts"`
- } `json:"input"`
- Parameters *struct {
- TextType string `json:"text_type,omitempty"`
- } `json:"parameters,omitempty"`
-}
-
-type AliEmbedding struct {
- Embedding []float64 `json:"embedding"`
- TextIndex int `json:"text_index"`
-}
-
-type AliEmbeddingResponse struct {
- Output struct {
- Embeddings []AliEmbedding `json:"embeddings"`
- } `json:"output"`
- Usage AliUsage `json:"usage"`
- AliError
-}
-
-type AliError struct {
- Code string `json:"code"`
- Message string `json:"message"`
- RequestId string `json:"request_id"`
-}
-
-type AliUsage struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- TotalTokens int `json:"total_tokens"`
-}
-
-type AliOutput struct {
- Text string `json:"text"`
- FinishReason string `json:"finish_reason"`
-}
-
-type AliChatResponse struct {
- Output AliOutput `json:"output"`
- Usage AliUsage `json:"usage"`
- AliError
-}
-
-const AliEnableSearchModelSuffix = "-internet"
-
-func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
- messages := make([]AliMessage, 0, len(request.Messages))
+func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
+ messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
- messages = append(messages, AliMessage{
+ messages = append(messages, Message{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
- if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
+ if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
- aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
+ aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
- return &AliChatRequest{
+ return &ChatRequest{
Model: aliModel,
- Input: AliInput{
+ Input: Input{
Messages: messages,
},
- Parameters: AliParameters{
+ Parameters: Parameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
},
}
}
-func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
- return &AliEmbeddingRequest{
+func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
+ return &EmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
@@ -122,21 +55,21 @@ func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingReque
}
}
-func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
- var aliResponse AliEmbeddingResponse
+func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+ var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: OpenAIError{
+ return &openai.ErrorWithStatusCode{
+ Error: openai.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
@@ -149,7 +82,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -157,16 +90,16 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
return nil, &fullTextResponse.Usage
}
-func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
- openAIEmbeddingResponse := OpenAIEmbeddingResponse{
+func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
+ openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
- Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
+ Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
- Usage: Usage{TotalTokens: response.Usage.TotalTokens},
+ Usage: openai.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
- openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
+ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
@@ -175,21 +108,21 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin
return &openAIEmbeddingResponse
}
-func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
- choice := OpenAITextResponseChoice{
+func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
+ choice := openai.TextResponseChoice{
Index: 0,
- Message: Message{
+ Message: openai.Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
- fullTextResponse := OpenAITextResponse{
+ fullTextResponse := openai.TextResponse{
Id: response.RequestId,
Object: "chat.completion",
- Created: common.GetTimestamp(),
- Choices: []OpenAITextResponseChoice{choice},
- Usage: Usage{
+ Created: helper.GetTimestamp(),
+ Choices: []openai.TextResponseChoice{choice},
+ Usage: openai.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
@@ -198,25 +131,25 @@ func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
return &fullTextResponse
}
-func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
- var choice ChatCompletionsStreamResponseChoice
+func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
+ var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
- response := ChatCompletionsStreamResponse{
+ response := openai.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "qwen",
- Choices: []ChatCompletionsStreamResponseChoice{choice},
+ Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
-func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
- var usage Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+ var usage openai.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -246,15 +179,15 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
}
stopChan <- true
}()
- setEventStreamHeaders(c)
+ common.SetEventStreamHeaders(c)
//lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
- var aliResponse AliChatResponse
+ var aliResponse ChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
@@ -267,7 +200,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -279,28 +212,28 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
})
err := resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
-func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
- var aliResponse AliChatResponse
+func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+ var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: OpenAIError{
+ return &openai.ErrorWithStatusCode{
+ Error: openai.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
@@ -313,7 +246,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode
fullTextResponse.Model = "qwen"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go
new file mode 100644
index 00000000..54f13041
--- /dev/null
+++ b/relay/channel/ali/model.go
@@ -0,0 +1,71 @@
+package ali
+
+type Message struct {
+ Content string `json:"content"`
+ Role string `json:"role"`
+}
+
+type Input struct {
+ //Prompt string `json:"prompt"`
+ Messages []Message `json:"messages"`
+}
+
+type Parameters struct {
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Seed uint64 `json:"seed,omitempty"`
+ EnableSearch bool `json:"enable_search,omitempty"`
+ IncrementalOutput bool `json:"incremental_output,omitempty"`
+}
+
+type ChatRequest struct {
+ Model string `json:"model"`
+ Input Input `json:"input"`
+ Parameters Parameters `json:"parameters,omitempty"`
+}
+
+type EmbeddingRequest struct {
+ Model string `json:"model"`
+ Input struct {
+ Texts []string `json:"texts"`
+ } `json:"input"`
+ Parameters *struct {
+ TextType string `json:"text_type,omitempty"`
+ } `json:"parameters,omitempty"`
+}
+
+type Embedding struct {
+ Embedding []float64 `json:"embedding"`
+ TextIndex int `json:"text_index"`
+}
+
+type EmbeddingResponse struct {
+ Output struct {
+ Embeddings []Embedding `json:"embeddings"`
+ } `json:"output"`
+ Usage Usage `json:"usage"`
+ Error
+}
+
+type Error struct {
+ Code string `json:"code"`
+ Message string `json:"message"`
+ RequestId string `json:"request_id"`
+}
+
+type Usage struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+type Output struct {
+ Text string `json:"text"`
+ FinishReason string `json:"finish_reason"`
+}
+
+type ChatResponse struct {
+ Output Output `json:"output"`
+ Usage Usage `json:"usage"`
+ Error
+}
diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go
new file mode 100644
index 00000000..55577228
--- /dev/null
+++ b/relay/channel/anthropic/adaptor.go
@@ -0,0 +1,22 @@
+package anthropic
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/controller/relay-claude.go b/relay/channel/anthropic/main.go
similarity index 59%
rename from controller/relay-claude.go
rename to relay/channel/anthropic/main.go
index ca7a701a..060fcde8 100644
--- a/controller/relay-claude.go
+++ b/relay/channel/anthropic/main.go
@@ -1,4 +1,4 @@
-package controller
+package anthropic
import (
"bufio"
@@ -8,37 +8,12 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/helper"
+ "one-api/common/logger"
+ "one-api/relay/channel/openai"
"strings"
)
-type ClaudeMetadata struct {
- UserId string `json:"user_id"`
-}
-
-type ClaudeRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt"`
- MaxTokensToSample int `json:"max_tokens_to_sample"`
- StopSequences []string `json:"stop_sequences,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- //ClaudeMetadata `json:"metadata,omitempty"`
- Stream bool `json:"stream,omitempty"`
-}
-
-type ClaudeError struct {
- Type string `json:"type"`
- Message string `json:"message"`
-}
-
-type ClaudeResponse struct {
- Completion string `json:"completion"`
- StopReason string `json:"stop_reason"`
- Model string `json:"model"`
- Error ClaudeError `json:"error"`
-}
-
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
@@ -50,8 +25,8 @@ func stopReasonClaude2OpenAI(reason string) string {
}
}
-func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
- claudeRequest := ClaudeRequest{
+func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request {
+ claudeRequest := Request{
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
@@ -80,43 +55,43 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
return &claudeRequest
}
-func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
- var choice ChatCompletionsStreamResponseChoice
+func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse {
+ var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
- var response ChatCompletionsStreamResponse
+ var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
- response.Choices = []ChatCompletionsStreamResponseChoice{choice}
+ response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
}
-func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
- choice := OpenAITextResponseChoice{
+func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
+ choice := openai.TextResponseChoice{
Index: 0,
- Message: Message{
+ Message: openai.Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
- fullTextResponse := OpenAITextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ fullTextResponse := openai.TextResponse{
+ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
- Created: common.GetTimestamp(),
- Choices: []OpenAITextResponseChoice{choice},
+ Created: helper.GetTimestamp(),
+ Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
-func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
responseText := ""
- responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
- createdTime := common.GetTimestamp()
+ responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
+ createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -143,16 +118,16 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
}
stopChan <- true
}()
- setEventStreamHeaders(c)
+ common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
- var claudeResponse ClaudeResponse
+ var claudeResponse Response
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
@@ -161,7 +136,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
@@ -173,28 +148,28 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
})
err := resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
-func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
+func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
- var claudeResponse ClaudeResponse
+ var claudeResponse Response
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: OpenAIError{
+ return &openai.ErrorWithStatusCode{
+ Error: openai.Error{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
@@ -205,8 +180,8 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = model
- completionTokens := countTokenText(claudeResponse.Completion, model)
- usage := Usage{
+ completionTokens := openai.CountTokenText(claudeResponse.Completion, model)
+ usage := openai.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
@@ -214,7 +189,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
diff --git a/relay/channel/anthropic/model.go b/relay/channel/anthropic/model.go
new file mode 100644
index 00000000..70fc9430
--- /dev/null
+++ b/relay/channel/anthropic/model.go
@@ -0,0 +1,29 @@
+package anthropic
+
+type Metadata struct {
+ UserId string `json:"user_id"`
+}
+
+type Request struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+ MaxTokensToSample int `json:"max_tokens_to_sample"`
+ StopSequences []string `json:"stop_sequences,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ //Metadata `json:"metadata,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+}
+
+type Error struct {
+ Type string `json:"type"`
+ Message string `json:"message"`
+}
+
+type Response struct {
+ Completion string `json:"completion"`
+ StopReason string `json:"stop_reason"`
+ Model string `json:"model"`
+ Error Error `json:"error"`
+}
diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go
new file mode 100644
index 00000000..498b664a
--- /dev/null
+++ b/relay/channel/baidu/adaptor.go
@@ -0,0 +1,22 @@
+package baidu
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/controller/relay-baidu.go b/relay/channel/baidu/main.go
similarity index 55%
rename from controller/relay-baidu.go
rename to relay/channel/baidu/main.go
index dca30da1..f5b98155 100644
--- a/controller/relay-baidu.go
+++ b/relay/channel/baidu/main.go
@@ -1,4 +1,4 @@
-package controller
+package baidu
import (
"bufio"
@@ -9,6 +9,10 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
+ "one-api/relay/util"
"strings"
"sync"
"time"
@@ -16,148 +20,104 @@ import (
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
-type BaiduTokenResponse struct {
+type TokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}
-type BaiduMessage struct {
+type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
-type BaiduChatRequest struct {
- Messages []BaiduMessage `json:"messages"`
- Stream bool `json:"stream"`
- UserId string `json:"user_id,omitempty"`
+type ChatRequest struct {
+ Messages []Message `json:"messages"`
+ Stream bool `json:"stream"`
+ UserId string `json:"user_id,omitempty"`
}
-type BaiduError struct {
+type Error struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
-type BaiduChatResponse struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Result string `json:"result"`
- IsTruncated bool `json:"is_truncated"`
- NeedClearHistory bool `json:"need_clear_history"`
- Usage Usage `json:"usage"`
- BaiduError
-}
-
-type BaiduChatStreamResponse struct {
- BaiduChatResponse
- SentenceId int `json:"sentence_id"`
- IsEnd bool `json:"is_end"`
-}
-
-type BaiduEmbeddingRequest struct {
- Input []string `json:"input"`
-}
-
-type BaiduEmbeddingData struct {
- Object string `json:"object"`
- Embedding []float64 `json:"embedding"`
- Index int `json:"index"`
-}
-
-type BaiduEmbeddingResponse struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Data []BaiduEmbeddingData `json:"data"`
- Usage Usage `json:"usage"`
- BaiduError
-}
-
-type BaiduAccessToken struct {
- AccessToken string `json:"access_token"`
- Error string `json:"error,omitempty"`
- ErrorDescription string `json:"error_description,omitempty"`
- ExpiresIn int64 `json:"expires_in,omitempty"`
- ExpiresAt time.Time `json:"-"`
-}
-
var baiduTokenStore sync.Map
-func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
- messages := make([]BaiduMessage, 0, len(request.Messages))
+func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
+ messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
- messages = append(messages, BaiduMessage{
+ messages = append(messages, Message{
Role: "user",
Content: message.StringContent(),
})
- messages = append(messages, BaiduMessage{
+ messages = append(messages, Message{
Role: "assistant",
Content: "Okay",
})
} else {
- messages = append(messages, BaiduMessage{
+ messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
}
- return &BaiduChatRequest{
+ return &ChatRequest{
Messages: messages,
Stream: request.Stream,
}
}
-func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
- choice := OpenAITextResponseChoice{
+func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
+ choice := openai.TextResponseChoice{
Index: 0,
- Message: Message{
+ Message: openai.Message{
Role: "assistant",
Content: response.Result,
},
FinishReason: "stop",
}
- fullTextResponse := OpenAITextResponse{
+ fullTextResponse := openai.TextResponse{
Id: response.Id,
Object: "chat.completion",
Created: response.Created,
- Choices: []OpenAITextResponseChoice{choice},
+ Choices: []openai.TextResponseChoice{choice},
Usage: response.Usage,
}
return &fullTextResponse
}
-func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
- var choice ChatCompletionsStreamResponseChoice
+func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse {
+ var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
- choice.FinishReason = &stopFinishReason
+ choice.FinishReason = &constant.StopFinishReason
}
- response := ChatCompletionsStreamResponse{
+ response := openai.ChatCompletionsStreamResponse{
Id: baiduResponse.Id,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: "ernie-bot",
- Choices: []ChatCompletionsStreamResponseChoice{choice},
+ Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
-func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
- return &BaiduEmbeddingRequest{
+func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
+ return &EmbeddingRequest{
Input: request.ParseInput(),
}
}
-func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
- openAIEmbeddingResponse := OpenAIEmbeddingResponse{
+func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
+ openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
- Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
+ Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
- openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
+ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
@@ -166,8 +126,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe
return &openAIEmbeddingResponse
}
-func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
- var usage Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+ var usage openai.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -194,14 +154,14 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
}
stopChan <- true
}()
- setEventStreamHeaders(c)
+ common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
- var baiduResponse BaiduChatStreamResponse
+ var baiduResponse ChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
@@ -212,7 +172,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -224,28 +184,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
})
err := resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
-func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
- var baiduResponse BaiduChatResponse
+func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+ var baiduResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: OpenAIError{
+ return &openai.ErrorWithStatusCode{
+ Error: openai.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -258,7 +218,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
fullTextResponse.Model = "ernie-bot"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -266,23 +226,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
return nil, &fullTextResponse.Usage
}
-func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
- var baiduResponse BaiduEmbeddingResponse
+func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+ var baiduResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: OpenAIError{
+ return &openai.ErrorWithStatusCode{
+ Error: openai.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -294,7 +254,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -302,10 +262,10 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
return nil, &fullTextResponse.Usage
}
-func getBaiduAccessToken(apiKey string) (string, error) {
+func GetAccessToken(apiKey string) (string, error) {
if val, ok := baiduTokenStore.Load(apiKey); ok {
- var accessToken BaiduAccessToken
- if accessToken, ok = val.(BaiduAccessToken); ok {
+ var accessToken AccessToken
+ if accessToken, ok = val.(AccessToken); ok {
// soon this will expire
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go func() {
@@ -320,12 +280,12 @@ func getBaiduAccessToken(apiKey string) (string, error) {
return "", err
}
if accessToken == nil {
- return "", errors.New("getBaiduAccessToken return a nil token")
+ return "", errors.New("GetAccessToken return a nil token")
}
return (*accessToken).AccessToken, nil
}
-func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
+func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
@@ -337,13 +297,13 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
- res, err := impatientHTTPClient.Do(req)
+ res, err := util.ImpatientHTTPClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
- var accessToken BaiduAccessToken
+ var accessToken AccessToken
err = json.NewDecoder(res.Body).Decode(&accessToken)
if err != nil {
return nil, err
diff --git a/relay/channel/baidu/model.go b/relay/channel/baidu/model.go
new file mode 100644
index 00000000..e182f5dd
--- /dev/null
+++ b/relay/channel/baidu/model.go
@@ -0,0 +1,50 @@
+package baidu
+
+import (
+ "one-api/relay/channel/openai"
+ "time"
+)
+
+type ChatResponse struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Result string `json:"result"`
+ IsTruncated bool `json:"is_truncated"`
+ NeedClearHistory bool `json:"need_clear_history"`
+ Usage openai.Usage `json:"usage"`
+ Error
+}
+
+type ChatStreamResponse struct {
+ ChatResponse
+ SentenceId int `json:"sentence_id"`
+ IsEnd bool `json:"is_end"`
+}
+
+type EmbeddingRequest struct {
+ Input []string `json:"input"`
+}
+
+type EmbeddingData struct {
+ Object string `json:"object"`
+ Embedding []float64 `json:"embedding"`
+ Index int `json:"index"`
+}
+
+type EmbeddingResponse struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Data []EmbeddingData `json:"data"`
+ Usage openai.Usage `json:"usage"`
+ Error
+}
+
+type AccessToken struct {
+ AccessToken string `json:"access_token"`
+ Error string `json:"error,omitempty"`
+ ErrorDescription string `json:"error_description,omitempty"`
+ ExpiresIn int64 `json:"expires_in,omitempty"`
+ ExpiresAt time.Time `json:"-"`
+}
diff --git a/relay/channel/google/adaptor.go b/relay/channel/google/adaptor.go
new file mode 100644
index 00000000..b328db32
--- /dev/null
+++ b/relay/channel/google/adaptor.go
@@ -0,0 +1,22 @@
+package google
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/controller/relay-gemini.go b/relay/channel/google/gemini.go
similarity index 65%
rename from controller/relay-gemini.go
rename to relay/channel/google/gemini.go
index d8ab58d6..3adc3fdd 100644
--- a/controller/relay-gemini.go
+++ b/relay/channel/google/gemini.go
@@ -1,4 +1,4 @@
-package controller
+package google
import (
"bufio"
@@ -7,7 +7,12 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/helper"
"one-api/common/image"
+ "one-api/common/logger"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
"strings"
"github.com/gin-gonic/gin"
@@ -19,66 +24,26 @@ const (
GeminiVisionMaxImageNum = 16
)
-type GeminiChatRequest struct {
- Contents []GeminiChatContent `json:"contents"`
- SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
- GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
- Tools []GeminiChatTools `json:"tools,omitempty"`
-}
-
-type GeminiInlineData struct {
- MimeType string `json:"mimeType"`
- Data string `json:"data"`
-}
-
-type GeminiPart struct {
- Text string `json:"text,omitempty"`
- InlineData *GeminiInlineData `json:"inlineData,omitempty"`
-}
-
-type GeminiChatContent struct {
- Role string `json:"role,omitempty"`
- Parts []GeminiPart `json:"parts"`
-}
-
-type GeminiChatSafetySettings struct {
- Category string `json:"category"`
- Threshold string `json:"threshold"`
-}
-
-type GeminiChatTools struct {
- FunctionDeclarations any `json:"functionDeclarations,omitempty"`
-}
-
-type GeminiChatGenerationConfig struct {
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"topP,omitempty"`
- TopK float64 `json:"topK,omitempty"`
- MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- StopSequences []string `json:"stopSequences,omitempty"`
-}
-
// Setting safety to the lowest possible values since Gemini is already powerless enough
-func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
+func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
- Threshold: common.GeminiSafetySetting,
+ Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
- Threshold: common.GeminiSafetySetting,
+ Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
- Threshold: common.GeminiSafetySetting,
+ Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
- Threshold: common.GeminiSafetySetting,
+ Threshold: config.GeminiSafetySetting,
},
},
GenerationConfig: GeminiChatGenerationConfig{
@@ -108,11 +73,11 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
- if part.Type == ContentTypeText {
+ if part.Type == openai.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
- } else if part.Type == ContentTypeImageURL {
+ } else if part.Type == openai.ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
@@ -187,21 +152,21 @@ type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
-func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
- fullTextResponse := OpenAITextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse {
+ fullTextResponse := openai.TextResponse{
+ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
- Created: common.GetTimestamp(),
- Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
+ Created: helper.GetTimestamp(),
+ Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
- choice := OpenAITextResponseChoice{
+ choice := openai.TextResponseChoice{
Index: i,
- Message: Message{
+ Message: openai.Message{
Role: "assistant",
Content: "",
},
- FinishReason: stopFinishReason,
+ FinishReason: constant.StopFinishReason,
}
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text
@@ -211,18 +176,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
return &fullTextResponse
}
-func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
- var choice ChatCompletionsStreamResponseChoice
+func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse {
+ var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
- choice.FinishReason = &stopFinishReason
- var response ChatCompletionsStreamResponse
+ choice.FinishReason = &constant.StopFinishReason
+ var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
- response.Choices = []ChatCompletionsStreamResponseChoice{choice}
+ response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
}
-func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
@@ -252,7 +217,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
}
stopChan <- true
}()
- setEventStreamHeaders(c)
+ common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -264,18 +229,18 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
- var choice ChatCompletionsStreamResponseChoice
+ var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
- response := ChatCompletionsStreamResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ response := openai.ChatCompletionsStreamResponse{
+ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "gemini-pro",
- Choices: []ChatCompletionsStreamResponseChoice{choice},
+ Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -287,28 +252,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
})
err := resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
-func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
+func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: OpenAIError{
+ return &openai.ErrorWithStatusCode{
+ Error: openai.Error{
Message: "No candidates returned",
Type: "server_error",
Param: "",
@@ -319,8 +284,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Model = model
- completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
- usage := Usage{
+ completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model)
+ usage := openai.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
@@ -328,7 +293,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
diff --git a/relay/channel/google/model.go b/relay/channel/google/model.go
new file mode 100644
index 00000000..694c2dd1
--- /dev/null
+++ b/relay/channel/google/model.go
@@ -0,0 +1,80 @@
+package google
+
+import (
+ "one-api/relay/channel/openai"
+)
+
+type GeminiChatRequest struct {
+ Contents []GeminiChatContent `json:"contents"`
+ SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
+ GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
+ Tools []GeminiChatTools `json:"tools,omitempty"`
+}
+
+type GeminiInlineData struct {
+ MimeType string `json:"mimeType"`
+ Data string `json:"data"`
+}
+
+type GeminiPart struct {
+ Text string `json:"text,omitempty"`
+ InlineData *GeminiInlineData `json:"inlineData,omitempty"`
+}
+
+type GeminiChatContent struct {
+ Role string `json:"role,omitempty"`
+ Parts []GeminiPart `json:"parts"`
+}
+
+type GeminiChatSafetySettings struct {
+ Category string `json:"category"`
+ Threshold string `json:"threshold"`
+}
+
+type GeminiChatTools struct {
+ FunctionDeclarations any `json:"functionDeclarations,omitempty"`
+}
+
+type GeminiChatGenerationConfig struct {
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+ TopK float64 `json:"topK,omitempty"`
+ MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
+ CandidateCount int `json:"candidateCount,omitempty"`
+ StopSequences []string `json:"stopSequences,omitempty"`
+}
+
+type PaLMChatMessage struct {
+ Author string `json:"author"`
+ Content string `json:"content"`
+}
+
+type PaLMFilter struct {
+ Reason string `json:"reason"`
+ Message string `json:"message"`
+}
+
+type PaLMPrompt struct {
+ Messages []PaLMChatMessage `json:"messages"`
+}
+
+type PaLMChatRequest struct {
+ Prompt PaLMPrompt `json:"prompt"`
+ Temperature float64 `json:"temperature,omitempty"`
+ CandidateCount int `json:"candidateCount,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+ TopK int `json:"topK,omitempty"`
+}
+
+type PaLMError struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Status string `json:"status"`
+}
+
+type PaLMChatResponse struct {
+ Candidates []PaLMChatMessage `json:"candidates"`
+ Messages []openai.Message `json:"messages"`
+ Filters []PaLMFilter `json:"filters"`
+ Error PaLMError `json:"error"`
+}
diff --git a/controller/relay-palm.go b/relay/channel/google/palm.go
similarity index 56%
rename from controller/relay-palm.go
rename to relay/channel/google/palm.go
index 0c1c8af6..3c86a432 100644
--- a/controller/relay-palm.go
+++ b/relay/channel/google/palm.go
@@ -1,4 +1,4 @@
-package controller
+package google
import (
"encoding/json"
@@ -7,47 +7,16 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/helper"
+ "one-api/common/logger"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
)
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
-type PaLMChatMessage struct {
- Author string `json:"author"`
- Content string `json:"content"`
-}
-
-type PaLMFilter struct {
- Reason string `json:"reason"`
- Message string `json:"message"`
-}
-
-type PaLMPrompt struct {
- Messages []PaLMChatMessage `json:"messages"`
-}
-
-type PaLMChatRequest struct {
- Prompt PaLMPrompt `json:"prompt"`
- Temperature float64 `json:"temperature,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- TopP float64 `json:"topP,omitempty"`
- TopK int `json:"topK,omitempty"`
-}
-
-type PaLMError struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Status string `json:"status"`
-}
-
-type PaLMChatResponse struct {
- Candidates []PaLMChatMessage `json:"candidates"`
- Messages []Message `json:"messages"`
- Filters []PaLMFilter `json:"filters"`
- Error PaLMError `json:"error"`
-}
-
-func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
+func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
@@ -71,14 +40,14 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
return &palmRequest
}
-func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
- fullTextResponse := OpenAITextResponse{
- Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
+func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
+ fullTextResponse := openai.TextResponse{
+ Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
- choice := OpenAITextResponseChoice{
+ choice := openai.TextResponseChoice{
Index: i,
- Message: Message{
+ Message: openai.Message{
Role: "assistant",
Content: candidate.Content,
},
@@ -89,42 +58,42 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
return &fullTextResponse
}
-func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
- var choice ChatCompletionsStreamResponseChoice
+func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse {
+ var choice openai.ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
}
- choice.FinishReason = &stopFinishReason
- var response ChatCompletionsStreamResponse
+ choice.FinishReason = &constant.StopFinishReason
+ var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "palm2"
- response.Choices = []ChatCompletionsStreamResponseChoice{choice}
+ response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
}
-func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
+func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
responseText := ""
- responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
- createdTime := common.GetTimestamp()
+ responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
+ createdTime := helper.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- common.SysError("error reading stream response: " + err.Error())
+ logger.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
err = resp.Body.Close()
if err != nil {
- common.SysError("error closing stream response: " + err.Error())
+ logger.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
@@ -136,14 +105,14 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
dataChan <- string(jsonResponse)
stopChan <- true
}()
- setEventStreamHeaders(c)
+ common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -156,28 +125,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
})
err := resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
-func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
+func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: OpenAIError{
+ return &openai.ErrorWithStatusCode{
+ Error: openai.Error{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
@@ -188,8 +157,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
fullTextResponse.Model = model
- completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
- usage := Usage{
+ completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model)
+ usage := openai.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
@@ -197,7 +166,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
diff --git a/relay/channel/interface.go b/relay/channel/interface.go
new file mode 100644
index 00000000..7a0fcbd3
--- /dev/null
+++ b/relay/channel/interface.go
@@ -0,0 +1,15 @@
+package channel
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor interface {
+ GetRequestURL() string
+ Auth(c *gin.Context) error
+ ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error)
+ DoRequest(request *openai.GeneralOpenAIRequest) error
+ DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error)
+}
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
new file mode 100644
index 00000000..cc302611
--- /dev/null
+++ b/relay/channel/openai/adaptor.go
@@ -0,0 +1,21 @@
+package openai
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*ErrorWithStatusCode, *Usage, error) {
+ return nil, nil, nil
+}
diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go
new file mode 100644
index 00000000..000f72ee
--- /dev/null
+++ b/relay/channel/openai/constant.go
@@ -0,0 +1,6 @@
+package openai
+
+const (
+ ContentTypeText = "text"
+ ContentTypeImageURL = "image_url"
+)
diff --git a/controller/relay-openai.go b/relay/channel/openai/main.go
similarity index 73%
rename from controller/relay-openai.go
rename to relay/channel/openai/main.go
index 37867843..5f464249 100644
--- a/controller/relay-openai.go
+++ b/relay/channel/openai/main.go
@@ -1,4 +1,4 @@
-package controller
+package openai
import (
"bufio"
@@ -8,10 +8,12 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
+ "one-api/relay/constant"
"strings"
)
-func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -41,21 +43,21 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
switch relayMode {
- case RelayModeChatCompletions:
+ case constant.RelayModeChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
}
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content
}
- case RelayModeCompletions:
+ case constant.RelayModeCompletions:
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
@@ -66,7 +68,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
}
stopChan <- true
}()
- setEventStreamHeaders(c)
+ common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -83,29 +85,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
})
err := resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+ return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
-func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
- var textResponse TextResponse
+func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) {
+ var textResponse SlimTextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: textResponse.Error,
- StatusCode: resp.StatusCode,
+ return &ErrorWithStatusCode{
+ Error: textResponse.Error,
+ StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
@@ -113,7 +115,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
+ // So the HTTPClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
@@ -121,17 +123,17 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
- return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+ return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
- completionTokens += countTokenText(choice.Message.StringContent(), model)
+ completionTokens += CountTokenText(choice.Message.StringContent(), model)
}
textResponse.Usage = Usage{
PromptTokens: promptTokens,
diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go
new file mode 100644
index 00000000..937fb424
--- /dev/null
+++ b/relay/channel/openai/model.go
@@ -0,0 +1,288 @@
+package openai
+
+type Message struct {
+ Role string `json:"role"`
+ Content any `json:"content"`
+ Name *string `json:"name,omitempty"`
+}
+
+type ImageURL struct {
+ Url string `json:"url,omitempty"`
+ Detail string `json:"detail,omitempty"`
+}
+
+type TextContent struct {
+ Type string `json:"type,omitempty"`
+ Text string `json:"text,omitempty"`
+}
+
+type ImageContent struct {
+ Type string `json:"type,omitempty"`
+ ImageURL *ImageURL `json:"image_url,omitempty"`
+}
+
+type OpenAIMessageContent struct {
+ Type string `json:"type,omitempty"`
+ Text string `json:"text"`
+ ImageURL *ImageURL `json:"image_url,omitempty"`
+}
+
+func (m Message) IsStringContent() bool {
+ _, ok := m.Content.(string)
+ return ok
+}
+
+func (m Message) StringContent() string {
+ content, ok := m.Content.(string)
+ if ok {
+ return content
+ }
+ contentList, ok := m.Content.([]any)
+ if ok {
+ var contentStr string
+ for _, contentItem := range contentList {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+ return ""
+}
+
+func (m Message) ParseContent() []OpenAIMessageContent {
+ var contentList []OpenAIMessageContent
+ content, ok := m.Content.(string)
+ if ok {
+ contentList = append(contentList, OpenAIMessageContent{
+ Type: ContentTypeText,
+ Text: content,
+ })
+ return contentList
+ }
+ anyList, ok := m.Content.([]any)
+ if ok {
+ for _, contentItem := range anyList {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ switch contentMap["type"] {
+ case ContentTypeText:
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentList = append(contentList, OpenAIMessageContent{
+ Type: ContentTypeText,
+ Text: subStr,
+ })
+ }
+ case ContentTypeImageURL:
+ if subObj, ok := contentMap["image_url"].(map[string]any); ok {
+ contentList = append(contentList, OpenAIMessageContent{
+ Type: ContentTypeImageURL,
+ ImageURL: &ImageURL{
+ Url: subObj["url"].(string),
+ },
+ })
+ }
+ }
+ }
+ return contentList
+ }
+ return nil
+}
+
+type ResponseFormat struct {
+ Type string `json:"type,omitempty"`
+}
+
+type GeneralOpenAIRequest struct {
+ Model string `json:"model,omitempty"`
+ Messages []Message `json:"messages,omitempty"`
+ Prompt any `json:"prompt,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ N int `json:"n,omitempty"`
+ Input any `json:"input,omitempty"`
+ Instruction string `json:"instruction,omitempty"`
+ Size string `json:"size,omitempty"`
+ Functions any `json:"functions,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
+ Seed float64 `json:"seed,omitempty"`
+ Tools any `json:"tools,omitempty"`
+ ToolChoice any `json:"tool_choice,omitempty"`
+ User string `json:"user,omitempty"`
+}
+
+func (r GeneralOpenAIRequest) ParseInput() []string {
+ if r.Input == nil {
+ return nil
+ }
+ var input []string
+ switch r.Input.(type) {
+ case string:
+ input = []string{r.Input.(string)}
+ case []any:
+ input = make([]string, 0, len(r.Input.([]any)))
+ for _, item := range r.Input.([]any) {
+ if str, ok := item.(string); ok {
+ input = append(input, str)
+ }
+ }
+ }
+ return input
+}
+
+type ChatRequest struct {
+ Model string `json:"model"`
+ Messages []Message `json:"messages"`
+ MaxTokens int `json:"max_tokens"`
+}
+
+type TextRequest struct {
+ Model string `json:"model"`
+ Messages []Message `json:"messages"`
+ Prompt string `json:"prompt"`
+ MaxTokens int `json:"max_tokens"`
+ //Stream bool `json:"stream"`
+}
+
+// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
+type ImageRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt" binding:"required"`
+ N int `json:"n,omitempty"`
+ Size string `json:"size,omitempty"`
+ Quality string `json:"quality,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
+ Style string `json:"style,omitempty"`
+ User string `json:"user,omitempty"`
+}
+
+type WhisperJSONResponse struct {
+ Text string `json:"text,omitempty"`
+}
+
+type WhisperVerboseJSONResponse struct {
+ Task string `json:"task,omitempty"`
+ Language string `json:"language,omitempty"`
+ Duration float64 `json:"duration,omitempty"`
+ Text string `json:"text,omitempty"`
+ Segments []Segment `json:"segments,omitempty"`
+}
+
+type Segment struct {
+ Id int `json:"id"`
+ Seek int `json:"seek"`
+ Start float64 `json:"start"`
+ End float64 `json:"end"`
+ Text string `json:"text"`
+ Tokens []int `json:"tokens"`
+ Temperature float64 `json:"temperature"`
+ AvgLogprob float64 `json:"avg_logprob"`
+ CompressionRatio float64 `json:"compression_ratio"`
+ NoSpeechProb float64 `json:"no_speech_prob"`
+}
+
+type TextToSpeechRequest struct {
+ Model string `json:"model" binding:"required"`
+ Input string `json:"input" binding:"required"`
+ Voice string `json:"voice" binding:"required"`
+ Speed float64 `json:"speed"`
+ ResponseFormat string `json:"response_format"`
+}
+
+type Usage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+type UsageOrResponseText struct {
+ *Usage
+ ResponseText string
+}
+
+type Error struct {
+ Message string `json:"message"`
+ Type string `json:"type"`
+ Param string `json:"param"`
+ Code any `json:"code"`
+}
+
+type ErrorWithStatusCode struct {
+ Error
+ StatusCode int `json:"status_code"`
+}
+
+type SlimTextResponse struct {
+ Choices []TextResponseChoice `json:"choices"`
+ Usage `json:"usage"`
+ Error Error `json:"error"`
+}
+
+type TextResponseChoice struct {
+ Index int `json:"index"`
+ Message `json:"message"`
+ FinishReason string `json:"finish_reason"`
+}
+
+type TextResponse struct {
+ Id string `json:"id"`
+ Model string `json:"model,omitempty"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Choices []TextResponseChoice `json:"choices"`
+ Usage `json:"usage"`
+}
+
+type EmbeddingResponseItem struct {
+ Object string `json:"object"`
+ Index int `json:"index"`
+ Embedding []float64 `json:"embedding"`
+}
+
+type EmbeddingResponse struct {
+ Object string `json:"object"`
+ Data []EmbeddingResponseItem `json:"data"`
+ Model string `json:"model"`
+ Usage `json:"usage"`
+}
+
+type ImageResponse struct {
+ Created int `json:"created"`
+ Data []struct {
+ Url string `json:"url"`
+ }
+}
+
+type ChatCompletionsStreamResponseChoice struct {
+ Delta struct {
+ Content string `json:"content"`
+ } `json:"delta"`
+ FinishReason *string `json:"finish_reason,omitempty"`
+}
+
+type ChatCompletionsStreamResponse struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
+}
+
+type CompletionsStreamResponse struct {
+ Choices []struct {
+ Text string `json:"text"`
+ FinishReason string `json:"finish_reason"`
+ } `json:"choices"`
+}
diff --git a/controller/relay-utils.go b/relay/channel/openai/token.go
similarity index 51%
rename from controller/relay-utils.go
rename to relay/channel/openai/token.go
index a6a1f0f6..6803770e 100644
--- a/controller/relay-utils.go
+++ b/relay/channel/openai/token.go
@@ -1,39 +1,31 @@
-package controller
+package openai
import (
- "context"
- "encoding/json"
"errors"
"fmt"
- "io"
- "math"
- "net/http"
- "one-api/common"
- "one-api/common/image"
- "one-api/model"
- "strconv"
- "strings"
-
- "github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
+ "math"
+ "one-api/common"
+ "one-api/common/config"
+ "one-api/common/image"
+ "one-api/common/logger"
+ "strings"
)
-var stopFinishReason = "stop"
-
// tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
- common.SysLog("initializing token encoders")
+ logger.SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
+ logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
+ logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
for model, _ := range common.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
@@ -44,7 +36,7 @@ func InitTokenEncoders() {
tokenEncoderMap[model] = nil
}
}
- common.SysLog("token encoders initialized")
+ logger.SysLog("token encoders initialized")
}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
@@ -55,7 +47,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
if ok {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
- common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
+ logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
@@ -65,13 +57,13 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
- if common.ApproximateTokenEnabled {
+ if config.ApproximateTokenEnabled {
return int(float64(len(text)) * 0.38)
}
return len(tokenEncoder.Encode(text, nil, nil))
}
-func countTokenMessages(messages []Message, model string) int {
+func CountTokenMessages(messages []Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
@@ -109,7 +101,7 @@ func countTokenMessages(messages []Message, model string) int {
}
imageTokens, err := countImageTokens(url, detail)
if err != nil {
- common.SysError("error counting image tokens: " + err.Error())
+ logger.SysError("error counting image tokens: " + err.Error())
} else {
tokenNum += imageTokens
}
@@ -195,191 +187,21 @@ func countImageTokens(url string, detail string) (_ int, err error) {
}
}
-func countTokenInput(input any, model string) int {
+func CountTokenInput(input any, model string) int {
switch v := input.(type) {
case string:
- return countTokenText(v, model)
+ return CountTokenText(v, model)
case []string:
text := ""
for _, s := range v {
text += s
}
- return countTokenText(text, model)
+ return CountTokenText(text, model)
}
return 0
}
-func countTokenText(text string, model string) int {
+func CountTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
}
-
-func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
- openAIError := OpenAIError{
- Message: err.Error(),
- Type: "one_api_error",
- Code: code,
- }
- return &OpenAIErrorWithStatusCode{
- OpenAIError: openAIError,
- StatusCode: statusCode,
- }
-}
-
-func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
- if !common.AutomaticDisableChannelEnabled {
- return false
- }
- if err == nil {
- return false
- }
- if statusCode == http.StatusUnauthorized {
- return true
- }
- if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
- return true
- }
- return false
-}
-
-func shouldEnableChannel(err error, openAIErr *OpenAIError) bool {
- if !common.AutomaticEnableChannelEnabled {
- return false
- }
- if err != nil {
- return false
- }
- if openAIErr != nil {
- return false
- }
- return true
-}
-
-func setEventStreamHeaders(c *gin.Context) {
- c.Writer.Header().Set("Content-Type", "text/event-stream")
- c.Writer.Header().Set("Cache-Control", "no-cache")
- c.Writer.Header().Set("Connection", "keep-alive")
- c.Writer.Header().Set("Transfer-Encoding", "chunked")
- c.Writer.Header().Set("X-Accel-Buffering", "no")
-}
-
-type GeneralErrorResponse struct {
- Error OpenAIError `json:"error"`
- Message string `json:"message"`
- Msg string `json:"msg"`
- Err string `json:"err"`
- ErrorMsg string `json:"error_msg"`
- Header struct {
- Message string `json:"message"`
- } `json:"header"`
- Response struct {
- Error struct {
- Message string `json:"message"`
- } `json:"error"`
- } `json:"response"`
-}
-
-func (e GeneralErrorResponse) ToMessage() string {
- if e.Error.Message != "" {
- return e.Error.Message
- }
- if e.Message != "" {
- return e.Message
- }
- if e.Msg != "" {
- return e.Msg
- }
- if e.Err != "" {
- return e.Err
- }
- if e.ErrorMsg != "" {
- return e.ErrorMsg
- }
- if e.Header.Message != "" {
- return e.Header.Message
- }
- if e.Response.Error.Message != "" {
- return e.Response.Error.Message
- }
- return ""
-}
-
-func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
- openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
- StatusCode: resp.StatusCode,
- OpenAIError: OpenAIError{
- Message: "",
- Type: "upstream_error",
- Code: "bad_response_status_code",
- Param: strconv.Itoa(resp.StatusCode),
- },
- }
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return
- }
- err = resp.Body.Close()
- if err != nil {
- return
- }
- var errResponse GeneralErrorResponse
- err = json.Unmarshal(responseBody, &errResponse)
- if err != nil {
- return
- }
- if errResponse.Error.Message != "" {
- // OpenAI format error, so we override the default one
- openAIErrorWithStatusCode.OpenAIError = errResponse.Error
- } else {
- openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage()
- }
- if openAIErrorWithStatusCode.OpenAIError.Message == "" {
- openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
- }
- return
-}
-
-func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
- fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
-
- if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
- switch channelType {
- case common.ChannelTypeOpenAI:
- fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
- case common.ChannelTypeAzure:
- fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
- }
- }
- return fullRequestURL
-}
-
-func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
- // quotaDelta is remaining quota to be consumed
- err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
- if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
- }
- err = model.CacheUpdateUserQuota(userId)
- if err != nil {
- common.SysError("error update user quota cache: " + err.Error())
- }
- // totalQuota is total quota consumed
- if totalQuota != 0 {
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
- model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
- model.UpdateChannelUsedQuota(channelId, totalQuota)
- }
- if totalQuota <= 0 {
- common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
- }
-}
-
-func GetAPIVersion(c *gin.Context) string {
- query := c.Request.URL.Query()
- apiVersion := query.Get("api-version")
- if apiVersion == "" {
- apiVersion = c.GetString("api_version")
- }
- return apiVersion
-}
diff --git a/relay/channel/openai/util.go b/relay/channel/openai/util.go
new file mode 100644
index 00000000..69ece6b3
--- /dev/null
+++ b/relay/channel/openai/util.go
@@ -0,0 +1,13 @@
+package openai
+
+func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode {
+ Error := Error{
+ Message: err.Error(),
+ Type: "one_api_error",
+ Code: code,
+ }
+ return &ErrorWithStatusCode{
+ Error: Error,
+ StatusCode: statusCode,
+ }
+}
diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go
new file mode 100644
index 00000000..e9f86aff
--- /dev/null
+++ b/relay/channel/tencent/adaptor.go
@@ -0,0 +1,22 @@
+package tencent
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go
new file mode 100644
index 00000000..64091d62
--- /dev/null
+++ b/relay/channel/tencent/main.go
@@ -0,0 +1,234 @@
+package tencent
+
+import (
+ "bufio"
+ "crypto/hmac"
+ "crypto/sha1"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/common/helper"
+ "one-api/common/logger"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
+ "sort"
+ "strconv"
+ "strings"
+)
+
+// https://cloud.tencent.com/document/product/1729/97732
+
+func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
+ messages := make([]Message, 0, len(request.Messages))
+ for i := 0; i < len(request.Messages); i++ {
+ message := request.Messages[i]
+ if message.Role == "system" {
+ messages = append(messages, Message{
+ Role: "user",
+ Content: message.StringContent(),
+ })
+ messages = append(messages, Message{
+ Role: "assistant",
+ Content: "Okay",
+ })
+ continue
+ }
+ messages = append(messages, Message{
+ Content: message.StringContent(),
+ Role: message.Role,
+ })
+ }
+ stream := 0
+ if request.Stream {
+ stream = 1
+ }
+ return &ChatRequest{
+ Timestamp: helper.GetTimestamp(),
+ Expired: helper.GetTimestamp() + 24*60*60,
+ QueryID: helper.GetUUID(),
+ Temperature: request.Temperature,
+ TopP: request.TopP,
+ Stream: stream,
+ Messages: messages,
+ }
+}
+
+func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
+ fullTextResponse := openai.TextResponse{
+ Object: "chat.completion",
+ Created: helper.GetTimestamp(),
+ Usage: response.Usage,
+ }
+ if len(response.Choices) > 0 {
+ choice := openai.TextResponseChoice{
+ Index: 0,
+ Message: openai.Message{
+ Role: "assistant",
+ Content: response.Choices[0].Messages.Content,
+ },
+ FinishReason: response.Choices[0].FinishReason,
+ }
+ fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
+ }
+ return &fullTextResponse
+}
+
+func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
+ response := openai.ChatCompletionsStreamResponse{
+ Object: "chat.completion.chunk",
+ Created: helper.GetTimestamp(),
+ Model: "tencent-hunyuan",
+ }
+ if len(TencentResponse.Choices) > 0 {
+ var choice openai.ChatCompletionsStreamResponseChoice
+ choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
+ if TencentResponse.Choices[0].FinishReason == "stop" {
+ choice.FinishReason = &constant.StopFinishReason
+ }
+ response.Choices = append(response.Choices, choice)
+ }
+ return &response
+}
+
+func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
+ var responseText string
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+ if atEOF && len(data) == 0 {
+ return 0, nil, nil
+ }
+ if i := strings.Index(string(data), "\n"); i >= 0 {
+ return i + 1, data[0:i], nil
+ }
+ if atEOF {
+ return len(data), data, nil
+ }
+ return 0, nil, nil
+ })
+ dataChan := make(chan string)
+ stopChan := make(chan bool)
+ go func() {
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < 5 { // ignore blank line or wrong format
+ continue
+ }
+ if data[:5] != "data:" {
+ continue
+ }
+ data = data[5:]
+ dataChan <- data
+ }
+ stopChan <- true
+ }()
+ common.SetEventStreamHeaders(c)
+ c.Stream(func(w io.Writer) bool {
+ select {
+ case data := <-dataChan:
+ var TencentResponse ChatResponse
+ err := json.Unmarshal([]byte(data), &TencentResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ return true
+ }
+ response := streamResponseTencent2OpenAI(&TencentResponse)
+ if len(response.Choices) != 0 {
+ responseText += response.Choices[0].Delta.Content
+ }
+ jsonResponse, err := json.Marshal(response)
+ if err != nil {
+ logger.SysError("error marshalling stream response: " + err.Error())
+ return true
+ }
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+ return true
+ case <-stopChan:
+ c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+ return false
+ }
+ })
+ err := resp.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+ }
+ return nil, responseText
+}
+
+func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+ var TencentResponse ChatResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = json.Unmarshal(responseBody, &TencentResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ if TencentResponse.Error.Code != 0 {
+ return &openai.ErrorWithStatusCode{
+ Error: openai.Error{
+ Message: TencentResponse.Error.Message,
+ Code: TencentResponse.Error.Code,
+ },
+ StatusCode: resp.StatusCode,
+ }, nil
+ }
+ fullTextResponse := responseTencent2OpenAI(&TencentResponse)
+ fullTextResponse.Model = "hunyuan"
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ return nil, &fullTextResponse.Usage
+}
+
+func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
+ parts := strings.Split(config, "|")
+ if len(parts) != 3 {
+ err = errors.New("invalid tencent config")
+ return
+ }
+ appId, err = strconv.ParseInt(parts[0], 10, 64)
+ secretId = parts[1]
+ secretKey = parts[2]
+ return
+}
+
+func GetSign(req ChatRequest, secretKey string) string {
+ params := make([]string, 0)
+ params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
+ params = append(params, "secret_id="+req.SecretId)
+ params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
+ params = append(params, "query_id="+req.QueryID)
+ params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
+ params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
+ params = append(params, "stream="+strconv.Itoa(req.Stream))
+ params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
+
+ var messageStr string
+ for _, msg := range req.Messages {
+ messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
+ }
+ messageStr = strings.TrimSuffix(messageStr, ",")
+ params = append(params, "messages=["+messageStr+"]")
+
+ sort.Sort(sort.StringSlice(params))
+ url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
+ mac := hmac.New(sha1.New, []byte(secretKey))
+ signURL := url
+ mac.Write([]byte(signURL))
+ sign := mac.Sum([]byte(nil))
+ return base64.StdEncoding.EncodeToString(sign)
+}
diff --git a/relay/channel/tencent/model.go b/relay/channel/tencent/model.go
new file mode 100644
index 00000000..511f3d97
--- /dev/null
+++ b/relay/channel/tencent/model.go
@@ -0,0 +1,63 @@
+package tencent
+
+import (
+ "one-api/relay/channel/openai"
+)
+
+type Message struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+type ChatRequest struct {
+ AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
+ SecretId string `json:"secret_id"` // 官网 SecretId
+ // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
+ // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
+ Timestamp int64 `json:"timestamp"`
+ // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
+ // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
+ Expired int64 `json:"expired"`
+ QueryID string `json:"query_id"` //请求 Id,用于问题排查
+ // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
+ // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
+ // 建议该参数和 top_p 只设置1个,不要同时更改 top_p
+ Temperature float64 `json:"temperature"`
+ // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
+ // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
+ // 建议该参数和 temperature 只设置1个,不要同时更改
+ TopP float64 `json:"top_p"`
+ // Stream 0:同步,1:流式 (默认,协议:SSE)
+ // 同步请求超时:60s,如果内容较长建议使用流式
+ Stream int `json:"stream"`
+ // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
+ // 输入 content 总数最大支持 3000 token。
+ Messages []Message `json:"messages"`
+}
+
+type Error struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+}
+
+type Usage struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+type ResponseChoices struct {
+ FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
+ Messages Message `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
+ Delta Message `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
+}
+
+type ChatResponse struct {
+ Choices []ResponseChoices `json:"choices,omitempty"` // 结果
+ Created string `json:"created,omitempty"` // unix 时间戳的字符串
+ Id string `json:"id,omitempty"` // 会话 id
+ Usage openai.Usage `json:"usage,omitempty"` // token 数量
+ Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
+ Note string `json:"note,omitempty"` // 注释
+ ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
+}
diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go
new file mode 100644
index 00000000..89e58485
--- /dev/null
+++ b/relay/channel/xunfei/adaptor.go
@@ -0,0 +1,22 @@
+package xunfei
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/controller/relay-xunfei.go b/relay/channel/xunfei/main.go
similarity index 58%
rename from controller/relay-xunfei.go
rename to relay/channel/xunfei/main.go
index 904e6d14..906d2844 100644
--- a/controller/relay-xunfei.go
+++ b/relay/channel/xunfei/main.go
@@ -1,4 +1,4 @@
-package controller
+package xunfei
import (
"crypto/hmac"
@@ -12,6 +12,10 @@ import (
"net/http"
"net/url"
"one-api/common"
+ "one-api/common/helper"
+ "one-api/common/logger"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
"strings"
"time"
)
@@ -19,82 +23,26 @@ import (
// https://console.xfyun.cn/services/cbm
// https://www.xfyun.cn/doc/spark/Web.html
-type XunfeiMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
-}
-
-type XunfeiChatRequest struct {
- Header struct {
- AppId string `json:"app_id"`
- } `json:"header"`
- Parameter struct {
- Chat struct {
- Domain string `json:"domain,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- TopK int `json:"top_k,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Auditing bool `json:"auditing,omitempty"`
- } `json:"chat"`
- } `json:"parameter"`
- Payload struct {
- Message struct {
- Text []XunfeiMessage `json:"text"`
- } `json:"message"`
- } `json:"payload"`
-}
-
-type XunfeiChatResponseTextItem struct {
- Content string `json:"content"`
- Role string `json:"role"`
- Index int `json:"index"`
-}
-
-type XunfeiChatResponse struct {
- Header struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Sid string `json:"sid"`
- Status int `json:"status"`
- } `json:"header"`
- Payload struct {
- Choices struct {
- Status int `json:"status"`
- Seq int `json:"seq"`
- Text []XunfeiChatResponseTextItem `json:"text"`
- } `json:"choices"`
- Usage struct {
- //Text struct {
- // QuestionTokens string `json:"question_tokens"`
- // PromptTokens string `json:"prompt_tokens"`
- // CompletionTokens string `json:"completion_tokens"`
- // TotalTokens string `json:"total_tokens"`
- //} `json:"text"`
- Text Usage `json:"text"`
- } `json:"usage"`
- } `json:"payload"`
-}
-
-func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
- messages := make([]XunfeiMessage, 0, len(request.Messages))
+func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
+ messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
- messages = append(messages, XunfeiMessage{
+ messages = append(messages, Message{
Role: "user",
Content: message.StringContent(),
})
- messages = append(messages, XunfeiMessage{
+ messages = append(messages, Message{
Role: "assistant",
Content: "Okay",
})
} else {
- messages = append(messages, XunfeiMessage{
+ messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
}
- xunfeiRequest := XunfeiChatRequest{}
+ xunfeiRequest := ChatRequest{}
xunfeiRequest.Header.AppId = xunfeiAppId
xunfeiRequest.Parameter.Chat.Domain = domain
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
@@ -104,49 +52,49 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
return &xunfeiRequest
}
-func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
+func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
if len(response.Payload.Choices.Text) == 0 {
- response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
+ response.Payload.Choices.Text = []ChatResponseTextItem{
{
Content: "",
},
}
}
- choice := OpenAITextResponseChoice{
+ choice := openai.TextResponseChoice{
Index: 0,
- Message: Message{
+ Message: openai.Message{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
- FinishReason: stopFinishReason,
+ FinishReason: constant.StopFinishReason,
}
- fullTextResponse := OpenAITextResponse{
+ fullTextResponse := openai.TextResponse{
Object: "chat.completion",
- Created: common.GetTimestamp(),
- Choices: []OpenAITextResponseChoice{choice},
+ Created: helper.GetTimestamp(),
+ Choices: []openai.TextResponseChoice{choice},
Usage: response.Payload.Usage.Text,
}
return &fullTextResponse
}
-func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
+func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
- xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
+ xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{
{
Content: "",
},
}
}
- var choice ChatCompletionsStreamResponseChoice
+ var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
- choice.FinishReason = &stopFinishReason
+ choice.FinishReason = &constant.StopFinishReason
}
- response := ChatCompletionsStreamResponse{
+ response := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "SparkDesk",
- Choices: []ChatCompletionsStreamResponseChoice{choice},
+ Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
@@ -177,14 +125,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
return callUrl
}
-func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
+func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
- return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
- setEventStreamHeaders(c)
- var usage Usage
+ common.SetEventStreamHeaders(c)
+ var usage openai.Usage
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
@@ -194,7 +142,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -207,15 +155,15 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
return nil, &usage
}
-func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
+func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
- return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
- var usage Usage
+ var usage openai.Usage
var content string
- var xunfeiResponse XunfeiChatResponse
+ var xunfeiResponse ChatResponse
stop := false
for !stop {
select {
@@ -231,7 +179,7 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
}
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
- xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
+ xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{
{
Content: "",
},
@@ -242,14 +190,14 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
response := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
}
-func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
+func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
@@ -263,26 +211,26 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId
return nil, nil, err
}
- dataChan := make(chan XunfeiChatResponse)
+ dataChan := make(chan ChatResponse)
stopChan := make(chan bool)
go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
- common.SysError("error reading stream response: " + err.Error())
+ logger.SysError("error reading stream response: " + err.Error())
break
}
- var response XunfeiChatResponse
+ var response ChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
- common.SysError("error closing websocket connection: " + err.Error())
+ logger.SysError("error closing websocket connection: " + err.Error())
}
break
}
@@ -301,7 +249,7 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string,
}
if apiVersion == "" {
apiVersion = "v1.1"
- common.SysLog("api_version not found, use default: " + apiVersion)
+ logger.SysLog("api_version not found, use default: " + apiVersion)
}
domain := "general"
if apiVersion != "v1.1" {
diff --git a/relay/channel/xunfei/model.go b/relay/channel/xunfei/model.go
new file mode 100644
index 00000000..0ca42818
--- /dev/null
+++ b/relay/channel/xunfei/model.go
@@ -0,0 +1,61 @@
+package xunfei
+
+import (
+ "one-api/relay/channel/openai"
+)
+
+type Message struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+type ChatRequest struct {
+ Header struct {
+ AppId string `json:"app_id"`
+ } `json:"header"`
+ Parameter struct {
+ Chat struct {
+ Domain string `json:"domain,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Auditing bool `json:"auditing,omitempty"`
+ } `json:"chat"`
+ } `json:"parameter"`
+ Payload struct {
+ Message struct {
+ Text []Message `json:"text"`
+ } `json:"message"`
+ } `json:"payload"`
+}
+
+type ChatResponseTextItem struct {
+ Content string `json:"content"`
+ Role string `json:"role"`
+ Index int `json:"index"`
+}
+
+type ChatResponse struct {
+ Header struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Sid string `json:"sid"`
+ Status int `json:"status"`
+ } `json:"header"`
+ Payload struct {
+ Choices struct {
+ Status int `json:"status"`
+ Seq int `json:"seq"`
+ Text []ChatResponseTextItem `json:"text"`
+ } `json:"choices"`
+ Usage struct {
+ //Text struct {
+ // QuestionTokens string `json:"question_tokens"`
+ // PromptTokens string `json:"prompt_tokens"`
+ // CompletionTokens string `json:"completion_tokens"`
+ // TotalTokens string `json:"total_tokens"`
+ //} `json:"text"`
+ Text openai.Usage `json:"text"`
+ } `json:"usage"`
+ } `json:"payload"`
+}
diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go
new file mode 100644
index 00000000..6d901bc3
--- /dev/null
+++ b/relay/channel/zhipu/adaptor.go
@@ -0,0 +1,22 @@
+package zhipu
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/controller/relay-zhipu.go b/relay/channel/zhipu/main.go
similarity index 58%
rename from controller/relay-zhipu.go
rename to relay/channel/zhipu/main.go
index cb5a78cf..d831f57a 100644
--- a/controller/relay-zhipu.go
+++ b/relay/channel/zhipu/main.go
@@ -1,4 +1,4 @@
-package controller
+package zhipu
import (
"bufio"
@@ -8,6 +8,10 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/helper"
+ "one-api/common/logger"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
"strings"
"sync"
"time"
@@ -18,53 +22,13 @@ import (
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
-type ZhipuMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
-}
-
-type ZhipuRequest struct {
- Prompt []ZhipuMessage `json:"prompt"`
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- RequestId string `json:"request_id,omitempty"`
- Incremental bool `json:"incremental,omitempty"`
-}
-
-type ZhipuResponseData struct {
- TaskId string `json:"task_id"`
- RequestId string `json:"request_id"`
- TaskStatus string `json:"task_status"`
- Choices []ZhipuMessage `json:"choices"`
- Usage `json:"usage"`
-}
-
-type ZhipuResponse struct {
- Code int `json:"code"`
- Msg string `json:"msg"`
- Success bool `json:"success"`
- Data ZhipuResponseData `json:"data"`
-}
-
-type ZhipuStreamMetaResponse struct {
- RequestId string `json:"request_id"`
- TaskId string `json:"task_id"`
- TaskStatus string `json:"task_status"`
- Usage `json:"usage"`
-}
-
-type zhipuTokenData struct {
- Token string
- ExpiryTime time.Time
-}
-
var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600
-func getZhipuToken(apikey string) string {
+func GetToken(apikey string) string {
data, ok := zhipuTokens.Load(apikey)
if ok {
- tokenData := data.(zhipuTokenData)
+ tokenData := data.(tokenData)
if time.Now().Before(tokenData.ExpiryTime) {
return tokenData.Token
}
@@ -72,7 +36,7 @@ func getZhipuToken(apikey string) string {
split := strings.Split(apikey, ".")
if len(split) != 2 {
- common.SysError("invalid zhipu key: " + apikey)
+ logger.SysError("invalid zhipu key: " + apikey)
return ""
}
@@ -100,7 +64,7 @@ func getZhipuToken(apikey string) string {
return ""
}
- zhipuTokens.Store(apikey, zhipuTokenData{
+ zhipuTokens.Store(apikey, tokenData{
Token: tokenString,
ExpiryTime: expiryTime,
})
@@ -108,26 +72,26 @@ func getZhipuToken(apikey string) string {
return tokenString
}
-func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
- messages := make([]ZhipuMessage, 0, len(request.Messages))
+func ConvertRequest(request openai.GeneralOpenAIRequest) *Request {
+ messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
- messages = append(messages, ZhipuMessage{
+ messages = append(messages, Message{
Role: "system",
Content: message.StringContent(),
})
- messages = append(messages, ZhipuMessage{
+ messages = append(messages, Message{
Role: "user",
Content: "Okay",
})
} else {
- messages = append(messages, ZhipuMessage{
+ messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
}
- return &ZhipuRequest{
+ return &Request{
Prompt: messages,
Temperature: request.Temperature,
TopP: request.TopP,
@@ -135,18 +99,18 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
}
}
-func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
- fullTextResponse := OpenAITextResponse{
+func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
+ fullTextResponse := openai.TextResponse{
Id: response.Data.TaskId,
Object: "chat.completion",
- Created: common.GetTimestamp(),
- Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
+ Created: helper.GetTimestamp(),
+ Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)),
Usage: response.Data.Usage,
}
for i, choice := range response.Data.Choices {
- openaiChoice := OpenAITextResponseChoice{
+ openaiChoice := openai.TextResponseChoice{
Index: i,
- Message: Message{
+ Message: openai.Message{
Role: choice.Role,
Content: strings.Trim(choice.Content, "\""),
},
@@ -160,34 +124,34 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
return &fullTextResponse
}
-func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
- var choice ChatCompletionsStreamResponseChoice
+func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse {
+ var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse
- response := ChatCompletionsStreamResponse{
+ response := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "chatglm",
- Choices: []ChatCompletionsStreamResponseChoice{choice},
+ Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
-func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
- var choice ChatCompletionsStreamResponseChoice
+func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) {
+ var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = ""
- choice.FinishReason = &stopFinishReason
- response := ChatCompletionsStreamResponse{
+ choice.FinishReason = &constant.StopFinishReason
+ response := openai.ChatCompletionsStreamResponse{
Id: zhipuResponse.RequestId,
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "chatglm",
- Choices: []ChatCompletionsStreamResponseChoice{choice},
+ Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response, &zhipuResponse.Usage
}
-func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
- var usage *Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+ var usage *openai.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -224,29 +188,29 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
}
stopChan <- true
}()
- setEventStreamHeaders(c)
+ common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case data := <-metaChan:
- var zhipuResponse ZhipuStreamMetaResponse
+ var zhipuResponse StreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
@@ -259,28 +223,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
})
err := resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}
-func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
- var zhipuResponse ZhipuResponse
+func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+ var zhipuResponse Response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if !zhipuResponse.Success {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: OpenAIError{
+ return &openai.ErrorWithStatusCode{
+ Error: openai.Error{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",
@@ -293,7 +257,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
fullTextResponse.Model = "chatglm"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
diff --git a/relay/channel/zhipu/model.go b/relay/channel/zhipu/model.go
new file mode 100644
index 00000000..08a5ec5f
--- /dev/null
+++ b/relay/channel/zhipu/model.go
@@ -0,0 +1,46 @@
+package zhipu
+
+import (
+ "one-api/relay/channel/openai"
+ "time"
+)
+
+type Message struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+type Request struct {
+ Prompt []Message `json:"prompt"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ RequestId string `json:"request_id,omitempty"`
+ Incremental bool `json:"incremental,omitempty"`
+}
+
+type ResponseData struct {
+ TaskId string `json:"task_id"`
+ RequestId string `json:"request_id"`
+ TaskStatus string `json:"task_status"`
+ Choices []Message `json:"choices"`
+ openai.Usage `json:"usage"`
+}
+
+type Response struct {
+ Code int `json:"code"`
+ Msg string `json:"msg"`
+ Success bool `json:"success"`
+ Data ResponseData `json:"data"`
+}
+
+type StreamMetaResponse struct {
+ RequestId string `json:"request_id"`
+ TaskId string `json:"task_id"`
+ TaskStatus string `json:"task_status"`
+ openai.Usage `json:"usage"`
+}
+
+type tokenData struct {
+ Token string
+ ExpiryTime time.Time
+}
diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go
new file mode 100644
index 00000000..658bfb90
--- /dev/null
+++ b/relay/constant/api_type.go
@@ -0,0 +1,69 @@
+package constant
+
+import (
+ "one-api/common"
+)
+
+const (
+ APITypeOpenAI = iota
+ APITypeClaude
+ APITypePaLM
+ APITypeBaidu
+ APITypeZhipu
+ APITypeAli
+ APITypeXunfei
+ APITypeAIProxyLibrary
+ APITypeTencent
+ APITypeGemini
+)
+
+func ChannelType2APIType(channelType int) int {
+ apiType := APITypeOpenAI
+ switch channelType {
+ case common.ChannelTypeAnthropic:
+ apiType = APITypeClaude
+ case common.ChannelTypeBaidu:
+ apiType = APITypeBaidu
+ case common.ChannelTypePaLM:
+ apiType = APITypePaLM
+ case common.ChannelTypeZhipu:
+ apiType = APITypeZhipu
+ case common.ChannelTypeAli:
+ apiType = APITypeAli
+ case common.ChannelTypeXunfei:
+ apiType = APITypeXunfei
+ case common.ChannelTypeAIProxyLibrary:
+ apiType = APITypeAIProxyLibrary
+ case common.ChannelTypeTencent:
+ apiType = APITypeTencent
+ case common.ChannelTypeGemini:
+ apiType = APITypeGemini
+ }
+ return apiType
+}
+
+//func GetAdaptor(apiType int) channel.Adaptor {
+// switch apiType {
+// case APITypeOpenAI:
+// return &openai.Adaptor{}
+// case APITypeClaude:
+// return &anthropic.Adaptor{}
+// case APITypePaLM:
+// return &google.Adaptor{}
+// case APITypeZhipu:
+// return &baidu.Adaptor{}
+// case APITypeBaidu:
+// return &baidu.Adaptor{}
+// case APITypeAli:
+// return &ali.Adaptor{}
+// case APITypeXunfei:
+// return &xunfei.Adaptor{}
+// case APITypeAIProxyLibrary:
+// return &aiproxy.Adaptor{}
+// case APITypeTencent:
+// return &tencent.Adaptor{}
+// case APITypeGemini:
+// return &google.Adaptor{}
+// }
+// return nil
+//}
diff --git a/relay/constant/common.go b/relay/constant/common.go
new file mode 100644
index 00000000..b6606cc6
--- /dev/null
+++ b/relay/constant/common.go
@@ -0,0 +1,3 @@
+package constant
+
+var StopFinishReason = "stop"
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
new file mode 100644
index 00000000..5e2fe574
--- /dev/null
+++ b/relay/constant/relay_mode.go
@@ -0,0 +1,42 @@
+package constant
+
+import "strings"
+
+const (
+ RelayModeUnknown = iota
+ RelayModeChatCompletions
+ RelayModeCompletions
+ RelayModeEmbeddings
+ RelayModeModerations
+ RelayModeImagesGenerations
+ RelayModeEdits
+ RelayModeAudioSpeech
+ RelayModeAudioTranscription
+ RelayModeAudioTranslation
+)
+
+func Path2RelayMode(path string) int {
+ relayMode := RelayModeUnknown
+ if strings.HasPrefix(path, "/v1/chat/completions") {
+ relayMode = RelayModeChatCompletions
+ } else if strings.HasPrefix(path, "/v1/completions") {
+ relayMode = RelayModeCompletions
+ } else if strings.HasPrefix(path, "/v1/embeddings") {
+ relayMode = RelayModeEmbeddings
+ } else if strings.HasSuffix(path, "embeddings") {
+ relayMode = RelayModeEmbeddings
+ } else if strings.HasPrefix(path, "/v1/moderations") {
+ relayMode = RelayModeModerations
+ } else if strings.HasPrefix(path, "/v1/images/generations") {
+ relayMode = RelayModeImagesGenerations
+ } else if strings.HasPrefix(path, "/v1/edits") {
+ relayMode = RelayModeEdits
+ } else if strings.HasPrefix(path, "/v1/audio/speech") {
+ relayMode = RelayModeAudioSpeech
+ } else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
+ relayMode = RelayModeAudioTranscription
+ } else if strings.HasPrefix(path, "/v1/audio/translations") {
+ relayMode = RelayModeAudioTranslation
+ }
+ return relayMode
+}
diff --git a/controller/relay-audio.go b/relay/controller/audio.go
similarity index 63%
rename from controller/relay-audio.go
rename to relay/controller/audio.go
index 2247f4c7..822d7e39 100644
--- a/controller/relay-audio.go
+++ b/relay/controller/audio.go
@@ -11,11 +11,16 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
+ "one-api/common/logger"
"one-api/model"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
+ "one-api/relay/util"
"strings"
)
-func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
+func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
audioModel := "whisper-1"
tokenId := c.GetInt("token_id")
@@ -25,18 +30,18 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
group := c.GetString("group")
tokenName := c.GetString("token_name")
- var ttsRequest TextToSpeechRequest
- if relayMode == RelayModeAudioSpeech {
+ var ttsRequest openai.TextToSpeechRequest
+ if relayMode == constant.RelayModeAudioSpeech {
// Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid
if err != nil {
- return errorWrapper(err, "invalid_json", http.StatusBadRequest)
+ return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest)
}
audioModel = ttsRequest.Model
// Check if text is too long 4096
if len(ttsRequest.Input) > 4096 {
- return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
+ return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
}
}
@@ -46,24 +51,24 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
var quota int
var preConsumedQuota int
switch relayMode {
- case RelayModeAudioSpeech:
+ case constant.RelayModeAudioSpeech:
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
default:
- preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
+ preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio)
}
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
- return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
// Check if user quota is enough
if userQuota-preConsumedQuota < 0 {
- return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+ return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
- return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
@@ -73,7 +78,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
- return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+ return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
@@ -83,7 +88,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
- return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[audioModel] != "" {
audioModel = modelMap[audioModel]
@@ -96,27 +101,27 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
baseURL = c.GetString("base_url")
}
- fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
- if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
+ fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
+ if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
- apiVersion := GetAPIVersion(c)
+ apiVersion := util.GetAzureAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
}
requestBody := &bytes.Buffer{}
_, err = io.Copy(requestBody, c.Request.Body)
if err != nil {
- return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
responseFormat := c.DefaultPostForm("response_format", "json")
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
- return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
- if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
+ if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
@@ -128,34 +133,34 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- resp, err := httpClient.Do(req)
+ resp, err := util.HTTPClient.Do(req)
if err != nil {
- return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
- return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
- return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
- if relayMode != RelayModeAudioSpeech {
+ if relayMode != constant.RelayModeAudioSpeech {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
- var openAIErr TextResponse
+ var openAIErr openai.SlimTextResponse
if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
if openAIErr.Error.Message != "" {
- return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
+ return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
}
}
@@ -172,12 +177,12 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
case "vtt":
text, err = getTextFromVTT(responseBody)
default:
- return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
+ return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
}
if err != nil {
- return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
}
- quota = countTokenText(text, audioModel)
+ quota = openai.CountTokenText(text, audioModel)
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
@@ -188,16 +193,16 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
// negative means add quota back for token & user
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
+ logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
}
}()
}(c.Request.Context())
}
- return relayErrorHandler(resp)
+ return util.RelayErrorHandler(resp)
}
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {
- go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
+ go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context())
for k, v := range resp.Header {
@@ -207,11 +212,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
- return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}
@@ -221,7 +226,7 @@ func getTextFromVTT(body []byte) (string, error) {
}
func getTextFromVerboseJSON(body []byte) (string, error) {
- var whisperResponse WhisperVerboseJSONResponse
+ var whisperResponse openai.WhisperVerboseJSONResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
}
@@ -254,7 +259,7 @@ func getTextFromText(body []byte) (string, error) {
}
func getTextFromJSON(body []byte) (string, error) {
- var whisperResponse WhisperJSONResponse
+ var whisperResponse openai.WhisperJSONResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
}
diff --git a/controller/relay-image.go b/relay/controller/image.go
similarity index 68%
rename from controller/relay-image.go
rename to relay/controller/image.go
index 14a2983b..9502a4d7 100644
--- a/controller/relay-image.go
+++ b/relay/controller/image.go
@@ -9,7 +9,10 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/model"
+ "one-api/relay/channel/openai"
+ "one-api/relay/util"
"strings"
"github.com/gin-gonic/gin"
@@ -25,7 +28,7 @@ func isWithinRange(element string, value int) bool {
return value >= min && value <= max
}
-func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
+func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
imageModel := "dall-e-2"
imageSize := "1024x1024"
@@ -35,10 +38,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
userId := c.GetInt("id")
group := c.GetString("group")
- var imageRequest ImageRequest
+ var imageRequest openai.ImageRequest
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
- return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
+ return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if imageRequest.N == 0 {
@@ -67,24 +70,24 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
}
} else {
- return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
+ return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
// Prompt validation
if imageRequest.Prompt == "" {
- return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
+ return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
// Check prompt length
if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
- return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
+ return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if isWithinRange(imageModel, imageRequest.N) == false {
// channel not azure
if channelType != common.ChannelTypeAzure {
- return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
+ return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
}
@@ -95,7 +98,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
- return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
@@ -107,10 +110,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
- fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
+ fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
- apiVersion := GetAPIVersion(c)
+ apiVersion := util.GetAzureAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
}
@@ -119,7 +122,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
- return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
@@ -134,12 +137,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
if userQuota-quota < 0 {
- return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+ return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
- return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
token := c.Request.Header.Get("Authorization")
if channelType == common.ChannelTypeAzure { // Azure authentication
@@ -152,20 +155,20 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- resp, err := httpClient.Do(req)
+ resp, err := util.HTTPClient.Do(req)
if err != nil {
- return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
- return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
- return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
- var textResponse ImageResponse
+ var textResponse openai.ImageResponse
defer func(ctx context.Context) {
if resp.StatusCode != http.StatusOK {
@@ -173,11 +176,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
+ logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
- common.SysError("error update user quota cache: " + err.Error())
+ logger.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
@@ -192,15 +195,15 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
@@ -212,11 +215,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
- return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}
diff --git a/relay/controller/text.go b/relay/controller/text.go
new file mode 100644
index 00000000..68354628
--- /dev/null
+++ b/relay/controller/text.go
@@ -0,0 +1,173 @@
+package controller
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "math"
+ "net/http"
+ "one-api/common"
+ "one-api/common/config"
+ "one-api/common/logger"
+ "one-api/model"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
+ "one-api/relay/util"
+ "strings"
+)
+
+func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
+ ctx := c.Request.Context()
+ meta := util.GetRelayMeta(c)
+ var textRequest openai.GeneralOpenAIRequest
+ err := common.UnmarshalBodyReusable(c, &textRequest)
+ if err != nil {
+ return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
+ }
+ if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
+ textRequest.Model = "text-moderation-latest"
+ }
+ if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
+ textRequest.Model = c.Param("model")
+ }
+ err = util.ValidateTextRequest(&textRequest, relayMode)
+ if err != nil {
+ return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
+ }
+ var isModelMapped bool
+ textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
+ apiType := constant.ChannelType2APIType(meta.ChannelType)
+ fullRequestURL, err := GetRequestURL(c.Request.URL.String(), apiType, relayMode, meta, &textRequest)
+ if err != nil {
+ logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
+ return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
+ }
+ var promptTokens int
+ var completionTokens int
+ switch relayMode {
+ case constant.RelayModeChatCompletions:
+ promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
+ case constant.RelayModeCompletions:
+ promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
+ case constant.RelayModeModerations:
+ promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model)
+ }
+ preConsumedTokens := config.PreConsumedQuota
+ if textRequest.MaxTokens != 0 {
+ preConsumedTokens = promptTokens + textRequest.MaxTokens
+ }
+ modelRatio := common.GetModelRatio(textRequest.Model)
+ groupRatio := common.GetGroupRatio(meta.Group)
+ ratio := modelRatio * groupRatio
+ preConsumedQuota := int(float64(preConsumedTokens) * ratio)
+ userQuota, err := model.CacheGetUserQuota(meta.UserId)
+ if err != nil {
+ return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+ }
+ if userQuota-preConsumedQuota < 0 {
+ return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+ }
+ err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
+ if err != nil {
+ return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+ }
+ if userQuota > 100*preConsumedQuota {
+ // in this case, we do not pre-consume quota
+ // because the user has enough quota
+ preConsumedQuota = 0
+ logger.Info(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
+ }
+ if preConsumedQuota > 0 {
+ err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
+ if err != nil {
+ return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+ }
+ }
+ requestBody, err := GetRequestBody(c, textRequest, isModelMapped, apiType, relayMode)
+ if err != nil {
+ return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError)
+ }
+ var req *http.Request
+ var resp *http.Response
+ isStream := textRequest.Stream
+
+ if apiType != constant.APITypeXunfei { // cause xunfei use websocket
+ req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
+ }
+ SetupRequestHeaders(c, req, apiType, meta, isStream)
+ resp, err = util.HTTPClient.Do(req)
+ if err != nil {
+ return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+ }
+ err = req.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+ }
+ err = c.Request.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+ }
+ isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
+
+ if resp.StatusCode != http.StatusOK {
+ util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
+ return util.RelayErrorHandler(resp)
+ }
+ }
+
+ var respErr *openai.ErrorWithStatusCode
+ var usage *openai.Usage
+
+ defer func(ctx context.Context) {
+ // Why we use defer here? Because if error happened, we will have to return the pre-consumed quota.
+ if respErr != nil {
+ logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
+ util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
+ return
+ }
+ if usage == nil {
+ logger.Error(ctx, "usage is nil, which is unexpected")
+ return
+ }
+
+ go func() {
+ quota := 0
+ completionRatio := common.GetCompletionRatio(textRequest.Model)
+ promptTokens = usage.PromptTokens
+ completionTokens = usage.CompletionTokens
+ quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
+ if ratio != 0 && quota <= 0 {
+ quota = 1
+ }
+ totalTokens := promptTokens + completionTokens
+ if totalTokens == 0 {
+ // in this case, must be some error happened
+ // we cannot just return, because we may have to return the pre-consumed quota
+ quota = 0
+ }
+ quotaDelta := quota - preConsumedQuota
+ err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
+ if err != nil {
+ logger.Error(ctx, "error consuming token remain quota: "+err.Error())
+ }
+ err = model.CacheUpdateUserQuota(meta.UserId)
+ if err != nil {
+ logger.Error(ctx, "error update user quota cache: "+err.Error())
+ }
+ if quota != 0 {
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+ model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
+ model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
+ model.UpdateChannelUsedQuota(meta.ChannelId, quota)
+ }
+ }()
+ }(ctx)
+ usage, respErr = DoResponse(c, &textRequest, resp, relayMode, apiType, isStream, promptTokens)
+ if respErr != nil {
+ return respErr
+ }
+ return nil
+}
diff --git a/relay/controller/util.go b/relay/controller/util.go
new file mode 100644
index 00000000..02f1b30f
--- /dev/null
+++ b/relay/controller/util.go
@@ -0,0 +1,337 @@
+package controller
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/common/helper"
+ "one-api/relay/channel/aiproxy"
+ "one-api/relay/channel/ali"
+ "one-api/relay/channel/anthropic"
+ "one-api/relay/channel/baidu"
+ "one-api/relay/channel/google"
+ "one-api/relay/channel/openai"
+ "one-api/relay/channel/tencent"
+ "one-api/relay/channel/xunfei"
+ "one-api/relay/channel/zhipu"
+ "one-api/relay/constant"
+ "one-api/relay/util"
+ "strings"
+)
+
+func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) {
+ fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
+ switch apiType {
+ case constant.APITypeOpenAI:
+ if meta.ChannelType == common.ChannelTypeAzure {
+ // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
+ requestURL := strings.Split(requestURL, "?")[0]
+ requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
+ task := strings.TrimPrefix(requestURL, "/v1/")
+ model_ := textRequest.Model
+ model_ = strings.Replace(model_, ".", "", -1)
+ // https://github.com/songquanpeng/one-api/issues/67
+ model_ = strings.TrimSuffix(model_, "-0301")
+ model_ = strings.TrimSuffix(model_, "-0314")
+ model_ = strings.TrimSuffix(model_, "-0613")
+
+ requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
+ fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
+ }
+ case constant.APITypeClaude:
+ fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL)
+ case constant.APITypeBaidu:
+ switch textRequest.Model {
+ case "ERNIE-Bot":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
+ case "ERNIE-Bot-turbo":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
+ case "ERNIE-Bot-4":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
+ case "BLOOMZ-7B":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
+ case "Embedding-V1":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
+ }
+ var accessToken string
+ var err error
+ if accessToken, err = baidu.GetAccessToken(meta.APIKey); err != nil {
+ return "", fmt.Errorf("failed to get baidu access token: %w", err)
+ }
+ fullRequestURL += "?access_token=" + accessToken
+ case constant.APITypePaLM:
+ fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL)
+ case constant.APITypeGemini:
+ version := helper.AssignOrDefault(meta.APIVersion, "v1")
+ action := "generateContent"
+ if textRequest.Stream {
+ action = "streamGenerateContent"
+ }
+ fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, textRequest.Model, action)
+ case constant.APITypeZhipu:
+ method := "invoke"
+ if textRequest.Stream {
+ method = "sse-invoke"
+ }
+ fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
+ case constant.APITypeAli:
+ fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
+ if relayMode == constant.RelayModeEmbeddings {
+ fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
+ }
+ case constant.APITypeTencent:
+ fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
+ case constant.APITypeAIProxyLibrary:
+ fullRequestURL = fmt.Sprintf("%s/api/library/ask", meta.BaseURL)
+ }
+ return fullRequestURL, nil
+}
+
+func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
+ var requestBody io.Reader
+ if isModelMapped {
+ jsonStr, err := json.Marshal(textRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ } else {
+ requestBody = c.Request.Body
+ }
+ switch apiType {
+ case constant.APITypeClaude:
+ claudeRequest := anthropic.ConvertRequest(textRequest)
+ jsonStr, err := json.Marshal(claudeRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeBaidu:
+ var jsonData []byte
+ var err error
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest)
+ jsonData, err = json.Marshal(baiduEmbeddingRequest)
+ default:
+ baiduRequest := baidu.ConvertRequest(textRequest)
+ jsonData, err = json.Marshal(baiduRequest)
+ }
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonData)
+ case constant.APITypePaLM:
+ palmRequest := google.ConvertPaLMRequest(textRequest)
+ jsonStr, err := json.Marshal(palmRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeGemini:
+ geminiChatRequest := google.ConvertGeminiRequest(textRequest)
+ jsonStr, err := json.Marshal(geminiChatRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeZhipu:
+ zhipuRequest := zhipu.ConvertRequest(textRequest)
+ jsonStr, err := json.Marshal(zhipuRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeAli:
+ var jsonStr []byte
+ var err error
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest)
+ jsonStr, err = json.Marshal(aliEmbeddingRequest)
+ default:
+ aliRequest := ali.ConvertRequest(textRequest)
+ jsonStr, err = json.Marshal(aliRequest)
+ }
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeTencent:
+ apiKey := c.Request.Header.Get("Authorization")
+ apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+ appId, secretId, secretKey, err := tencent.ParseConfig(apiKey)
+ if err != nil {
+ return nil, err
+ }
+ tencentRequest := tencent.ConvertRequest(textRequest)
+ tencentRequest.AppId = appId
+ tencentRequest.SecretId = secretId
+ jsonStr, err := json.Marshal(tencentRequest)
+ if err != nil {
+ return nil, err
+ }
+ sign := tencent.GetSign(*tencentRequest, secretKey)
+ c.Request.Header.Set("Authorization", sign)
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeAIProxyLibrary:
+ aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest)
+ aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
+ jsonStr, err := json.Marshal(aiProxyLibraryRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ }
+ return requestBody, nil
+}
+
+func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) {
+ SetupAuthHeaders(c, req, apiType, meta, isStream)
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+ if isStream && c.Request.Header.Get("Accept") == "" {
+ req.Header.Set("Accept", "text/event-stream")
+ }
+}
+
+func SetupAuthHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) {
+ apiKey := meta.APIKey
+ switch apiType {
+ case constant.APITypeOpenAI:
+ if meta.ChannelType == common.ChannelTypeAzure {
+ req.Header.Set("api-key", apiKey)
+ } else {
+ req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+ if meta.ChannelType == common.ChannelTypeOpenRouter {
+ req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
+ req.Header.Set("X-Title", "One API")
+ }
+ }
+ case constant.APITypeClaude:
+ req.Header.Set("x-api-key", apiKey)
+ anthropicVersion := c.Request.Header.Get("anthropic-version")
+ if anthropicVersion == "" {
+ anthropicVersion = "2023-06-01"
+ }
+ req.Header.Set("anthropic-version", anthropicVersion)
+ case constant.APITypeZhipu:
+ token := zhipu.GetToken(apiKey)
+ req.Header.Set("Authorization", token)
+ case constant.APITypeAli:
+ req.Header.Set("Authorization", "Bearer "+apiKey)
+ if isStream {
+ req.Header.Set("X-DashScope-SSE", "enable")
+ }
+ if c.GetString("plugin") != "" {
+ req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
+ }
+ case constant.APITypeTencent:
+ req.Header.Set("Authorization", apiKey)
+ case constant.APITypePaLM:
+ req.Header.Set("x-goog-api-key", apiKey)
+ case constant.APITypeGemini:
+ req.Header.Set("x-goog-api-key", apiKey)
+ default:
+ req.Header.Set("Authorization", "Bearer "+apiKey)
+ }
+}
+
+func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) {
+ var responseText string
+ switch apiType {
+ case constant.APITypeOpenAI:
+ if isStream {
+ err, responseText = openai.StreamHandler(c, resp, relayMode)
+ } else {
+ err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model)
+ }
+ case constant.APITypeClaude:
+ if isStream {
+ err, responseText = anthropic.StreamHandler(c, resp)
+ } else {
+ err, usage = anthropic.Handler(c, resp, promptTokens, textRequest.Model)
+ }
+ case constant.APITypeBaidu:
+ if isStream {
+ err, usage = baidu.StreamHandler(c, resp)
+ } else {
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ err, usage = baidu.EmbeddingHandler(c, resp)
+ default:
+ err, usage = baidu.Handler(c, resp)
+ }
+ }
+ case constant.APITypePaLM:
+ if isStream { // PaLM2 API does not support stream
+ err, responseText = google.PaLMStreamHandler(c, resp)
+ } else {
+ err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
+ }
+ case constant.APITypeGemini:
+ if isStream {
+ err, responseText = google.StreamHandler(c, resp)
+ } else {
+ err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
+ }
+ case constant.APITypeZhipu:
+ if isStream {
+ err, usage = zhipu.StreamHandler(c, resp)
+ } else {
+ err, usage = zhipu.Handler(c, resp)
+ }
+ case constant.APITypeAli:
+ if isStream {
+ err, usage = ali.StreamHandler(c, resp)
+ } else {
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ err, usage = ali.EmbeddingHandler(c, resp)
+ default:
+ err, usage = ali.Handler(c, resp)
+ }
+ }
+ case constant.APITypeXunfei:
+ auth := c.Request.Header.Get("Authorization")
+ auth = strings.TrimPrefix(auth, "Bearer ")
+ splits := strings.Split(auth, "|")
+ if len(splits) != 3 {
+ return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
+ }
+ if isStream {
+ err, usage = xunfei.StreamHandler(c, *textRequest, splits[0], splits[1], splits[2])
+ } else {
+ err, usage = xunfei.Handler(c, *textRequest, splits[0], splits[1], splits[2])
+ }
+ case constant.APITypeAIProxyLibrary:
+ if isStream {
+ err, usage = aiproxy.StreamHandler(c, resp)
+ } else {
+ err, usage = aiproxy.Handler(c, resp)
+ }
+ case constant.APITypeTencent:
+ if isStream {
+ err, responseText = tencent.StreamHandler(c, resp)
+ } else {
+ err, usage = tencent.Handler(c, resp)
+ }
+ default:
+ return nil, openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
+ }
+ if err != nil {
+ return nil, err
+ }
+ if usage == nil && responseText != "" {
+ usage = &openai.Usage{}
+ usage.PromptTokens = promptTokens
+ usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ }
+ return usage, nil
+}
diff --git a/relay/util/billing.go b/relay/util/billing.go
new file mode 100644
index 00000000..35fb28a4
--- /dev/null
+++ b/relay/util/billing.go
@@ -0,0 +1,19 @@
+package util
+
+import (
+ "context"
+ "one-api/common/logger"
+ "one-api/model"
+)
+
+func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) {
+ if preConsumedQuota != 0 {
+ go func(ctx context.Context) {
+ // return pre-consumed quota
+ err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
+ if err != nil {
+ logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
+ }
+ }(ctx)
+ }
+}
diff --git a/relay/util/common.go b/relay/util/common.go
new file mode 100644
index 00000000..be31857b
--- /dev/null
+++ b/relay/util/common.go
@@ -0,0 +1,168 @@
+package util
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/common/config"
+ "one-api/common/logger"
+ "one-api/model"
+ "one-api/relay/channel/openai"
+ "strconv"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
+ if !config.AutomaticDisableChannelEnabled {
+ return false
+ }
+ if err == nil {
+ return false
+ }
+ if statusCode == http.StatusUnauthorized {
+ return true
+ }
+ if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
+ return true
+ }
+ return false
+}
+
+func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
+ if !config.AutomaticEnableChannelEnabled {
+ return false
+ }
+ if err != nil {
+ return false
+ }
+ if openAIErr != nil {
+ return false
+ }
+ return true
+}
+
+type GeneralErrorResponse struct {
+ Error openai.Error `json:"error"`
+ Message string `json:"message"`
+ Msg string `json:"msg"`
+ Err string `json:"err"`
+ ErrorMsg string `json:"error_msg"`
+ Header struct {
+ Message string `json:"message"`
+ } `json:"header"`
+ Response struct {
+ Error struct {
+ Message string `json:"message"`
+ } `json:"error"`
+ } `json:"response"`
+}
+
+func (e GeneralErrorResponse) ToMessage() string {
+ if e.Error.Message != "" {
+ return e.Error.Message
+ }
+ if e.Message != "" {
+ return e.Message
+ }
+ if e.Msg != "" {
+ return e.Msg
+ }
+ if e.Err != "" {
+ return e.Err
+ }
+ if e.ErrorMsg != "" {
+ return e.ErrorMsg
+ }
+ if e.Header.Message != "" {
+ return e.Header.Message
+ }
+ if e.Response.Error.Message != "" {
+ return e.Response.Error.Message
+ }
+ return ""
+}
+
+func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) {
+ ErrorWithStatusCode = &openai.ErrorWithStatusCode{
+ StatusCode: resp.StatusCode,
+ Error: openai.Error{
+ Message: "",
+ Type: "upstream_error",
+ Code: "bad_response_status_code",
+ Param: strconv.Itoa(resp.StatusCode),
+ },
+ }
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return
+ }
+ var errResponse GeneralErrorResponse
+ err = json.Unmarshal(responseBody, &errResponse)
+ if err != nil {
+ return
+ }
+ if errResponse.Error.Message != "" {
+ // OpenAI format error, so we override the default one
+ ErrorWithStatusCode.Error = errResponse.Error
+ } else {
+ ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
+ }
+ if ErrorWithStatusCode.Error.Message == "" {
+ ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
+ }
+ return
+}
+
+func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+
+ if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
+ switch channelType {
+ case common.ChannelTypeOpenAI:
+ fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
+ case common.ChannelTypeAzure:
+ fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
+ }
+ }
+ return fullRequestURL
+}
+
+func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
+ // quotaDelta is remaining quota to be consumed
+ err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
+ if err != nil {
+ logger.SysError("error consuming token remain quota: " + err.Error())
+ }
+ err = model.CacheUpdateUserQuota(userId)
+ if err != nil {
+ logger.SysError("error update user quota cache: " + err.Error())
+ }
+ // totalQuota is total quota consumed
+ if totalQuota != 0 {
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+ model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
+ model.UpdateChannelUsedQuota(channelId, totalQuota)
+ }
+ if totalQuota <= 0 {
+ logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
+ }
+}
+
+func GetAzureAPIVersion(c *gin.Context) string {
+ query := c.Request.URL.Query()
+ apiVersion := query.Get("api-version")
+ if apiVersion == "" {
+ apiVersion = c.GetString("api_version")
+ }
+ return apiVersion
+}
diff --git a/relay/util/init.go b/relay/util/init.go
new file mode 100644
index 00000000..62d44d15
--- /dev/null
+++ b/relay/util/init.go
@@ -0,0 +1,24 @@
+package util
+
+import (
+ "net/http"
+ "one-api/common/config"
+ "time"
+)
+
+var HTTPClient *http.Client
+var ImpatientHTTPClient *http.Client
+
+func init() {
+ if config.RelayTimeout == 0 {
+ HTTPClient = &http.Client{}
+ } else {
+ HTTPClient = &http.Client{
+ Timeout: time.Duration(config.RelayTimeout) * time.Second,
+ }
+ }
+
+ ImpatientHTTPClient = &http.Client{
+ Timeout: 5 * time.Second,
+ }
+}
diff --git a/relay/util/model_mapping.go b/relay/util/model_mapping.go
new file mode 100644
index 00000000..39e062a1
--- /dev/null
+++ b/relay/util/model_mapping.go
@@ -0,0 +1,12 @@
+package util
+
+func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) {
+ if mapping == nil {
+ return modelName, false
+ }
+ mappedModelName := mapping[modelName]
+ if mappedModelName != "" {
+ return mappedModelName, true
+ }
+ return modelName, false
+}
diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go
new file mode 100644
index 00000000..19936e49
--- /dev/null
+++ b/relay/util/relay_meta.go
@@ -0,0 +1,44 @@
+package util
+
+import (
+ "github.com/gin-gonic/gin"
+ "one-api/common"
+ "strings"
+)
+
+type RelayMeta struct {
+ ChannelType int
+ ChannelId int
+ TokenId int
+ TokenName string
+ UserId int
+ Group string
+ ModelMapping map[string]string
+ BaseURL string
+ APIVersion string
+ APIKey string
+ Config map[string]string
+}
+
+func GetRelayMeta(c *gin.Context) *RelayMeta {
+ meta := RelayMeta{
+ ChannelType: c.GetInt("channel"),
+ ChannelId: c.GetInt("channel_id"),
+ TokenId: c.GetInt("token_id"),
+ TokenName: c.GetString("token_name"),
+ UserId: c.GetInt("id"),
+ Group: c.GetString("group"),
+ ModelMapping: c.GetStringMapString("model_mapping"),
+ BaseURL: c.GetString("base_url"),
+ APIVersion: c.GetString("api_version"),
+ APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+ Config: nil,
+ }
+ if meta.ChannelType == common.ChannelTypeAzure {
+ meta.APIVersion = GetAzureAPIVersion(c)
+ }
+ if meta.BaseURL == "" {
+ meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
+ }
+ return &meta
+}
diff --git a/relay/util/validation.go b/relay/util/validation.go
new file mode 100644
index 00000000..48b42d94
--- /dev/null
+++ b/relay/util/validation.go
@@ -0,0 +1,37 @@
+package util
+
+import (
+ "errors"
+ "math"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
+)
+
+func ValidateTextRequest(textRequest *openai.GeneralOpenAIRequest, relayMode int) error {
+ if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
+ return errors.New("max_tokens is invalid")
+ }
+ if textRequest.Model == "" {
+ return errors.New("model is required")
+ }
+ switch relayMode {
+ case constant.RelayModeCompletions:
+ if textRequest.Prompt == "" {
+ return errors.New("field prompt is required")
+ }
+ case constant.RelayModeChatCompletions:
+ if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
+ return errors.New("field messages is required")
+ }
+ case constant.RelayModeEmbeddings:
+ case constant.RelayModeModerations:
+ if textRequest.Input == "" {
+ return errors.New("field input is required")
+ }
+ case constant.RelayModeEdits:
+ if textRequest.Instruction == "" {
+ return errors.New("field instruction is required")
+ }
+ }
+ return nil
+}
diff --git a/router/api-router.go b/router/api-router.go
index da3f9e61..162675ce 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -35,6 +35,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute := userRoute.Group("/")
selfRoute.Use(middleware.UserAuth())
{
+ selfRoute.GET("/dashboard", controller.GetUserDashboard)
selfRoute.GET("/self", controller.GetSelf)
selfRoute.PUT("/self", controller.UpdateSelf)
selfRoute.DELETE("/self", controller.DeleteSelf)
diff --git a/router/main.go b/router/main.go
index 85127a1a..6504b312 100644
--- a/router/main.go
+++ b/router/main.go
@@ -5,7 +5,8 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
- "one-api/common"
+ "one-api/common/config"
+ "one-api/common/logger"
"os"
"strings"
)
@@ -15,9 +16,9 @@ func SetRouter(router *gin.Engine, buildFS embed.FS) {
SetDashboardRouter(router)
SetRelayRouter(router)
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
- if common.IsMasterNode && frontendBaseUrl != "" {
+ if config.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = ""
- common.SysLog("FRONTEND_BASE_URL is ignored on master node")
+ logger.SysLog("FRONTEND_BASE_URL is ignored on master node")
}
if frontendBaseUrl == "" {
SetWebRouter(router, buildFS)
diff --git a/router/web-router.go b/router/web-router.go
index 7328c7a3..95d8fcb9 100644
--- a/router/web-router.go
+++ b/router/web-router.go
@@ -8,17 +8,18 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/controller"
"one-api/middleware"
"strings"
)
func SetWebRouter(router *gin.Engine, buildFS embed.FS) {
- indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", common.Theme))
+ indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", config.Theme))
router.Use(gzip.Gzip(gzip.DefaultCompression))
router.Use(middleware.GlobalWebRateLimit())
router.Use(middleware.Cache())
- router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", common.Theme))))
+ router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", config.Theme))))
router.NoRoute(func(c *gin.Context) {
if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") {
controller.RelayNotFound(c)
diff --git a/web/README.md b/web/README.md
index ca73b298..86486085 100644
--- a/web/README.md
+++ b/web/README.md
@@ -1,17 +1,38 @@
# One API 的前端界面
+
> 每个文件夹代表一个主题,欢迎提交你的主题
## 提交新的主题
+
> 欢迎在页面底部保留你和 One API 的版权信息以及指向链接
+
1. 在 `web` 文件夹下新建一个文件夹,文件夹名为主题名。
2. 把你的主题文件放到这个文件夹下。
-3. 修改 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。
+3. 修改你的 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。
+4. 修改 `common/constants.go` 中的 `ValidThemes`,把你的主题名称注册进去。
+5. 修改 `web/THEMES` 文件,这里也需要同步修改。
## 主题列表
+
### 主题:default
+
默认主题,由 [JustSong](https://github.com/songquanpeng) 开发。
预览:
|||
|:---:|:---:|
+### 主题:berry
+
+由 [MartialBE](https://github.com/MartialBE) 开发。
+
+预览:
+|||
+|:---:|:---:|
+|||
+|||
+|||
+
+#### 开发说明
+
+请查看 [web/berry/README.md](https://github.com/songquanpeng/one-api/tree/main/web/berry/README.md)
diff --git a/web/THEMES b/web/THEMES
index 331d858c..b6597eeb 100644
--- a/web/THEMES
+++ b/web/THEMES
@@ -1 +1,2 @@
-default
\ No newline at end of file
+default
+berry
\ No newline at end of file
diff --git a/web/berry/.gitignore b/web/berry/.gitignore
new file mode 100644
index 00000000..2b5bba76
--- /dev/null
+++ b/web/berry/.gitignore
@@ -0,0 +1,26 @@
+# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
+
+# dependencies
+/node_modules
+/.pnp
+.pnp.js
+
+# testing
+/coverage
+
+# production
+/build
+
+# misc
+.DS_Store
+.env.local
+.env.development.local
+.env.test.local
+.env.production.local
+
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+.idea
+package-lock.json
+yarn.lock
\ No newline at end of file
diff --git a/web/berry/README.md b/web/berry/README.md
new file mode 100644
index 00000000..170feedc
--- /dev/null
+++ b/web/berry/README.md
@@ -0,0 +1,61 @@
+# One API 前端界面
+
+这个项目是 One API 的前端界面,它基于 [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template) 进行开发。
+
+## 使用的开源项目
+
+使用了以下开源项目作为我们项目的一部分:
+
+- [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template)
+- [minimal-ui-kit](minimal-ui-kit)
+
+## 开发说明
+
+当添加新的渠道时,需要修改以下地方:
+
+1. `web/berry/src/constants/ChannelConstants.js`
+
+在该文件中的 `CHANNEL_OPTIONS` 添加新的渠道
+
+```js
+export const CHANNEL_OPTIONS = {
+ //key 为渠道ID
+ 1: {
+ key: 1, // 渠道ID
+ text: "OpenAI", // 渠道名称
+ value: 1, // 渠道ID
+ color: "primary", // 渠道列表显示的颜色
+ },
+};
+```
+
+2. `web/berry/src/views/Channel/type/Config.js`
+
+在该文件中的`typeConfig`添加新的渠道配置, 如果无需配置,可以不添加
+
+```js
+const typeConfig = {
+ // key 为渠道ID
+ 3: {
+ inputLabel: {
+ // 输入框名称 配置
+ // 对应的字段名称
+ base_url: "AZURE_OPENAI_ENDPOINT",
+ other: "默认 API 版本",
+ },
+ prompt: {
+ // 输入框提示 配置
+ // 对应的字段名称
+ base_url: "请填写AZURE_OPENAI_ENDPOINT",
+
+ // 注意:通过判断 `other` 是否有值来判断是否需要显示 `other` 输入框, 默认是没有值的
+ other: "请输入默认API版本,例如:2023-06-01-preview",
+ },
+ modelGroup: "openai", // 模型组名称,这个值是给 填入渠道支持模型 按钮使用的。 填入渠道支持模型 按钮会根据这个值来获取模型组,如果填写默认是 openai
+ },
+};
+```
+
+## 许可证
+
+本项目中使用的代码遵循 MIT 许可证。
diff --git a/web/berry/jsconfig.json b/web/berry/jsconfig.json
new file mode 100644
index 00000000..35332c70
--- /dev/null
+++ b/web/berry/jsconfig.json
@@ -0,0 +1,9 @@
+{
+ "compilerOptions": {
+ "target": "esnext",
+ "module": "commonjs",
+ "baseUrl": "src"
+ },
+ "include": ["src/**/*"],
+ "exclude": ["node_modules"]
+}
diff --git a/web/berry/package.json b/web/berry/package.json
new file mode 100644
index 00000000..f428fd9c
--- /dev/null
+++ b/web/berry/package.json
@@ -0,0 +1,84 @@
+{
+ "name": "one_api_web",
+ "version": "1.0.0",
+ "proxy": "http://127.0.0.1:3000",
+ "private": true,
+ "homepage": "",
+ "dependencies": {
+ "@emotion/cache": "^11.9.3",
+ "@emotion/react": "^11.9.3",
+ "@emotion/styled": "^11.9.3",
+ "@mui/icons-material": "^5.8.4",
+ "@mui/lab": "^5.0.0-alpha.88",
+ "@mui/material": "^5.8.6",
+ "@mui/system": "^5.8.6",
+ "@mui/utils": "^5.8.6",
+ "@mui/x-date-pickers": "^6.18.5",
+ "@tabler/icons-react": "^2.44.0",
+ "apexcharts": "^3.35.3",
+ "axios": "^0.27.2",
+ "dayjs": "^1.11.10",
+ "formik": "^2.2.9",
+ "framer-motion": "^6.3.16",
+ "history": "^5.3.0",
+ "marked": "^4.1.1",
+ "material-ui-popup-state": "^4.0.1",
+ "notistack": "^3.0.1",
+ "prop-types": "^15.8.1",
+ "react": "^18.2.0",
+ "react-apexcharts": "^1.4.0",
+ "react-device-detect": "^2.2.2",
+ "react-dom": "^18.2.0",
+ "react-perfect-scrollbar": "^1.5.8",
+ "react-redux": "^8.0.2",
+ "react-router": "6.3.0",
+ "react-router-dom": "6.3.0",
+ "react-scripts": "^5.0.1",
+ "react-turnstile": "^1.1.2",
+ "redux": "^4.2.0",
+ "yup": "^0.32.11"
+ },
+ "scripts": {
+ "start": "react-scripts start",
+ "build": "react-scripts build && mv -f build ../build/berry",
+ "test": "react-scripts test",
+ "eject": "react-scripts eject"
+ },
+ "eslintConfig": {
+ "extends": [
+ "react-app"
+ ]
+ },
+ "babel": {
+ "presets": [
+ "@babel/preset-react"
+ ]
+ },
+ "browserslist": {
+ "production": [
+ "defaults",
+ "not IE 11"
+ ],
+ "development": [
+ "last 1 chrome version",
+ "last 1 firefox version",
+ "last 1 safari version"
+ ]
+ },
+ "devDependencies": {
+ "@babel/core": "^7.21.4",
+ "@babel/eslint-parser": "^7.21.3",
+ "eslint": "^8.38.0",
+ "eslint-config-prettier": "^8.8.0",
+ "eslint-config-react-app": "^7.0.1",
+ "eslint-plugin-flowtype": "^8.0.3",
+ "eslint-plugin-import": "^2.27.5",
+ "eslint-plugin-jsx-a11y": "^6.7.1",
+ "eslint-plugin-prettier": "^4.2.1",
+ "eslint-plugin-react": "^7.32.2",
+ "eslint-plugin-react-hooks": "^4.6.0",
+ "immutable": "^4.3.0",
+ "prettier": "^2.8.7",
+ "sass": "^1.53.0"
+ }
+}
diff --git a/web/berry/public/favicon.ico b/web/berry/public/favicon.ico
new file mode 100644
index 00000000..fbcfb14a
Binary files /dev/null and b/web/berry/public/favicon.ico differ
diff --git a/web/berry/public/index.html b/web/berry/public/index.html
new file mode 100644
index 00000000..6f232250
--- /dev/null
+++ b/web/berry/public/index.html
@@ -0,0 +1,26 @@
+
+
+
+ One API
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/web/berry/src/App.js b/web/berry/src/App.js
new file mode 100644
index 00000000..fc54c632
--- /dev/null
+++ b/web/berry/src/App.js
@@ -0,0 +1,43 @@
+import { useSelector } from 'react-redux';
+
+import { ThemeProvider } from '@mui/material/styles';
+import { CssBaseline, StyledEngineProvider } from '@mui/material';
+
+// routing
+import Routes from 'routes';
+
+// defaultTheme
+import themes from 'themes';
+
+// project imports
+import NavigationScroll from 'layout/NavigationScroll';
+
+// auth
+import UserProvider from 'contexts/UserContext';
+import StatusProvider from 'contexts/StatusContext';
+import { SnackbarProvider } from 'notistack';
+
+// ==============================|| APP ||============================== //
+
+const App = () => {
+ const customization = useSelector((state) => state.customization);
+
+ return (
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ );
+};
+
+export default App;
diff --git a/web/berry/src/assets/images/404.svg b/web/berry/src/assets/images/404.svg
new file mode 100644
index 00000000..352a14ad
--- /dev/null
+++ b/web/berry/src/assets/images/404.svg
@@ -0,0 +1,40 @@
+
\ No newline at end of file
diff --git a/web/berry/src/assets/images/auth/auth-blue-card.svg b/web/berry/src/assets/images/auth/auth-blue-card.svg
new file mode 100644
index 00000000..6c9fe3e7
--- /dev/null
+++ b/web/berry/src/assets/images/auth/auth-blue-card.svg
@@ -0,0 +1,65 @@
+
diff --git a/web/berry/src/assets/images/auth/auth-pattern-dark.svg b/web/berry/src/assets/images/auth/auth-pattern-dark.svg
new file mode 100644
index 00000000..aa0e4ab2
--- /dev/null
+++ b/web/berry/src/assets/images/auth/auth-pattern-dark.svg
@@ -0,0 +1,39 @@
+
diff --git a/web/berry/src/assets/images/auth/auth-pattern.svg b/web/berry/src/assets/images/auth/auth-pattern.svg
new file mode 100644
index 00000000..b7ac8e27
--- /dev/null
+++ b/web/berry/src/assets/images/auth/auth-pattern.svg
@@ -0,0 +1,39 @@
+
diff --git a/web/berry/src/assets/images/auth/auth-purple-card.svg b/web/berry/src/assets/images/auth/auth-purple-card.svg
new file mode 100644
index 00000000..c724e0a3
--- /dev/null
+++ b/web/berry/src/assets/images/auth/auth-purple-card.svg
@@ -0,0 +1,69 @@
+
diff --git a/web/berry/src/assets/images/auth/auth-signup-blue-card.svg b/web/berry/src/assets/images/auth/auth-signup-blue-card.svg
new file mode 100644
index 00000000..ebb8e85f
--- /dev/null
+++ b/web/berry/src/assets/images/auth/auth-signup-blue-card.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/web/berry/src/assets/images/auth/auth-signup-white-card.svg b/web/berry/src/assets/images/auth/auth-signup-white-card.svg
new file mode 100644
index 00000000..56b97e20
--- /dev/null
+++ b/web/berry/src/assets/images/auth/auth-signup-white-card.svg
@@ -0,0 +1,40 @@
+
diff --git a/web/berry/src/assets/images/icons/earning.svg b/web/berry/src/assets/images/icons/earning.svg
new file mode 100644
index 00000000..e877b599
--- /dev/null
+++ b/web/berry/src/assets/images/icons/earning.svg
@@ -0,0 +1,5 @@
+
diff --git a/web/berry/src/assets/images/icons/github.svg b/web/berry/src/assets/images/icons/github.svg
new file mode 100644
index 00000000..e5b1b82a
--- /dev/null
+++ b/web/berry/src/assets/images/icons/github.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/web/berry/src/assets/images/icons/shape-avatar.svg b/web/berry/src/assets/images/icons/shape-avatar.svg
new file mode 100644
index 00000000..38aac7e2
--- /dev/null
+++ b/web/berry/src/assets/images/icons/shape-avatar.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/web/berry/src/assets/images/icons/social-google.svg b/web/berry/src/assets/images/icons/social-google.svg
new file mode 100644
index 00000000..2231ce98
--- /dev/null
+++ b/web/berry/src/assets/images/icons/social-google.svg
@@ -0,0 +1,6 @@
+
diff --git a/web/berry/src/assets/images/icons/wechat.svg b/web/berry/src/assets/images/icons/wechat.svg
new file mode 100644
index 00000000..a0b2e36c
--- /dev/null
+++ b/web/berry/src/assets/images/icons/wechat.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/web/berry/src/assets/images/invite/cover.jpg b/web/berry/src/assets/images/invite/cover.jpg
new file mode 100644
index 00000000..93be1a40
Binary files /dev/null and b/web/berry/src/assets/images/invite/cover.jpg differ
diff --git a/web/berry/src/assets/images/invite/cwok_casual_19.webp b/web/berry/src/assets/images/invite/cwok_casual_19.webp
new file mode 100644
index 00000000..1cf2c376
Binary files /dev/null and b/web/berry/src/assets/images/invite/cwok_casual_19.webp differ
diff --git a/web/berry/src/assets/images/logo-2.svg b/web/berry/src/assets/images/logo-2.svg
new file mode 100644
index 00000000..2e674a7e
--- /dev/null
+++ b/web/berry/src/assets/images/logo-2.svg
@@ -0,0 +1,15 @@
+
\ No newline at end of file
diff --git a/web/berry/src/assets/images/logo.svg b/web/berry/src/assets/images/logo.svg
new file mode 100644
index 00000000..348c7e5a
--- /dev/null
+++ b/web/berry/src/assets/images/logo.svg
@@ -0,0 +1,13 @@
+
\ No newline at end of file
diff --git a/web/berry/src/assets/images/users/user-round.svg b/web/berry/src/assets/images/users/user-round.svg
new file mode 100644
index 00000000..eaef7ed9
--- /dev/null
+++ b/web/berry/src/assets/images/users/user-round.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/web/berry/src/assets/scss/_themes-vars.module.scss b/web/berry/src/assets/scss/_themes-vars.module.scss
new file mode 100644
index 00000000..a470b033
--- /dev/null
+++ b/web/berry/src/assets/scss/_themes-vars.module.scss
@@ -0,0 +1,157 @@
+// paper & background
+$paper: #ffffff;
+
+// primary
+$primaryLight: #eef2f6;
+$primaryMain: #2196f3;
+$primaryDark: #1e88e5;
+$primary200: #90caf9;
+$primary800: #1565c0;
+
+// secondary
+$secondaryLight: #ede7f6;
+$secondaryMain: #673ab7;
+$secondaryDark: #5e35b1;
+$secondary200: #b39ddb;
+$secondary800: #4527a0;
+
+// success Colors
+$successLight: #b9f6ca;
+$success200: #69f0ae;
+$successMain: #00e676;
+$successDark: #00c853;
+
+// error
+$errorLight: #ef9a9a;
+$errorMain: #f44336;
+$errorDark: #c62828;
+
+// orange
+$orangeLight: #fbe9e7;
+$orangeMain: #ffab91;
+$orangeDark: #d84315;
+
+// warning
+$warningLight: #fff8e1;
+$warningMain: #ffe57f;
+$warningDark: #ffc107;
+
+// grey
+$grey50: #f8fafc;
+$grey100: #eef2f6;
+$grey200: #e3e8ef;
+$grey300: #cdd5df;
+$grey500: #697586;
+$grey600: #4b5565;
+$grey700: #364152;
+$grey900: #121926;
+
+// ==============================|| DARK THEME VARIANTS ||============================== //
+
+// paper & background
+$darkBackground: #1a223f; // level 3
+$darkPaper: #111936; // level 4
+
+// dark 800 & 900
+$darkLevel1: #29314f; // level 1
+$darkLevel2: #212946; // level 2
+
+// primary dark
+$darkPrimaryLight: #eef2f6;
+$darkPrimaryMain: #2196f3;
+$darkPrimaryDark: #1e88e5;
+$darkPrimary200: #90caf9;
+$darkPrimary800: #1565c0;
+
+// secondary dark
+$darkSecondaryLight: #d1c4e9;
+$darkSecondaryMain: #7c4dff;
+$darkSecondaryDark: #651fff;
+$darkSecondary200: #b39ddb;
+$darkSecondary800: #6200ea;
+
+// text variants
+$darkTextTitle: #d7dcec;
+$darkTextPrimary: #bdc8f0;
+$darkTextSecondary: #8492c4;
+
+// ==============================|| JAVASCRIPT ||============================== //
+
+:export {
+ // paper & background
+ paper: $paper;
+
+ // primary
+ primaryLight: $primaryLight;
+ primary200: $primary200;
+ primaryMain: $primaryMain;
+ primaryDark: $primaryDark;
+ primary800: $primary800;
+
+ // secondary
+ secondaryLight: $secondaryLight;
+ secondary200: $secondary200;
+ secondaryMain: $secondaryMain;
+ secondaryDark: $secondaryDark;
+ secondary800: $secondary800;
+
+ // success
+ successLight: $successLight;
+ success200: $success200;
+ successMain: $successMain;
+ successDark: $successDark;
+
+ // error
+ errorLight: $errorLight;
+ errorMain: $errorMain;
+ errorDark: $errorDark;
+
+ // orange
+ orangeLight: $orangeLight;
+ orangeMain: $orangeMain;
+ orangeDark: $orangeDark;
+
+ // warning
+ warningLight: $warningLight;
+ warningMain: $warningMain;
+ warningDark: $warningDark;
+
+ // grey
+ grey50: $grey50;
+ grey100: $grey100;
+ grey200: $grey200;
+ grey300: $grey300;
+ grey500: $grey500;
+ grey600: $grey600;
+ grey700: $grey700;
+ grey900: $grey900;
+
+ // ==============================|| DARK THEME VARIANTS ||============================== //
+
+ // paper & background
+ darkPaper: $darkPaper;
+ darkBackground: $darkBackground;
+
+ // dark 800 & 900
+ darkLevel1: $darkLevel1;
+ darkLevel2: $darkLevel2;
+
+ // text variants
+ darkTextTitle: $darkTextTitle;
+ darkTextPrimary: $darkTextPrimary;
+ darkTextSecondary: $darkTextSecondary;
+
+ // primary dark
+ darkPrimaryLight: $darkPrimaryLight;
+ darkPrimaryMain: $darkPrimaryMain;
+ darkPrimaryDark: $darkPrimaryDark;
+ darkPrimary200: $darkPrimary200;
+ darkPrimary800: $darkPrimary800;
+
+ // secondary dark
+ darkSecondaryLight: $darkSecondaryLight;
+ darkSecondaryMain: $darkSecondaryMain;
+ darkSecondaryDark: $darkSecondaryDark;
+ darkSecondary200: $darkSecondary200;
+ darkSecondary800: $darkSecondary800;
+}
diff --git a/web/berry/src/assets/scss/style.scss b/web/berry/src/assets/scss/style.scss
new file mode 100644
index 00000000..17d566e6
--- /dev/null
+++ b/web/berry/src/assets/scss/style.scss
@@ -0,0 +1,128 @@
+// color variants
+@import 'themes-vars.module.scss';
+
+// third-party
+@import '~react-perfect-scrollbar/dist/css/styles.css';
+
+// ==============================|| LIGHT BOX ||============================== //
+.fullscreen .react-images__blanket {
+ z-index: 1200;
+}
+
+// ==============================|| APEXCHART ||============================== //
+
+.apexcharts-legend-series .apexcharts-legend-marker {
+ margin-right: 8px;
+}
+
+// ==============================|| PERFECT SCROLLBAR ||============================== //
+
+.scrollbar-container {
+ .ps__rail-y {
+ &:hover > .ps__thumb-y,
+ &:focus > .ps__thumb-y,
+ &.ps--clicking .ps__thumb-y {
+ background-color: $grey500;
+ width: 5px;
+ }
+ }
+ .ps__thumb-y {
+ background-color: $grey500;
+ border-radius: 6px;
+ width: 5px;
+ right: 0;
+ }
+}
+
+.scrollbar-container.ps,
+.scrollbar-container > .ps {
+ &.ps--active-y > .ps__rail-y {
+ width: 5px;
+ background-color: transparent !important;
+ z-index: 999;
+ &:hover,
+ &.ps--clicking {
+ width: 5px;
+ background-color: transparent;
+ }
+ }
+ &.ps--scrolling-y > .ps__rail-y,
+ &.ps--scrolling-x > .ps__rail-x {
+ opacity: 0.4;
+ background-color: transparent;
+ }
+}
+
+// ==============================|| ANIMATION KEYFRAMES ||============================== //
+
+@keyframes wings {
+ 50% {
+ transform: translateY(-40px);
+ }
+ 100% {
+ transform: translateY(0px);
+ }
+}
+
+@keyframes blink {
+ 50% {
+ opacity: 0;
+ }
+ 100% {
+ opacity: 1;
+ }
+}
+
+@keyframes bounce {
+ 0%,
+ 20%,
+ 53%,
+ to {
+ animation-timing-function: cubic-bezier(0.215, 0.61, 0.355, 1);
+ transform: translateZ(0);
+ }
+ 40%,
+ 43% {
+ animation-timing-function: cubic-bezier(0.755, 0.05, 0.855, 0.06);
+ transform: translate3d(0, -5px, 0);
+ }
+ 70% {
+ animation-timing-function: cubic-bezier(0.755, 0.05, 0.855, 0.06);
+ transform: translate3d(0, -7px, 0);
+ }
+ 80% {
+ transition-timing-function: cubic-bezier(0.215, 0.61, 0.355, 1);
+ transform: translateZ(0);
+ }
+ 90% {
+ transform: translate3d(0, -2px, 0);
+ }
+}
+
+@keyframes slideY {
+ 0%,
+ 50%,
+ 100% {
+ transform: translateY(0px);
+ }
+ 25% {
+ transform: translateY(-10px);
+ }
+ 75% {
+ transform: translateY(10px);
+ }
+}
+
+@keyframes slideX {
+ 0%,
+ 50%,
+ 100% {
+ transform: translateX(0px);
+ }
+ 25% {
+ transform: translateX(-10px);
+ }
+ 75% {
+ transform: translateX(10px);
+ }
+}
diff --git a/web/berry/src/config.js b/web/berry/src/config.js
new file mode 100644
index 00000000..eeeda99a
--- /dev/null
+++ b/web/berry/src/config.js
@@ -0,0 +1,29 @@
+const config = {
+ // basename: only at build time to set, and Don't add '/' at end off BASENAME for breadcrumbs, also Don't put only '/' use blank('') instead,
+ // like '/berry-material-react/react/default'
+ basename: '/',
+ defaultPath: '/panel/dashboard',
+ fontFamily: `'Roboto', sans-serif, Helvetica, Arial, sans-serif`,
+ borderRadius: 12,
+ siteInfo: {
+ chat_link: '',
+ display_in_currency: true,
+ email_verification: false,
+ footer_html: '',
+ github_client_id: '',
+ github_oauth: false,
+ logo: '',
+ quota_per_unit: 500000,
+ server_address: '',
+ start_time: 0,
+ system_name: 'One API',
+ top_up_link: '',
+ turnstile_check: false,
+ turnstile_site_key: '',
+ version: '',
+ wechat_login: false,
+ wechat_qrcode: ''
+ }
+};
+
+export default config;
diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js
new file mode 100644
index 00000000..3ce27838
--- /dev/null
+++ b/web/berry/src/constants/ChannelConstants.js
@@ -0,0 +1,146 @@
+export const CHANNEL_OPTIONS = {
+ 1: {
+ key: 1,
+ text: 'OpenAI',
+ value: 1,
+ color: 'primary'
+ },
+ 14: {
+ key: 14,
+ text: 'Anthropic Claude',
+ value: 14,
+ color: 'info'
+ },
+ 3: {
+ key: 3,
+ text: 'Azure OpenAI',
+ value: 3,
+ color: 'orange'
+ },
+ 11: {
+ key: 11,
+ text: 'Google PaLM2',
+ value: 11,
+ color: 'orange'
+ },
+ 24: {
+ key: 24,
+ text: 'Google Gemini',
+ value: 24,
+ color: 'orange'
+ },
+ 15: {
+ key: 15,
+ text: '百度文心千帆',
+ value: 15,
+ color: 'default'
+ },
+ 17: {
+ key: 17,
+ text: '阿里通义千问',
+ value: 17,
+ color: 'default'
+ },
+ 18: {
+ key: 18,
+ text: '讯飞星火认知',
+ value: 18,
+ color: 'default'
+ },
+ 16: {
+ key: 16,
+ text: '智谱 ChatGLM',
+ value: 16,
+ color: 'default'
+ },
+ 19: {
+ key: 19,
+ text: '360 智脑',
+ value: 19,
+ color: 'default'
+ },
+ 23: {
+ key: 23,
+ text: '腾讯混元',
+ value: 23,
+ color: 'default'
+ },
+ 8: {
+ key: 8,
+ text: '自定义渠道',
+ value: 8,
+ color: 'primary'
+ },
+ 22: {
+ key: 22,
+ text: '知识库:FastGPT',
+ value: 22,
+ color: 'default'
+ },
+ 21: {
+ key: 21,
+ text: '知识库:AI Proxy',
+ value: 21,
+ color: 'purple'
+ },
+ 20: {
+ key: 20,
+ text: '代理:OpenRouter',
+ value: 20,
+ color: 'primary'
+ },
+ 2: {
+ key: 2,
+ text: '代理:API2D',
+ value: 2,
+ color: 'primary'
+ },
+ 5: {
+ key: 5,
+ text: '代理:OpenAI-SB',
+ value: 5,
+ color: 'primary'
+ },
+ 7: {
+ key: 7,
+ text: '代理:OhMyGPT',
+ value: 7,
+ color: 'primary'
+ },
+ 10: {
+ key: 10,
+ text: '代理:AI Proxy',
+ value: 10,
+ color: 'primary'
+ },
+ 4: {
+ key: 4,
+ text: '代理:CloseAI',
+ value: 4,
+ color: 'primary'
+ },
+ 6: {
+ key: 6,
+ text: '代理:OpenAI Max',
+ value: 6,
+ color: 'primary'
+ },
+ 9: {
+ key: 9,
+ text: '代理:AI.LS',
+ value: 9,
+ color: 'primary'
+ },
+ 12: {
+ key: 12,
+ text: '代理:API2GPT',
+ value: 12,
+ color: 'primary'
+ },
+ 13: {
+ key: 13,
+ text: '代理:AIGC2D',
+ value: 13,
+ color: 'primary'
+ }
+};
diff --git a/web/berry/src/constants/CommonConstants.js b/web/berry/src/constants/CommonConstants.js
new file mode 100644
index 00000000..1a37d5f6
--- /dev/null
+++ b/web/berry/src/constants/CommonConstants.js
@@ -0,0 +1 @@
+export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend!
diff --git a/web/berry/src/constants/SnackbarConstants.js b/web/berry/src/constants/SnackbarConstants.js
new file mode 100644
index 00000000..a05c6652
--- /dev/null
+++ b/web/berry/src/constants/SnackbarConstants.js
@@ -0,0 +1,27 @@
+export const snackbarConstants = {
+ Common: {
+ ERROR: {
+ variant: 'error',
+ autoHideDuration: 5000
+ },
+ WARNING: {
+ variant: 'warning',
+ autoHideDuration: 10000
+ },
+ SUCCESS: {
+ variant: 'success',
+ autoHideDuration: 1500
+ },
+ INFO: {
+ variant: 'info',
+ autoHideDuration: 3000
+ },
+ NOTICE: {
+ variant: 'info',
+ autoHideDuration: 20000
+ }
+ },
+ Mobile: {
+ anchorOrigin: { vertical: 'bottom', horizontal: 'center' }
+ }
+};
diff --git a/web/berry/src/constants/index.js b/web/berry/src/constants/index.js
new file mode 100644
index 00000000..716ef6aa
--- /dev/null
+++ b/web/berry/src/constants/index.js
@@ -0,0 +1,3 @@
+export * from './SnackbarConstants';
+export * from './CommonConstants';
+export * from './ChannelConstants';
diff --git a/web/berry/src/contexts/StatusContext.js b/web/berry/src/contexts/StatusContext.js
new file mode 100644
index 00000000..ed9f1621
--- /dev/null
+++ b/web/berry/src/contexts/StatusContext.js
@@ -0,0 +1,70 @@
+import { useEffect, useCallback, createContext } from "react";
+import { API } from "utils/api";
+import { showNotice, showError } from "utils/common";
+import { SET_SITE_INFO } from "store/actions";
+import { useDispatch } from "react-redux";
+
+export const LoadStatusContext = createContext();
+
+// eslint-disable-next-line
+const StatusProvider = ({ children }) => {
+ const dispatch = useDispatch();
+
+ const loadStatus = useCallback(async () => {
+ const res = await API.get("/api/status");
+ const { success, data } = res.data;
+ let system_name = "";
+ if (success) {
+ if (!data.chat_link) {
+ delete data.chat_link;
+ }
+ localStorage.setItem("siteInfo", JSON.stringify(data));
+ localStorage.setItem("quota_per_unit", data.quota_per_unit);
+ localStorage.setItem("display_in_currency", data.display_in_currency);
+ dispatch({ type: SET_SITE_INFO, payload: data });
+ if (
+ data.version !== process.env.REACT_APP_VERSION &&
+ data.version !== "v0.0.0" &&
+ data.version !== "" &&
+ process.env.REACT_APP_VERSION !== ""
+ ) {
+ showNotice(
+ `新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面`
+ );
+ }
+ if (data.system_name) {
+ system_name = data.system_name;
+ }
+ } else {
+ const backupSiteInfo = localStorage.getItem("siteInfo");
+ if (backupSiteInfo) {
+ const data = JSON.parse(backupSiteInfo);
+ if (data.system_name) {
+ system_name = data.system_name;
+ }
+ dispatch({
+ type: SET_SITE_INFO,
+ payload: data,
+ });
+ }
+ showError("无法正常连接至服务器!");
+ }
+
+ if (system_name) {
+ document.title = system_name;
+ }
+ }, [dispatch]);
+
+ useEffect(() => {
+ loadStatus().then();
+ }, [loadStatus]);
+
+ return (
+
+ {" "}
+ {children}{" "}
+
+ );
+};
+
+export default StatusProvider;
diff --git a/web/berry/src/contexts/UserContext.js b/web/berry/src/contexts/UserContext.js
new file mode 100644
index 00000000..491da9d9
--- /dev/null
+++ b/web/berry/src/contexts/UserContext.js
@@ -0,0 +1,29 @@
+// contexts/User/index.jsx
+import React, { useEffect, useCallback, createContext, useState } from 'react';
+import { LOGIN } from 'store/actions';
+import { useDispatch } from 'react-redux';
+
+export const UserContext = createContext();
+
+// eslint-disable-next-line
+const UserProvider = ({ children }) => {
+ const dispatch = useDispatch();
+ const [isUserLoaded, setIsUserLoaded] = useState(false);
+
+ const loadUser = useCallback(() => {
+ let user = localStorage.getItem('user');
+ if (user) {
+ let data = JSON.parse(user);
+ dispatch({ type: LOGIN, payload: data });
+ }
+ setIsUserLoaded(true);
+ }, [dispatch]);
+
+ useEffect(() => {
+ loadUser();
+ }, [loadUser]);
+
+ return {children} ;
+};
+
+export default UserProvider;
diff --git a/web/berry/src/hooks/useAuth.js b/web/berry/src/hooks/useAuth.js
new file mode 100644
index 00000000..fa7cb934
--- /dev/null
+++ b/web/berry/src/hooks/useAuth.js
@@ -0,0 +1,13 @@
+import { isAdmin } from 'utils/common';
+import { useNavigate } from 'react-router-dom';
+const navigate = useNavigate();
+
+const useAuth = () => {
+ const userIsAdmin = isAdmin();
+
+ if (!userIsAdmin) {
+ navigate('/panel/404');
+ }
+};
+
+export default useAuth;
diff --git a/web/berry/src/hooks/useLogin.js b/web/berry/src/hooks/useLogin.js
new file mode 100644
index 00000000..53626577
--- /dev/null
+++ b/web/berry/src/hooks/useLogin.js
@@ -0,0 +1,78 @@
+import { API } from 'utils/api';
+import { useDispatch } from 'react-redux';
+import { LOGIN } from 'store/actions';
+import { useNavigate } from 'react-router';
+import { showSuccess } from 'utils/common';
+
+const useLogin = () => {
+ const dispatch = useDispatch();
+ const navigate = useNavigate();
+ const login = async (username, password) => {
+ try {
+ const res = await API.post(`/api/user/login`, {
+ username,
+ password
+ });
+ const { success, message, data } = res.data;
+ if (success) {
+ localStorage.setItem('user', JSON.stringify(data));
+ dispatch({ type: LOGIN, payload: data });
+ navigate('/panel');
+ }
+ return { success, message };
+ } catch (err) {
+ // 请求失败,设置错误信息
+ return { success: false, message: '' };
+ }
+ };
+
+ const githubLogin = async (code, state) => {
+ try {
+ const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
+ const { success, message, data } = res.data;
+ if (success) {
+ if (message === 'bind') {
+ showSuccess('绑定成功!');
+ navigate('/panel');
+ } else {
+ dispatch({ type: LOGIN, payload: data });
+ localStorage.setItem('user', JSON.stringify(data));
+ showSuccess('登录成功!');
+ navigate('/panel');
+ }
+ }
+ return { success, message };
+ } catch (err) {
+ // 请求失败,设置错误信息
+ return { success: false, message: '' };
+ }
+ };
+
+ const wechatLogin = async (code) => {
+ try {
+ const res = await API.get(`/api/oauth/wechat?code=${code}`);
+ const { success, message, data } = res.data;
+ if (success) {
+ dispatch({ type: LOGIN, payload: data });
+ localStorage.setItem('user', JSON.stringify(data));
+ showSuccess('登录成功!');
+ navigate('/panel');
+ }
+ return { success, message };
+ } catch (err) {
+ // 请求失败,设置错误信息
+ return { success: false, message: '' };
+ }
+ };
+
+ const logout = async () => {
+ await API.get('/api/user/logout');
+ localStorage.removeItem('user');
+ dispatch({ type: LOGIN, payload: null });
+ navigate('/');
+ };
+
+ return { login, logout, githubLogin, wechatLogin };
+};
+
+export default useLogin;
diff --git a/web/berry/src/hooks/useRegister.js b/web/berry/src/hooks/useRegister.js
new file mode 100644
index 00000000..5377e96d
--- /dev/null
+++ b/web/berry/src/hooks/useRegister.js
@@ -0,0 +1,43 @@
+import { API } from 'utils/api';
+import { useNavigate } from 'react-router';
+import { showSuccess } from 'utils/common';
+
+const useRegister = () => {
+ const navigate = useNavigate();
+ const register = async (input, turnstile) => {
+ try {
+ let affCode = localStorage.getItem('aff');
+ if (affCode) {
+ input = { ...input, aff_code: affCode };
+ }
+ const res = await API.post(`/api/user/register?turnstile=${turnstile}`, input);
+ const { success, message } = res.data;
+ if (success) {
+ showSuccess('注册成功!');
+ navigate('/login');
+ }
+ return { success, message };
+ } catch (err) {
+ // 请求失败,设置错误信息
+ return { success: false, message: '' };
+ }
+ };
+
+ const sendVerificationCode = async (email, turnstile) => {
+ try {
+ const res = await API.get(`/api/verification?email=${email}&turnstile=${turnstile}`);
+ const { success, message } = res.data;
+ if (success) {
+ showSuccess('验证码发送成功,请检查你的邮箱!');
+ }
+ return { success, message };
+ } catch (err) {
+ // 请求失败,设置错误信息
+ return { success: false, message: '' };
+ }
+ };
+
+ return { register, sendVerificationCode };
+};
+
+export default useRegister;
diff --git a/web/berry/src/hooks/useScriptRef.js b/web/berry/src/hooks/useScriptRef.js
new file mode 100644
index 00000000..bd300cbb
--- /dev/null
+++ b/web/berry/src/hooks/useScriptRef.js
@@ -0,0 +1,18 @@
+import { useEffect, useRef } from 'react';
+
+// ==============================|| ELEMENT REFERENCE HOOKS ||============================== //
+
+const useScriptRef = () => {
+ const scripted = useRef(true);
+
+ useEffect(
+ () => () => {
+ scripted.current = true;
+ },
+ []
+ );
+
+ return scripted;
+};
+
+export default useScriptRef;
diff --git a/web/berry/src/index.js b/web/berry/src/index.js
new file mode 100644
index 00000000..d1411be3
--- /dev/null
+++ b/web/berry/src/index.js
@@ -0,0 +1,31 @@
+import { createRoot } from 'react-dom/client';
+
+// third party
+import { BrowserRouter } from 'react-router-dom';
+import { Provider } from 'react-redux';
+
+// project imports
+import * as serviceWorker from 'serviceWorker';
+import App from 'App';
+import { store } from 'store';
+
+// style + assets
+import 'assets/scss/style.scss';
+import config from './config';
+
+// ==============================|| REACT DOM RENDER ||============================== //
+
+const container = document.getElementById('root');
+const root = createRoot(container); // createRoot(container!) if you use TypeScript
+root.render(
+
+
+
+
+
+);
+
+// If you want your app to work offline and load faster, you can change
+// unregister() to register() below. Note this comes with some pitfalls.
+// Learn more about service workers: https://bit.ly/CRA-PWA
+serviceWorker.register();
diff --git a/web/berry/src/layout/MainLayout/Header/ProfileSection/index.js b/web/berry/src/layout/MainLayout/Header/ProfileSection/index.js
new file mode 100644
index 00000000..3e351254
--- /dev/null
+++ b/web/berry/src/layout/MainLayout/Header/ProfileSection/index.js
@@ -0,0 +1,173 @@
+import { useState, useRef, useEffect } from 'react';
+
+import { useSelector } from 'react-redux';
+import { useNavigate } from 'react-router-dom';
+// material-ui
+import { useTheme } from '@mui/material/styles';
+import {
+ Avatar,
+ Chip,
+ ClickAwayListener,
+ List,
+ ListItemButton,
+ ListItemIcon,
+ ListItemText,
+ Paper,
+ Popper,
+ Typography
+} from '@mui/material';
+
+// project imports
+import MainCard from 'ui-component/cards/MainCard';
+import Transitions from 'ui-component/extended/Transitions';
+import User1 from 'assets/images/users/user-round.svg';
+import useLogin from 'hooks/useLogin';
+
+// assets
+import { IconLogout, IconSettings, IconUserScan } from '@tabler/icons-react';
+
+// ==============================|| PROFILE MENU ||============================== //
+
+const ProfileSection = () => {
+ const theme = useTheme();
+ const navigate = useNavigate();
+ const customization = useSelector((state) => state.customization);
+ const { logout } = useLogin();
+
+ const [open, setOpen] = useState(false);
+ /**
+ * anchorRef is used on different componets and specifying one type leads to other components throwing an error
+ * */
+ const anchorRef = useRef(null);
+ const handleLogout = async () => {
+ logout();
+ };
+
+ const handleClose = (event) => {
+ if (anchorRef.current && anchorRef.current.contains(event.target)) {
+ return;
+ }
+ setOpen(false);
+ };
+
+ const handleToggle = () => {
+ setOpen((prevOpen) => !prevOpen);
+ };
+
+ const prevOpen = useRef(open);
+ useEffect(() => {
+ if (prevOpen.current === true && open === false) {
+ anchorRef.current.focus();
+ }
+
+ prevOpen.current = open;
+ }, [open]);
+
+ return (
+ <>
+
+ }
+ label={}
+ variant="outlined"
+ ref={anchorRef}
+ aria-controls={open ? 'menu-list-grow' : undefined}
+ aria-haspopup="true"
+ onClick={handleToggle}
+ color="primary"
+ />
+
+ {({ TransitionProps }) => (
+
+
+
+
+
+ navigate('/panel/profile')}>
+
+
+
+ 设置} />
+
+
+
+
+
+
+ 登出} />
+
+
+
+
+
+
+ )}
+
+ >
+ );
+};
+
+export default ProfileSection;
diff --git a/web/berry/src/layout/MainLayout/Header/index.js b/web/berry/src/layout/MainLayout/Header/index.js
new file mode 100644
index 00000000..51d40c75
--- /dev/null
+++ b/web/berry/src/layout/MainLayout/Header/index.js
@@ -0,0 +1,68 @@
+import PropTypes from 'prop-types';
+
+// material-ui
+import { useTheme } from '@mui/material/styles';
+import { Avatar, Box, ButtonBase } from '@mui/material';
+
+// project imports
+import LogoSection from '../LogoSection';
+import ProfileSection from './ProfileSection';
+
+// assets
+import { IconMenu2 } from '@tabler/icons-react';
+
+// ==============================|| MAIN NAVBAR / HEADER ||============================== //
+
+const Header = ({ handleLeftDrawerToggle }) => {
+ const theme = useTheme();
+
+ return (
+ <>
+ {/* logo & toggler button */}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ >
+ );
+};
+
+Header.propTypes = {
+ handleLeftDrawerToggle: PropTypes.func
+};
+
+export default Header;
diff --git a/web/berry/src/layout/MainLayout/LogoSection/index.js b/web/berry/src/layout/MainLayout/LogoSection/index.js
new file mode 100644
index 00000000..1d70e48c
--- /dev/null
+++ b/web/berry/src/layout/MainLayout/LogoSection/index.js
@@ -0,0 +1,23 @@
+import { Link } from 'react-router-dom';
+import { useDispatch, useSelector } from 'react-redux';
+
+// material-ui
+import { ButtonBase } from '@mui/material';
+
+// project imports
+import Logo from 'ui-component/Logo';
+import { MENU_OPEN } from 'store/actions';
+
+// ==============================|| MAIN LOGO ||============================== //
+
+const LogoSection = () => {
+ const defaultId = useSelector((state) => state.customization.defaultId);
+ const dispatch = useDispatch();
+ return (
+ dispatch({ type: MENU_OPEN, id: defaultId })} component={Link} to="/">
+
+
+ );
+};
+
+export default LogoSection;
diff --git a/web/berry/src/layout/MainLayout/Sidebar/MenuCard/index.js b/web/berry/src/layout/MainLayout/Sidebar/MenuCard/index.js
new file mode 100644
index 00000000..16b13231
--- /dev/null
+++ b/web/berry/src/layout/MainLayout/Sidebar/MenuCard/index.js
@@ -0,0 +1,130 @@
+// import PropTypes from 'prop-types';
+import { useSelector } from 'react-redux';
+
+// material-ui
+import { styled, useTheme } from '@mui/material/styles';
+import {
+ Avatar,
+ Card,
+ CardContent,
+ // Grid,
+ // LinearProgress,
+ List,
+ ListItem,
+ ListItemAvatar,
+ ListItemText,
+ Typography
+ // linearProgressClasses
+} from '@mui/material';
+import User1 from 'assets/images/users/user-round.svg';
+import { useNavigate } from 'react-router-dom';
+
+// assets
+// import TableChartOutlinedIcon from '@mui/icons-material/TableChartOutlined';
+
+// styles
+// const BorderLinearProgress = styled(LinearProgress)(({ theme }) => ({
+// height: 10,
+// borderRadius: 30,
+// [`&.${linearProgressClasses.colorPrimary}`]: {
+// backgroundColor: '#fff'
+// },
+// [`& .${linearProgressClasses.bar}`]: {
+// borderRadius: 5,
+// backgroundColor: theme.palette.primary.main
+// }
+// }));
+
+const CardStyle = styled(Card)(({ theme }) => ({
+ background: theme.palette.primary.light,
+ marginBottom: '22px',
+ overflow: 'hidden',
+ position: 'relative',
+ '&:after': {
+ content: '""',
+ position: 'absolute',
+ width: '157px',
+ height: '157px',
+ background: theme.palette.primary[200],
+ borderRadius: '50%',
+ top: '-105px',
+ right: '-96px'
+ }
+}));
+
+// ==============================|| PROGRESS BAR WITH LABEL ||============================== //
+
+// function LinearProgressWithLabel({ value, ...others }) {
+// const theme = useTheme();
+
+// return (
+//
+//
+//
+//
+//
+// Progress
+//
+//
+//
+// {`${Math.round(value)}%`}
+//
+//
+//
+//
+//
+//
+//
+// );
+// }
+
+// LinearProgressWithLabel.propTypes = {
+// value: PropTypes.number
+// };
+
+// ==============================|| SIDEBAR MENU Card ||============================== //
+
+const MenuCard = () => {
+ const theme = useTheme();
+ const account = useSelector((state) => state.account);
+ const navigate = useNavigate();
+
+ return (
+
+
+
+
+
+ navigate('/panel/profile')}
+ >
+
+
+ {account.user?.username}
+
+ }
+ secondary={ 欢迎回来 }
+ />
+
+
+ {/* */}
+
+
+ );
+};
+
+export default MenuCard;
diff --git a/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavCollapse/index.js b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavCollapse/index.js
new file mode 100644
index 00000000..0632d56f
--- /dev/null
+++ b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavCollapse/index.js
@@ -0,0 +1,158 @@
+import PropTypes from 'prop-types';
+import { useEffect, useState } from 'react';
+import { useSelector } from 'react-redux';
+import { useLocation, useNavigate } from 'react-router';
+
+// material-ui
+import { useTheme } from '@mui/material/styles';
+import { Collapse, List, ListItemButton, ListItemIcon, ListItemText, Typography } from '@mui/material';
+
+// project imports
+import NavItem from '../NavItem';
+
+// assets
+import FiberManualRecordIcon from '@mui/icons-material/FiberManualRecord';
+import { IconChevronDown, IconChevronUp } from '@tabler/icons-react';
+
+// ==============================|| SIDEBAR MENU LIST COLLAPSE ITEMS ||============================== //
+
+const NavCollapse = ({ menu, level }) => {
+ const theme = useTheme();
+ const customization = useSelector((state) => state.customization);
+ const navigate = useNavigate();
+
+ const [open, setOpen] = useState(false);
+ const [selected, setSelected] = useState(null);
+
+ const handleClick = () => {
+ setOpen(!open);
+ setSelected(!selected ? menu.id : null);
+ if (menu?.id !== 'authentication') {
+ navigate(menu.children[0]?.url);
+ }
+ };
+
+ const { pathname } = useLocation();
+ const checkOpenForParent = (child, id) => {
+ child.forEach((item) => {
+ if (item.url === pathname) {
+ setOpen(true);
+ setSelected(id);
+ }
+ });
+ };
+
+ // menu collapse for sub-levels
+ useEffect(() => {
+ setOpen(false);
+ setSelected(null);
+ if (menu.children) {
+ menu.children.forEach((item) => {
+ if (item.children?.length) {
+ checkOpenForParent(item.children, menu.id);
+ }
+ if (item.url === pathname) {
+ setSelected(menu.id);
+ setOpen(true);
+ }
+ });
+ }
+
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, [pathname, menu.children]);
+
+ // menu collapse & item
+ const menus = menu.children?.map((item) => {
+ switch (item.type) {
+ case 'collapse':
+ return ;
+ case 'item':
+ return ;
+ default:
+ return (
+
+ Menu Items Error
+
+ );
+ }
+ });
+
+ const Icon = menu.icon;
+ const menuIcon = menu.icon ? (
+
+ ) : (
+ 0 ? 'inherit' : 'medium'}
+ />
+ );
+
+ return (
+ <>
+ 1 ? 'transparent !important' : 'inherit',
+ py: level > 1 ? 1 : 1.25,
+ pl: `${level * 24}px`
+ }}
+ selected={selected === menu.id}
+ onClick={handleClick}
+ >
+ {menuIcon}
+
+ {menu.title}
+
+ }
+ secondary={
+ menu.caption && (
+
+ {menu.caption}
+
+ )
+ }
+ />
+ {open ? (
+
+ ) : (
+
+ )}
+
+
+
+ {menus}
+
+
+ >
+ );
+};
+
+NavCollapse.propTypes = {
+ menu: PropTypes.object,
+ level: PropTypes.number
+};
+
+export default NavCollapse;
diff --git a/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavGroup/index.js b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavGroup/index.js
new file mode 100644
index 00000000..b6479bc2
--- /dev/null
+++ b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavGroup/index.js
@@ -0,0 +1,61 @@
+import PropTypes from 'prop-types';
+
+// material-ui
+import { useTheme } from '@mui/material/styles';
+import { Divider, List, Typography } from '@mui/material';
+
+// project imports
+import NavItem from '../NavItem';
+import NavCollapse from '../NavCollapse';
+
+// ==============================|| SIDEBAR MENU LIST GROUP ||============================== //
+
+const NavGroup = ({ item }) => {
+ const theme = useTheme();
+
+ // menu list collapse & items
+ const items = item.children?.map((menu) => {
+ switch (menu.type) {
+ case 'collapse':
+ return ;
+ case 'item':
+ return ;
+ default:
+ return (
+
+ Menu Items Error
+
+ );
+ }
+ });
+
+ return (
+ <>
+
+ {item.title}
+ {item.caption && (
+
+ {item.caption}
+
+ )}
+
+ )
+ }
+ >
+ {items}
+
+
+ {/* group divider */}
+
+ >
+ );
+};
+
+NavGroup.propTypes = {
+ item: PropTypes.object
+};
+
+export default NavGroup;
diff --git a/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavItem/index.js b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavItem/index.js
new file mode 100644
index 00000000..ddce9cf4
--- /dev/null
+++ b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavItem/index.js
@@ -0,0 +1,115 @@
+import PropTypes from 'prop-types';
+import { forwardRef, useEffect } from 'react';
+import { Link, useLocation } from 'react-router-dom';
+import { useDispatch, useSelector } from 'react-redux';
+
+// material-ui
+import { useTheme } from '@mui/material/styles';
+import { Avatar, Chip, ListItemButton, ListItemIcon, ListItemText, Typography, useMediaQuery } from '@mui/material';
+
+// project imports
+import { MENU_OPEN, SET_MENU } from 'store/actions';
+
+// assets
+import FiberManualRecordIcon from '@mui/icons-material/FiberManualRecord';
+
+// ==============================|| SIDEBAR MENU LIST ITEMS ||============================== //
+
+const NavItem = ({ item, level }) => {
+ const theme = useTheme();
+ const dispatch = useDispatch();
+ const { pathname } = useLocation();
+ const customization = useSelector((state) => state.customization);
+ const matchesSM = useMediaQuery(theme.breakpoints.down('lg'));
+
+ const Icon = item.icon;
+ const itemIcon = item?.icon ? (
+
+ ) : (
+ id === item?.id) > -1 ? 8 : 6,
+ height: customization.isOpen.findIndex((id) => id === item?.id) > -1 ? 8 : 6
+ }}
+ fontSize={level > 0 ? 'inherit' : 'medium'}
+ />
+ );
+
+ let itemTarget = '_self';
+ if (item.target) {
+ itemTarget = '_blank';
+ }
+
+ let listItemProps = {
+ component: forwardRef((props, ref) => )
+ };
+ if (item?.external) {
+ listItemProps = { component: 'a', href: item.url, target: itemTarget };
+ }
+
+ const itemHandler = (id) => {
+ dispatch({ type: MENU_OPEN, id });
+ if (matchesSM) dispatch({ type: SET_MENU, opened: false });
+ };
+
+ // active menu item on page load
+ useEffect(() => {
+ const currentIndex = document.location.pathname
+ .toString()
+ .split('/')
+ .findIndex((id) => id === item.id);
+ if (currentIndex > -1) {
+ dispatch({ type: MENU_OPEN, id: item.id });
+ }
+ // eslint-disable-next-line
+ }, [pathname]);
+
+ return (
+ 1 ? 'transparent !important' : 'inherit',
+ py: level > 1 ? 1 : 1.25,
+ pl: `${level * 24}px`
+ }}
+ selected={customization.isOpen.findIndex((id) => id === item.id) > -1}
+ onClick={() => itemHandler(item.id)}
+ >
+ {itemIcon}
+ id === item.id) > -1 ? 'h5' : 'body1'} color="inherit">
+ {item.title}
+
+ }
+ secondary={
+ item.caption && (
+
+ {item.caption}
+
+ )
+ }
+ />
+ {item.chip && (
+ {item.chip.avatar}}
+ />
+ )}
+
+ );
+};
+
+NavItem.propTypes = {
+ item: PropTypes.object,
+ level: PropTypes.number
+};
+
+export default NavItem;
diff --git a/web/berry/src/layout/MainLayout/Sidebar/MenuList/index.js b/web/berry/src/layout/MainLayout/Sidebar/MenuList/index.js
new file mode 100644
index 00000000..4872057a
--- /dev/null
+++ b/web/berry/src/layout/MainLayout/Sidebar/MenuList/index.js
@@ -0,0 +1,36 @@
+// material-ui
+import { Typography } from '@mui/material';
+
+// project imports
+import NavGroup from './NavGroup';
+import menuItem from 'menu-items';
+import { isAdmin } from 'utils/common';
+
+// ==============================|| SIDEBAR MENU LIST ||============================== //
+const MenuList = () => {
+ const userIsAdmin = isAdmin();
+
+ return (
+ <>
+ {menuItem.items.map((item) => {
+ if (item.type !== 'group') {
+ return (
+
+ Menu Items Error
+
+ );
+ }
+
+ const filteredChildren = item.children.filter((child) => !child.isAdmin || userIsAdmin);
+
+ if (filteredChildren.length === 0) {
+ return null;
+ }
+
+ return ;
+ })}
+ >
+ );
+};
+
+export default MenuList;
diff --git a/web/berry/src/layout/MainLayout/Sidebar/index.js b/web/berry/src/layout/MainLayout/Sidebar/index.js
new file mode 100644
index 00000000..e3c6d12d
--- /dev/null
+++ b/web/berry/src/layout/MainLayout/Sidebar/index.js
@@ -0,0 +1,94 @@
+import PropTypes from 'prop-types';
+
+// material-ui
+import { useTheme } from '@mui/material/styles';
+import { Box, Chip, Drawer, Stack, useMediaQuery } from '@mui/material';
+
+// third-party
+import PerfectScrollbar from 'react-perfect-scrollbar';
+import { BrowserView, MobileView } from 'react-device-detect';
+
+// project imports
+import MenuList from './MenuList';
+import LogoSection from '../LogoSection';
+import MenuCard from './MenuCard';
+import { drawerWidth } from 'store/constant';
+
+// ==============================|| SIDEBAR DRAWER ||============================== //
+
+const Sidebar = ({ drawerOpen, drawerToggle, window }) => {
+ const theme = useTheme();
+ const matchUpMd = useMediaQuery(theme.breakpoints.up('md'));
+
+ const drawer = (
+ <>
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ >
+ );
+
+ const container = window !== undefined ? () => window.document.body : undefined;
+
+ return (
+
+
+ {drawer}
+
+
+ );
+};
+
+Sidebar.propTypes = {
+ drawerOpen: PropTypes.bool,
+ drawerToggle: PropTypes.func,
+ window: PropTypes.object
+};
+
+export default Sidebar;
diff --git a/web/berry/src/layout/MainLayout/index.js b/web/berry/src/layout/MainLayout/index.js
new file mode 100644
index 00000000..973a167b
--- /dev/null
+++ b/web/berry/src/layout/MainLayout/index.js
@@ -0,0 +1,103 @@
+import { useDispatch, useSelector } from 'react-redux';
+import { Outlet } from 'react-router-dom';
+import AuthGuard from 'utils/route-guard/AuthGuard';
+
+// material-ui
+import { styled, useTheme } from '@mui/material/styles';
+import { AppBar, Box, CssBaseline, Toolbar, useMediaQuery } from '@mui/material';
+import AdminContainer from 'ui-component/AdminContainer';
+
+// project imports
+import Breadcrumbs from 'ui-component/extended/Breadcrumbs';
+import Header from './Header';
+import Sidebar from './Sidebar';
+import navigation from 'menu-items';
+import { drawerWidth } from 'store/constant';
+import { SET_MENU } from 'store/actions';
+
+// assets
+import { IconChevronRight } from '@tabler/icons-react';
+
+// styles
+const Main = styled('main', { shouldForwardProp: (prop) => prop !== 'open' })(({ theme, open }) => ({
+ ...theme.typography.mainContent,
+ borderBottomLeftRadius: 0,
+ borderBottomRightRadius: 0,
+ transition: theme.transitions.create(
+ 'margin',
+ open
+ ? {
+ easing: theme.transitions.easing.easeOut,
+ duration: theme.transitions.duration.enteringScreen
+ }
+ : {
+ easing: theme.transitions.easing.sharp,
+ duration: theme.transitions.duration.leavingScreen
+ }
+ ),
+ [theme.breakpoints.up('md')]: {
+ marginLeft: open ? 0 : -(drawerWidth - 20),
+ width: `calc(100% - ${drawerWidth}px)`
+ },
+ [theme.breakpoints.down('md')]: {
+ marginLeft: '20px',
+ width: `calc(100% - ${drawerWidth}px)`,
+ padding: '16px'
+ },
+ [theme.breakpoints.down('sm')]: {
+ marginLeft: '10px',
+ width: `calc(100% - ${drawerWidth}px)`,
+ padding: '16px',
+ marginRight: '10px'
+ }
+}));
+
+// ==============================|| MAIN LAYOUT ||============================== //
+
+const MainLayout = () => {
+ const theme = useTheme();
+ const matchDownMd = useMediaQuery(theme.breakpoints.down('md'));
+ // Handle left drawer
+ const leftDrawerOpened = useSelector((state) => state.customization.opened);
+ const dispatch = useDispatch();
+ const handleLeftDrawerToggle = () => {
+ dispatch({ type: SET_MENU, opened: !leftDrawerOpened });
+ };
+
+ return (
+
+
+ {/* header */}
+
+
+
+
+
+
+ {/* drawer */}
+
+
+ {/* main content */}
+
+ {/* breadcrumb */}
+
+
+
+
+
+
+
+
+ );
+};
+
+export default MainLayout;
diff --git a/web/berry/src/layout/MinimalLayout/Header/index.js b/web/berry/src/layout/MinimalLayout/Header/index.js
new file mode 100644
index 00000000..4f61da60
--- /dev/null
+++ b/web/berry/src/layout/MinimalLayout/Header/index.js
@@ -0,0 +1,75 @@
+// material-ui
+import { useTheme } from "@mui/material/styles";
+import { Box, Button, Stack } from "@mui/material";
+import LogoSection from "layout/MainLayout/LogoSection";
+import { Link } from "react-router-dom";
+import { useLocation } from "react-router-dom";
+import { useSelector } from "react-redux";
+
+// ==============================|| MAIN NAVBAR / HEADER ||============================== //
+
+const Header = () => {
+ const theme = useTheme();
+ const { pathname } = useLocation();
+ const account = useSelector((state) => state.account);
+
+ return (
+ <>
+
+
+
+
+
+
+
+
+
+
+
+ {account.user ? (
+
+ ) : (
+
+ )}
+
+ >
+ );
+};
+
+export default Header;
diff --git a/web/berry/src/layout/MinimalLayout/index.js b/web/berry/src/layout/MinimalLayout/index.js
new file mode 100644
index 00000000..c2919c6d
--- /dev/null
+++ b/web/berry/src/layout/MinimalLayout/index.js
@@ -0,0 +1,39 @@
+import { Outlet } from 'react-router-dom';
+import { useTheme } from '@mui/material/styles';
+import { AppBar, Box, CssBaseline, Toolbar } from '@mui/material';
+import Header from './Header';
+import Footer from 'ui-component/Footer';
+
+// ==============================|| MINIMAL LAYOUT ||============================== //
+
+const MinimalLayout = () => {
+ const theme = useTheme();
+
+ return (
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ );
+};
+
+export default MinimalLayout;
diff --git a/web/berry/src/layout/NavMotion.js b/web/berry/src/layout/NavMotion.js
new file mode 100644
index 00000000..d82f7e4f
--- /dev/null
+++ b/web/berry/src/layout/NavMotion.js
@@ -0,0 +1,39 @@
+import PropTypes from 'prop-types';
+import { motion } from 'framer-motion';
+
+// ==============================|| ANIMATION FOR CONTENT ||============================== //
+
+const NavMotion = ({ children }) => {
+ const motionVariants = {
+ initial: {
+ opacity: 0,
+ scale: 0.99
+ },
+ in: {
+ opacity: 1,
+ scale: 1
+ },
+ out: {
+ opacity: 0,
+ scale: 1.01
+ }
+ };
+
+ const motionTransition = {
+ type: 'tween',
+ ease: 'anticipate',
+ duration: 0.4
+ };
+
+ return (
+
+ {children}
+
+ );
+};
+
+NavMotion.propTypes = {
+ children: PropTypes.node
+};
+
+export default NavMotion;
diff --git a/web/berry/src/layout/NavigationScroll.js b/web/berry/src/layout/NavigationScroll.js
new file mode 100644
index 00000000..89b22e65
--- /dev/null
+++ b/web/berry/src/layout/NavigationScroll.js
@@ -0,0 +1,26 @@
+import PropTypes from 'prop-types';
+import { useEffect } from 'react';
+import { useLocation } from 'react-router-dom';
+
+// ==============================|| NAVIGATION SCROLL TO TOP ||============================== //
+
+const NavigationScroll = ({ children }) => {
+ const location = useLocation();
+ const { pathname } = location;
+
+ useEffect(() => {
+ window.scrollTo({
+ top: 0,
+ left: 0,
+ behavior: 'smooth'
+ });
+ }, [pathname]);
+
+ return children || null;
+};
+
+NavigationScroll.propTypes = {
+ children: PropTypes.node
+};
+
+export default NavigationScroll;
diff --git a/web/berry/src/menu-items/index.js b/web/berry/src/menu-items/index.js
new file mode 100644
index 00000000..e732f8af
--- /dev/null
+++ b/web/berry/src/menu-items/index.js
@@ -0,0 +1,18 @@
+import panel from './panel';
+
+// ==============================|| MENU ITEMS ||============================== //
+
+const menuItems = {
+ items: [panel],
+ urlMap: {}
+};
+
+// Initialize urlMap
+menuItems.urlMap = menuItems.items.reduce((map, item) => {
+ item.children.forEach((child) => {
+ map[child.url] = child;
+ });
+ return map;
+}, {});
+
+export default menuItems;
diff --git a/web/berry/src/menu-items/panel.js b/web/berry/src/menu-items/panel.js
new file mode 100644
index 00000000..556b157f
--- /dev/null
+++ b/web/berry/src/menu-items/panel.js
@@ -0,0 +1,104 @@
+// assets
+import {
+ IconDashboard,
+ IconSitemap,
+ IconArticle,
+ IconCoin,
+ IconAdjustments,
+ IconKey,
+ IconGardenCart,
+ IconUser,
+ IconUserScan
+} from '@tabler/icons-react';
+
+// constant
+const icons = { IconDashboard, IconSitemap, IconArticle, IconCoin, IconAdjustments, IconKey, IconGardenCart, IconUser, IconUserScan };
+
+// ==============================|| DASHBOARD MENU ITEMS ||============================== //
+
+const panel = {
+ id: 'panel',
+ type: 'group',
+ children: [
+ {
+ id: 'dashboard',
+ title: '总览',
+ type: 'item',
+ url: '/panel/dashboard',
+ icon: icons.IconDashboard,
+ breadcrumbs: false,
+ isAdmin: false
+ },
+ {
+ id: 'channel',
+ title: '渠道',
+ type: 'item',
+ url: '/panel/channel',
+ icon: icons.IconSitemap,
+ breadcrumbs: false,
+ isAdmin: true
+ },
+ {
+ id: 'token',
+ title: '令牌',
+ type: 'item',
+ url: '/panel/token',
+ icon: icons.IconKey,
+ breadcrumbs: false
+ },
+ {
+ id: 'log',
+ title: '日志',
+ type: 'item',
+ url: '/panel/log',
+ icon: icons.IconArticle,
+ breadcrumbs: false
+ },
+ {
+ id: 'redemption',
+ title: '兑换',
+ type: 'item',
+ url: '/panel/redemption',
+ icon: icons.IconCoin,
+ breadcrumbs: false,
+ isAdmin: true
+ },
+ {
+ id: 'topup',
+ title: '充值',
+ type: 'item',
+ url: '/panel/topup',
+ icon: icons.IconGardenCart,
+ breadcrumbs: false
+ },
+ {
+ id: 'user',
+ title: '用户',
+ type: 'item',
+ url: '/panel/user',
+ icon: icons.IconUser,
+ breadcrumbs: false,
+ isAdmin: true
+ },
+ {
+ id: 'profile',
+ title: '我的',
+ type: 'item',
+ url: '/panel/profile',
+ icon: icons.IconUserScan,
+ breadcrumbs: false,
+ isAdmin: false
+ },
+ {
+ id: 'setting',
+ title: '设置',
+ type: 'item',
+ url: '/panel/setting',
+ icon: icons.IconAdjustments,
+ breadcrumbs: false,
+ isAdmin: true
+ }
+ ]
+};
+
+export default panel;
diff --git a/web/berry/src/routes/MainRoutes.js b/web/berry/src/routes/MainRoutes.js
new file mode 100644
index 00000000..74f7e4c2
--- /dev/null
+++ b/web/berry/src/routes/MainRoutes.js
@@ -0,0 +1,73 @@
+import { lazy } from 'react';
+
+// project imports
+import MainLayout from 'layout/MainLayout';
+import Loadable from 'ui-component/Loadable';
+
+const Channel = Loadable(lazy(() => import('views/Channel')));
+const Log = Loadable(lazy(() => import('views/Log')));
+const Redemption = Loadable(lazy(() => import('views/Redemption')));
+const Setting = Loadable(lazy(() => import('views/Setting')));
+const Token = Loadable(lazy(() => import('views/Token')));
+const Topup = Loadable(lazy(() => import('views/Topup')));
+const User = Loadable(lazy(() => import('views/User')));
+const Profile = Loadable(lazy(() => import('views/Profile')));
+const NotFoundView = Loadable(lazy(() => import('views/Error')));
+
+// dashboard routing
+const Dashboard = Loadable(lazy(() => import('views/Dashboard')));
+
+// ==============================|| MAIN ROUTING ||============================== //
+
+const MainRoutes = {
+ path: '/panel',
+ element: ,
+ children: [
+ {
+ path: '',
+ element:
+ },
+ {
+ path: 'dashboard',
+ element:
+ },
+ {
+ path: 'channel',
+ element:
+ },
+ {
+ path: 'log',
+ element:
+ },
+ {
+ path: 'redemption',
+ element:
+ },
+ {
+ path: 'setting',
+ element:
+ },
+ {
+ path: 'token',
+ element:
+ },
+ {
+ path: 'topup',
+ element:
+ },
+ {
+ path: 'user',
+ element:
+ },
+ {
+ path: 'profile',
+ element:
+ },
+ {
+ path: '404',
+ element:
+ }
+ ]
+};
+
+export default MainRoutes;
diff --git a/web/berry/src/routes/OtherRoutes.js b/web/berry/src/routes/OtherRoutes.js
new file mode 100644
index 00000000..085c4add
--- /dev/null
+++ b/web/berry/src/routes/OtherRoutes.js
@@ -0,0 +1,58 @@
+import { lazy } from 'react';
+
+// project imports
+import Loadable from 'ui-component/Loadable';
+import MinimalLayout from 'layout/MinimalLayout';
+
+// login option 3 routing
+const AuthLogin = Loadable(lazy(() => import('views/Authentication/Auth/Login')));
+const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register')));
+const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth')));
+const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword')));
+const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword')));
+const Home = Loadable(lazy(() => import('views/Home')));
+const About = Loadable(lazy(() => import('views/About')));
+const NotFoundView = Loadable(lazy(() => import('views/Error')));
+
+// ==============================|| AUTHENTICATION ROUTING ||============================== //
+
+const OtherRoutes = {
+ path: '/',
+ element: ,
+ children: [
+ {
+ path: '',
+ element:
+ },
+ {
+ path: '/about',
+ element:
+ },
+ {
+ path: '/login',
+ element:
+ },
+ {
+ path: '/register',
+ element:
+ },
+ {
+ path: '/reset',
+ element:
+ },
+ {
+ path: '/user/reset',
+ element:
+ },
+ {
+ path: '/oauth/github',
+ element:
+ },
+ {
+ path: '/404',
+ element:
+ }
+ ]
+};
+
+export default OtherRoutes;
diff --git a/web/berry/src/routes/index.js b/web/berry/src/routes/index.js
new file mode 100644
index 00000000..e77c610a
--- /dev/null
+++ b/web/berry/src/routes/index.js
@@ -0,0 +1,11 @@
+import { useRoutes } from 'react-router-dom';
+
+// routes
+import MainRoutes from './MainRoutes';
+import OtherRoutes from './OtherRoutes';
+
+// ==============================|| ROUTING RENDER ||============================== //
+
+export default function ThemeRoutes() {
+ return useRoutes([MainRoutes, OtherRoutes]);
+}
diff --git a/web/berry/src/serviceWorker.js b/web/berry/src/serviceWorker.js
new file mode 100644
index 00000000..02320234
--- /dev/null
+++ b/web/berry/src/serviceWorker.js
@@ -0,0 +1,128 @@
+// This optional code is used to register a service worker.
+// register() is not called by default.
+
+// This lets the app load faster on subsequent visits in production, and gives
+// it offline capabilities. However, it also means that developers (and users)
+// will only see deployed updates on subsequent visits to a page, after all the
+// existing tabs open on the page have been closed, since previously cached
+// resources are updated in the background.
+
+// To learn more about the benefits of this model and instructions on how to
+// opt-in, read https://bit.ly/CRA-PWA
+
+const isLocalhost = Boolean(
+ window.location.hostname === 'localhost' ||
+ // [::1] is the IPv6 localhost address.
+ window.location.hostname === '[::1]' ||
+ // 127.0.0.0/8 are considered localhost for IPv4.
+ window.location.hostname.match(/^127(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}$/)
+);
+
+function registerValidSW(swUrl, config) {
+ navigator.serviceWorker
+ .register(swUrl)
+ .then((registration) => {
+ registration.onupdatefound = () => {
+ const installingWorker = registration.installing;
+ if (installingWorker == null) {
+ return;
+ }
+ installingWorker.onstatechange = () => {
+ if (installingWorker.state === 'installed') {
+ if (navigator.serviceWorker.controller) {
+ // At this point, the updated precached content has been fetched,
+ // but the previous service worker will still serve the older
+ // content until all client tabs are closed.
+ console.log('New content is available and will be used when all tabs for this page are closed. See https://bit.ly/CRA-PWA.');
+
+ // Execute callback
+ if (config && config.onUpdate) {
+ config.onUpdate(registration);
+ }
+ } else {
+ // At this point, everything has been precached.
+ // It's the perfect time to display a
+ // "Content is cached for offline use." message.
+ console.log('Content is cached for offline use.');
+
+ // Execute callback
+ if (config && config.onSuccess) {
+ config.onSuccess(registration);
+ }
+ }
+ }
+ };
+ };
+ })
+ .catch((error) => {
+ console.error('Error during service worker registration:', error);
+ });
+}
+
+function checkValidServiceWorker(swUrl, config) {
+ // Check if the service worker can be found. If it can't reload the page.
+ fetch(swUrl, {
+ headers: { 'Service-Worker': 'script' }
+ })
+ .then((response) => {
+ // Ensure service worker exists, and that we really are getting a JS file.
+ const contentType = response.headers.get('content-type');
+ if (response.status === 404 || (contentType != null && contentType.indexOf('javascript') === -1)) {
+ // No service worker found. Probably a different app. Reload the page.
+ navigator.serviceWorker.ready.then((registration) => {
+ registration.unregister().then(() => {
+ window.location.reload();
+ });
+ });
+ } else {
+ // Service worker found. Proceed as normal.
+ registerValidSW(swUrl, config);
+ }
+ })
+ .catch(() => {
+ console.log('No internet connection found. App is running in offline mode.');
+ });
+}
+
+export function register(config) {
+ if (process.env.NODE_ENV === 'production' && 'serviceWorker' in navigator) {
+ // The URL constructor is available in all browsers that support SW.
+ const publicUrl = new URL(process.env.PUBLIC_URL, window.location.href);
+ if (publicUrl.origin !== window.location.origin) {
+ // Our service worker won't work if PUBLIC_URL is on a different origin
+ // from what our page is served on. This might happen if a CDN is used to
+ // serve assets; see https://github.com/facebook/create-react-app/issues/2374
+ return;
+ }
+
+ window.addEventListener('load', () => {
+ const swUrl = `${process.env.PUBLIC_URL}/service-worker.js`;
+
+ if (isLocalhost) {
+ // This is running on localhost. Let's check if a service worker still exists or not.
+ checkValidServiceWorker(swUrl, config);
+
+ // Add some additional logging to localhost, pointing developers to the
+ // service worker/PWA documentation.
+ navigator.serviceWorker.ready.then(() => {
+ console.log('This web app is being served cache-first by a service worker. To learn more, visit https://bit.ly/CRA-PWA');
+ });
+ } else {
+ // Is not localhost. Just register service worker
+ registerValidSW(swUrl, config);
+ }
+ });
+ }
+}
+
+export function unregister() {
+ if ('serviceWorker' in navigator) {
+ navigator.serviceWorker.ready
+ .then((registration) => {
+ registration.unregister();
+ })
+ .catch((error) => {
+ console.error(error.message);
+ });
+ }
+}
diff --git a/web/berry/src/store/accountReducer.js b/web/berry/src/store/accountReducer.js
new file mode 100644
index 00000000..5414bb97
--- /dev/null
+++ b/web/berry/src/store/accountReducer.js
@@ -0,0 +1,24 @@
+import * as actionTypes from './actions';
+
+export const initialState = {
+ user: undefined
+};
+
+const accountReducer = (state = initialState, action) => {
+ switch (action.type) {
+ case actionTypes.LOGIN:
+ return {
+ ...state,
+ user: action.payload
+ };
+ case actionTypes.LOGOUT:
+ return {
+ ...state,
+ user: undefined
+ };
+ default:
+ return state;
+ }
+};
+
+export default accountReducer;
diff --git a/web/berry/src/store/actions.js b/web/berry/src/store/actions.js
new file mode 100644
index 00000000..221e8578
--- /dev/null
+++ b/web/berry/src/store/actions.js
@@ -0,0 +1,9 @@
+// action - customization reducer
+export const SET_MENU = '@customization/SET_MENU';
+export const MENU_TOGGLE = '@customization/MENU_TOGGLE';
+export const MENU_OPEN = '@customization/MENU_OPEN';
+export const SET_FONT_FAMILY = '@customization/SET_FONT_FAMILY';
+export const SET_BORDER_RADIUS = '@customization/SET_BORDER_RADIUS';
+export const SET_SITE_INFO = '@siteInfo/SET_SITE_INFO';
+export const LOGIN = '@account/LOGIN';
+export const LOGOUT = '@account/LOGOUT';
diff --git a/web/berry/src/store/constant.js b/web/berry/src/store/constant.js
new file mode 100644
index 00000000..75da000c
--- /dev/null
+++ b/web/berry/src/store/constant.js
@@ -0,0 +1,4 @@
+// theme constant
+export const gridSpacing = 3;
+export const drawerWidth = 260;
+export const appDrawerWidth = 320;
diff --git a/web/berry/src/store/customizationReducer.js b/web/berry/src/store/customizationReducer.js
new file mode 100644
index 00000000..bd8e5f00
--- /dev/null
+++ b/web/berry/src/store/customizationReducer.js
@@ -0,0 +1,46 @@
+// project imports
+import config from 'config';
+
+// action - state management
+import * as actionTypes from './actions';
+
+export const initialState = {
+ isOpen: [], // for active default menu
+ defaultId: 'default',
+ fontFamily: config.fontFamily,
+ borderRadius: config.borderRadius,
+ opened: true
+};
+
+// ==============================|| CUSTOMIZATION REDUCER ||============================== //
+
+const customizationReducer = (state = initialState, action) => {
+ let id;
+ switch (action.type) {
+ case actionTypes.MENU_OPEN:
+ id = action.id;
+ return {
+ ...state,
+ isOpen: [id]
+ };
+ case actionTypes.SET_MENU:
+ return {
+ ...state,
+ opened: action.opened
+ };
+ case actionTypes.SET_FONT_FAMILY:
+ return {
+ ...state,
+ fontFamily: action.fontFamily
+ };
+ case actionTypes.SET_BORDER_RADIUS:
+ return {
+ ...state,
+ borderRadius: action.borderRadius
+ };
+ default:
+ return state;
+ }
+};
+
+export default customizationReducer;
diff --git a/web/berry/src/store/index.js b/web/berry/src/store/index.js
new file mode 100644
index 00000000..b9ec2a68
--- /dev/null
+++ b/web/berry/src/store/index.js
@@ -0,0 +1,9 @@
+import { createStore } from 'redux';
+import reducer from './reducer';
+
+// ==============================|| REDUX - MAIN STORE ||============================== //
+
+const store = createStore(reducer);
+const persister = 'Free';
+
+export { store, persister };
diff --git a/web/berry/src/store/reducer.js b/web/berry/src/store/reducer.js
new file mode 100644
index 00000000..220f585f
--- /dev/null
+++ b/web/berry/src/store/reducer.js
@@ -0,0 +1,16 @@
+import { combineReducers } from 'redux';
+
+// reducer import
+import customizationReducer from './customizationReducer';
+import accountReducer from './accountReducer';
+import siteInfoReducer from './siteInfoReducer';
+
+// ==============================|| COMBINE REDUCER ||============================== //
+
+const reducer = combineReducers({
+ customization: customizationReducer,
+ account: accountReducer,
+ siteInfo: siteInfoReducer
+});
+
+export default reducer;
diff --git a/web/berry/src/store/siteInfoReducer.js b/web/berry/src/store/siteInfoReducer.js
new file mode 100644
index 00000000..e14bc245
--- /dev/null
+++ b/web/berry/src/store/siteInfoReducer.js
@@ -0,0 +1,18 @@
+import config from 'config';
+import * as actionTypes from './actions';
+
+export const initialState = config.siteInfo;
+
+const siteInfoReducer = (state = initialState, action) => {
+ switch (action.type) {
+ case actionTypes.SET_SITE_INFO:
+ return {
+ ...state,
+ ...action.payload
+ };
+ default:
+ return state;
+ }
+};
+
+export default siteInfoReducer;
diff --git a/web/berry/src/themes/compStyleOverride.js b/web/berry/src/themes/compStyleOverride.js
new file mode 100644
index 00000000..b6e87e01
--- /dev/null
+++ b/web/berry/src/themes/compStyleOverride.js
@@ -0,0 +1,256 @@
+export default function componentStyleOverrides(theme) {
+ const bgColor = theme.colors?.grey50;
+ return {
+ MuiButton: {
+ styleOverrides: {
+ root: {
+ fontWeight: 500,
+ borderRadius: '4px',
+ '&.Mui-disabled': {
+ color: theme.colors?.grey600
+ }
+ }
+ }
+ },
+ MuiMenuItem: {
+ styleOverrides: {
+ root: {
+ '&:hover': {
+ backgroundColor: theme.colors?.grey100
+ }
+ }
+ }
+ }, //MuiAutocomplete-popper MuiPopover-root
+ MuiAutocomplete: {
+ styleOverrides: {
+ popper: {
+ // 继承 MuiPopover-root
+ boxShadow: '0px 5px 5px -3px rgba(0,0,0,0.2),0px 8px 10px 1px rgba(0,0,0,0.14),0px 3px 14px 2px rgba(0,0,0,0.12)',
+ borderRadius: '12px',
+ color: '#364152'
+ },
+ listbox: {
+ // 继承 MuiPopover-root
+ padding: '0px',
+ paddingTop: '8px',
+ paddingBottom: '8px'
+ },
+ option: {
+ fontSize: '16px',
+ fontWeight: '400',
+ lineHeight: '1.334em',
+ alignItems: 'center',
+ paddingTop: '6px',
+ paddingBottom: '6px',
+ paddingLeft: '16px',
+ paddingRight: '16px'
+ }
+ }
+ },
+ MuiIconButton: {
+ styleOverrides: {
+ root: {
+ color: theme.darkTextPrimary,
+ '&:hover': {
+ backgroundColor: theme.colors?.grey200
+ }
+ }
+ }
+ },
+ MuiPaper: {
+ defaultProps: {
+ elevation: 0
+ },
+ styleOverrides: {
+ root: {
+ backgroundImage: 'none'
+ },
+ rounded: {
+ borderRadius: `${theme?.customization?.borderRadius}px`
+ }
+ }
+ },
+ MuiCardHeader: {
+ styleOverrides: {
+ root: {
+ color: theme.colors?.textDark,
+ padding: '24px'
+ },
+ title: {
+ fontSize: '1.125rem'
+ }
+ }
+ },
+ MuiCardContent: {
+ styleOverrides: {
+ root: {
+ padding: '24px'
+ }
+ }
+ },
+ MuiCardActions: {
+ styleOverrides: {
+ root: {
+ padding: '24px'
+ }
+ }
+ },
+ MuiListItemButton: {
+ styleOverrides: {
+ root: {
+ color: theme.darkTextPrimary,
+ paddingTop: '10px',
+ paddingBottom: '10px',
+ '&.Mui-selected': {
+ color: theme.menuSelected,
+ backgroundColor: theme.menuSelectedBack,
+ '&:hover': {
+ backgroundColor: theme.menuSelectedBack
+ },
+ '& .MuiListItemIcon-root': {
+ color: theme.menuSelected
+ }
+ },
+ '&:hover': {
+ backgroundColor: theme.menuSelectedBack,
+ color: theme.menuSelected,
+ '& .MuiListItemIcon-root': {
+ color: theme.menuSelected
+ }
+ }
+ }
+ }
+ },
+ MuiListItemIcon: {
+ styleOverrides: {
+ root: {
+ color: theme.darkTextPrimary,
+ minWidth: '36px'
+ }
+ }
+ },
+ MuiListItemText: {
+ styleOverrides: {
+ primary: {
+ color: theme.textDark
+ }
+ }
+ },
+ MuiInputBase: {
+ styleOverrides: {
+ input: {
+ color: theme.textDark,
+ '&::placeholder': {
+ color: theme.darkTextSecondary,
+ fontSize: '0.875rem'
+ }
+ }
+ }
+ },
+ MuiOutlinedInput: {
+ styleOverrides: {
+ root: {
+ background: bgColor,
+ borderRadius: `${theme?.customization?.borderRadius}px`,
+ '& .MuiOutlinedInput-notchedOutline': {
+ borderColor: theme.colors?.grey400
+ },
+ '&:hover $notchedOutline': {
+ borderColor: theme.colors?.primaryLight
+ },
+ '&.MuiInputBase-multiline': {
+ padding: 1
+ }
+ },
+ input: {
+ fontWeight: 500,
+ background: bgColor,
+ padding: '15.5px 14px',
+ borderRadius: `${theme?.customization?.borderRadius}px`,
+ '&.MuiInputBase-inputSizeSmall': {
+ padding: '10px 14px',
+ '&.MuiInputBase-inputAdornedStart': {
+ paddingLeft: 0
+ }
+ }
+ },
+ inputAdornedStart: {
+ paddingLeft: 4
+ },
+ notchedOutline: {
+ borderRadius: `${theme?.customization?.borderRadius}px`
+ }
+ }
+ },
+ MuiSlider: {
+ styleOverrides: {
+ root: {
+ '&.Mui-disabled': {
+ color: theme.colors?.grey300
+ }
+ },
+ mark: {
+ backgroundColor: theme.paper,
+ width: '4px'
+ },
+ valueLabel: {
+ color: theme?.colors?.primaryLight
+ }
+ }
+ },
+ MuiDivider: {
+ styleOverrides: {
+ root: {
+ borderColor: theme.divider,
+ opacity: 1
+ }
+ }
+ },
+ MuiAvatar: {
+ styleOverrides: {
+ root: {
+ color: theme.colors?.primaryDark,
+ background: theme.colors?.primary200
+ }
+ }
+ },
+ MuiChip: {
+ styleOverrides: {
+ root: {
+ '&.MuiChip-deletable .MuiChip-deleteIcon': {
+ color: 'inherit'
+ }
+ }
+ }
+ },
+ MuiTableCell: {
+ styleOverrides: {
+ root: {
+ borderBottom: '1px solid rgb(241, 243, 244)',
+ textAlign: 'center'
+ },
+ head: {
+ color: theme.darkTextSecondary,
+ backgroundColor: 'rgb(244, 246, 248)'
+ }
+ }
+ },
+ MuiTableRow: {
+ styleOverrides: {
+ root: {
+ '&:hover': {
+ backgroundColor: 'rgb(244, 246, 248)'
+ }
+ }
+ }
+ },
+ MuiTooltip: {
+ styleOverrides: {
+ tooltip: {
+ color: theme.paper,
+ background: theme.colors?.grey700
+ }
+ }
+ }
+ };
+}
diff --git a/web/berry/src/themes/index.js b/web/berry/src/themes/index.js
new file mode 100644
index 00000000..6e694aa6
--- /dev/null
+++ b/web/berry/src/themes/index.js
@@ -0,0 +1,55 @@
+import { createTheme } from '@mui/material/styles';
+
+// assets
+import colors from 'assets/scss/_themes-vars.module.scss';
+
+// project imports
+import componentStyleOverrides from './compStyleOverride';
+import themePalette from './palette';
+import themeTypography from './typography';
+
+/**
+ * Represent theme style and structure as per Material-UI
+ * @param {JsonObject} customization customization parameter object
+ */
+
+export const theme = (customization) => {
+ const color = colors;
+
+ const themeOption = {
+ colors: color,
+ heading: color.grey900,
+ paper: color.paper,
+ backgroundDefault: color.paper,
+ background: color.primaryLight,
+ darkTextPrimary: color.grey700,
+ darkTextSecondary: color.grey500,
+ textDark: color.grey900,
+ menuSelected: color.secondaryDark,
+ menuSelectedBack: color.secondaryLight,
+ divider: color.grey200,
+ customization
+ };
+
+ const themeOptions = {
+ direction: 'ltr',
+ palette: themePalette(themeOption),
+ mixins: {
+ toolbar: {
+ minHeight: '48px',
+ padding: '16px',
+ '@media (min-width: 600px)': {
+ minHeight: '48px'
+ }
+ }
+ },
+ typography: themeTypography(themeOption)
+ };
+
+ const themes = createTheme(themeOptions);
+ themes.components = componentStyleOverrides(themeOption);
+
+ return themes;
+};
+
+export default theme;
diff --git a/web/berry/src/themes/palette.js b/web/berry/src/themes/palette.js
new file mode 100644
index 00000000..09768555
--- /dev/null
+++ b/web/berry/src/themes/palette.js
@@ -0,0 +1,73 @@
+/**
+ * Color intention that you want to used in your theme
+ * @param {JsonObject} theme Theme customization object
+ */
+
+export default function themePalette(theme) {
+ return {
+ mode: 'light',
+ common: {
+ black: theme.colors?.darkPaper
+ },
+ primary: {
+ light: theme.colors?.primaryLight,
+ main: theme.colors?.primaryMain,
+ dark: theme.colors?.primaryDark,
+ 200: theme.colors?.primary200,
+ 800: theme.colors?.primary800
+ },
+ secondary: {
+ light: theme.colors?.secondaryLight,
+ main: theme.colors?.secondaryMain,
+ dark: theme.colors?.secondaryDark,
+ 200: theme.colors?.secondary200,
+ 800: theme.colors?.secondary800
+ },
+ error: {
+ light: theme.colors?.errorLight,
+ main: theme.colors?.errorMain,
+ dark: theme.colors?.errorDark
+ },
+ orange: {
+ light: theme.colors?.orangeLight,
+ main: theme.colors?.orangeMain,
+ dark: theme.colors?.orangeDark
+ },
+ warning: {
+ light: theme.colors?.warningLight,
+ main: theme.colors?.warningMain,
+ dark: theme.colors?.warningDark
+ },
+ success: {
+ light: theme.colors?.successLight,
+ 200: theme.colors?.success200,
+ main: theme.colors?.successMain,
+ dark: theme.colors?.successDark
+ },
+ grey: {
+ 50: theme.colors?.grey50,
+ 100: theme.colors?.grey100,
+ 500: theme.darkTextSecondary,
+ 600: theme.heading,
+ 700: theme.darkTextPrimary,
+ 900: theme.textDark
+ },
+ dark: {
+ light: theme.colors?.darkTextPrimary,
+ main: theme.colors?.darkLevel1,
+ dark: theme.colors?.darkLevel2,
+ 800: theme.colors?.darkBackground,
+ 900: theme.colors?.darkPaper
+ },
+ text: {
+ primary: theme.darkTextPrimary,
+ secondary: theme.darkTextSecondary,
+ dark: theme.textDark,
+ hint: theme.colors?.grey100
+ },
+ background: {
+ paper: theme.paper,
+ default: theme.backgroundDefault
+ }
+ };
+}
diff --git a/web/berry/src/themes/typography.js b/web/berry/src/themes/typography.js
new file mode 100644
index 00000000..24bfabb9
--- /dev/null
+++ b/web/berry/src/themes/typography.js
@@ -0,0 +1,137 @@
+/**
+ * Typography used in theme
+ * @param {JsonObject} theme theme customization object
+ */
+
+export default function themeTypography(theme) {
+ return {
+ fontFamily: theme?.customization?.fontFamily,
+ h6: {
+ fontWeight: 500,
+ color: theme.heading,
+ fontSize: '0.75rem'
+ },
+ h5: {
+ fontSize: '0.875rem',
+ color: theme.heading,
+ fontWeight: 500
+ },
+ h4: {
+ fontSize: '1rem',
+ color: theme.heading,
+ fontWeight: 600
+ },
+ h3: {
+ fontSize: '1.25rem',
+ color: theme.heading,
+ fontWeight: 600
+ },
+ h2: {
+ fontSize: '1.5rem',
+ color: theme.heading,
+ fontWeight: 700
+ },
+ h1: {
+ fontSize: '2.125rem',
+ color: theme.heading,
+ fontWeight: 700
+ },
+ subtitle1: {
+ fontSize: '0.875rem',
+ fontWeight: 500,
+ color: theme.textDark
+ },
+ subtitle2: {
+ fontSize: '0.75rem',
+ fontWeight: 400,
+ color: theme.darkTextSecondary
+ },
+ caption: {
+ fontSize: '0.75rem',
+ color: theme.darkTextSecondary,
+ fontWeight: 400
+ },
+ body1: {
+ fontSize: '0.875rem',
+ fontWeight: 400,
+ lineHeight: '1.334em'
+ },
+ body2: {
+ letterSpacing: '0em',
+ fontWeight: 400,
+ lineHeight: '1.5em',
+ color: theme.darkTextPrimary
+ },
+ button: {
+ textTransform: 'capitalize'
+ },
+ customInput: {
+ marginTop: 1,
+ marginBottom: 1,
+ '& > label': {
+ top: 23,
+ left: 0,
+ color: theme.grey500,
+ '&[data-shrink="false"]': {
+ top: 5
+ }
+ },
+ '& > div > input': {
+ padding: '30.5px 14px 11.5px !important'
+ },
+ '& legend': {
+ display: 'none'
+ },
+ '& fieldset': {
+ top: 0
+ }
+ },
+ otherInput: {
+ marginTop: 1,
+ marginBottom: 1
+ },
+ mainContent: {
+ backgroundColor: theme.background,
+ width: '100%',
+ minHeight: 'calc(100vh - 88px)',
+ flexGrow: 1,
+ padding: '20px',
+ marginTop: '88px',
+ marginRight: '20px',
+ borderRadius: `${theme?.customization?.borderRadius}px`
+ },
+ menuCaption: {
+ fontSize: '0.875rem',
+ fontWeight: 500,
+ color: theme.heading,
+ padding: '6px',
+ textTransform: 'capitalize',
+ marginTop: '10px'
+ },
+ subMenuCaption: {
+ fontSize: '0.6875rem',
+ fontWeight: 500,
+ color: theme.darkTextSecondary,
+ textTransform: 'capitalize'
+ },
+ commonAvatar: {
+ cursor: 'pointer',
+ borderRadius: '8px'
+ },
+ smallAvatar: {
+ width: '22px',
+ height: '22px',
+ fontSize: '1rem'
+ },
+ mediumAvatar: {
+ width: '34px',
+ height: '34px',
+ fontSize: '1.2rem'
+ },
+ largeAvatar: {
+ width: '44px',
+ height: '44px',
+ fontSize: '1.5rem'
+ }
+ };
+}
diff --git a/web/berry/src/ui-component/AdminContainer.js b/web/berry/src/ui-component/AdminContainer.js
new file mode 100644
index 00000000..eff42a22
--- /dev/null
+++ b/web/berry/src/ui-component/AdminContainer.js
@@ -0,0 +1,11 @@
+import { styled } from '@mui/material/styles';
+import { Container } from '@mui/material';
+
+const AdminContainer = styled(Container)(({ theme }) => ({
+ [theme.breakpoints.down('md')]: {
+ paddingLeft: '0px',
+ paddingRight: '0px'
+ }
+}));
+
+export default AdminContainer;
diff --git a/web/berry/src/ui-component/Footer.js b/web/berry/src/ui-component/Footer.js
new file mode 100644
index 00000000..38f61993
--- /dev/null
+++ b/web/berry/src/ui-component/Footer.js
@@ -0,0 +1,37 @@
+// material-ui
+import { Link, Container, Box } from '@mui/material';
+import React from 'react';
+import { useSelector } from 'react-redux';
+
+// ==============================|| FOOTER - AUTHENTICATION 2 & 3 ||============================== //
+
+const Footer = () => {
+ const siteInfo = useSelector((state) => state.siteInfo);
+
+ return (
+
+
+ {siteInfo.footer_html ? (
+
+ ) : (
+ <>
+
+ {siteInfo.system_name} {process.env.REACT_APP_VERSION}{' '}
+
+ 由{' '}
+
+ JustSong
+ {' '}
+ 构建,主题 berry 来自{' '}
+
+ MartialBE
+ {' '},源代码遵循
+ MIT 协议
+ >
+ )}
+
+
+ );
+};
+
+export default Footer;
diff --git a/web/berry/src/ui-component/Label.js b/web/berry/src/ui-component/Label.js
new file mode 100644
index 00000000..715c6248
--- /dev/null
+++ b/web/berry/src/ui-component/Label.js
@@ -0,0 +1,158 @@
+/*
+ * Label.js
+ *
+ * This file uses code from the Minimal UI project, available at
+ * https://github.com/minimal-ui-kit/material-kit-react/blob/main/src/components/label/label.jsx
+ *
+ * Minimal UI is licensed under the MIT License. A copy of the license is included below:
+ *
+ * MIT License
+ *
+ * Copyright (c) 2021 Minimal UI (https://minimals.cc/)
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+import PropTypes from 'prop-types';
+import { forwardRef } from 'react';
+
+import Box from '@mui/material/Box';
+import { useTheme } from '@mui/material/styles';
+import { alpha, styled } from '@mui/material/styles';
+
+// ----------------------------------------------------------------------
+
+const Label = forwardRef(({ children, color = 'default', variant = 'soft', startIcon, endIcon, sx, ...other }, ref) => {
+ const theme = useTheme();
+
+ const iconStyles = {
+ width: 16,
+ height: 16,
+ '& svg, img': { width: 1, height: 1, objectFit: 'cover' }
+ };
+
+ return (
+
+ {startIcon && {startIcon} }
+
+ {children}
+
+ {endIcon && {endIcon} }
+
+ );
+});
+
+Label.propTypes = {
+ children: PropTypes.node,
+ endIcon: PropTypes.object,
+ startIcon: PropTypes.object,
+ sx: PropTypes.object,
+ variant: PropTypes.oneOf(['filled', 'outlined', 'ghost', 'soft']),
+ color: PropTypes.oneOf(['default', 'primary', 'secondary', 'info', 'success', 'warning', 'orange', 'error'])
+};
+
+export default Label;
+
+const StyledLabel = styled(Box)(({ theme, ownerState }) => {
+ // const lightMode = theme.palette.mode === 'light';
+
+ const filledVariant = ownerState.variant === 'filled';
+
+ const outlinedVariant = ownerState.variant === 'outlined';
+
+ const softVariant = ownerState.variant === 'soft';
+
+ const ghostVariant = ownerState.variant === 'ghost';
+
+ const defaultStyle = {
+ ...(ownerState.color === 'default' && {
+ // FILLED
+ ...(filledVariant && {
+ color: theme.palette.grey[300],
+ backgroundColor: theme.palette.text.primary
+ }),
+ // OUTLINED
+ ...(outlinedVariant && {
+ color: theme.palette.grey[500],
+ border: `2px solid ${theme.palette.grey[500]}`
+ }),
+ // SOFT
+ ...(softVariant && {
+ color: theme.palette.text.secondary,
+ backgroundColor: alpha(theme.palette.grey[500], 0.16)
+ })
+ })
+ };
+
+ const colorStyle = {
+ ...(ownerState.color !== 'default' && {
+ // FILLED
+ ...(filledVariant && {
+ color: theme.palette.background.paper,
+ backgroundColor: theme.palette[ownerState.color]?.main
+ }),
+ // OUTLINED
+ ...(outlinedVariant && {
+ backgroundColor: 'transparent',
+ color: theme.palette[ownerState.color]?.main,
+ border: `2px solid ${theme.palette[ownerState.color]?.main}`
+ }),
+ // SOFT
+ ...(softVariant && {
+ color: theme.palette[ownerState.color]['dark'],
+ backgroundColor: alpha(theme.palette[ownerState.color]?.main, 0.16)
+ }),
+ // GHOST
+ ...(ghostVariant && {
+ color: theme.palette[ownerState.color]?.main
+ })
+ })
+ };
+
+ return {
+ height: 24,
+ minWidth: 24,
+ lineHeight: 0,
+ borderRadius: 6,
+ cursor: 'default',
+ alignItems: 'center',
+ whiteSpace: 'nowrap',
+ display: 'inline-flex',
+ justifyContent: 'center',
+ // textTransform: 'capitalize',
+ padding: theme.spacing(0, 0.75),
+ fontSize: theme.typography.pxToRem(12),
+ fontWeight: theme.typography.fontWeightBold,
+ transition: theme.transitions.create('all', {
+ duration: theme.transitions.duration.shorter
+ }),
+ ...defaultStyle,
+ ...colorStyle
+ };
+});
diff --git a/web/berry/src/ui-component/Loadable.js b/web/berry/src/ui-component/Loadable.js
new file mode 100644
index 00000000..01de3f90
--- /dev/null
+++ b/web/berry/src/ui-component/Loadable.js
@@ -0,0 +1,15 @@
+import { Suspense } from 'react';
+
+// project imports
+import Loader from './Loader';
+
+// ==============================|| LOADABLE - LAZY LOADING ||============================== //
+
+const Loadable = (Component) => (props) =>
+ (
+ }>
+
+
+ );
+
+export default Loadable;
diff --git a/web/berry/src/ui-component/Loader.js b/web/berry/src/ui-component/Loader.js
new file mode 100644
index 00000000..9072dcdb
--- /dev/null
+++ b/web/berry/src/ui-component/Loader.js
@@ -0,0 +1,21 @@
+// material-ui
+import LinearProgress from '@mui/material/LinearProgress';
+import { styled } from '@mui/material/styles';
+
+// styles
+const LoaderWrapper = styled('div')({
+ position: 'fixed',
+ top: 0,
+ left: 0,
+ zIndex: 1301,
+ width: '100%'
+});
+
+// ==============================|| LOADER ||============================== //
+const Loader = () => (
+
+
+
+);
+
+export default Loader;
diff --git a/web/berry/src/ui-component/Logo.js b/web/berry/src/ui-component/Logo.js
new file mode 100644
index 00000000..a34fe895
--- /dev/null
+++ b/web/berry/src/ui-component/Logo.js
@@ -0,0 +1,21 @@
+// material-ui
+import logo from 'assets/images/logo.svg';
+import { useSelector } from 'react-redux';
+
+/**
+ * if you want to use image instead of