diff --git a/Dockerfile b/Dockerfile index b21a7b3c..94cd8468 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,16 @@ FROM node:16 as builder -WORKDIR /build -COPY ./web . +WORKDIR /web COPY ./VERSION . -RUN chmod u+x ./build.sh && ./build.sh +COPY ./web . + +WORKDIR /web/default +RUN npm install +RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build + +WORKDIR /web/berry +RUN npm install +RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build FROM golang AS builder2 @@ -15,7 +22,7 @@ WORKDIR /build ADD go.mod go.sum ./ RUN go mod download COPY . . -COPY --from=builder /build/build ./web/build +COPY --from=builder /web/build ./web/build RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api FROM alpine @@ -28,4 +35,4 @@ RUN apk update \ COPY --from=builder2 /build/one-api / EXPOSE 3000 WORKDIR /data -ENTRYPOINT ["/one-api"] +ENTRYPOINT ["/one-api"] \ No newline at end of file diff --git a/README.md b/README.md index 27acfedd..02a62387 100644 --- a/README.md +++ b/README.md @@ -414,6 +414,9 @@ https://openai.justsong.cn 8. 升级之前数据库需要做变更吗? + 一般情况下不需要,系统将在初始化的时候自动调整。 + 如果需要的话,我会在更新日志中说明,并给出脚本。 +9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? + + 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。 + + 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。 ## 相关项目 * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 diff --git a/common/config/config.go b/common/config/config.go new file mode 100644 index 00000000..34d871fe --- /dev/null +++ b/common/config/config.go @@ -0,0 +1,127 @@ +package config + +import ( + "one-api/common/helper" + "os" + "strconv" + "sync" + "time" + + "github.com/google/uuid" +) + +var SystemName = "One API" +var ServerAddress = "http://localhost:3000" +var Footer = "" +var Logo = "" +var TopUpLink = "" +var ChatLink = "" +var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens +var DisplayInCurrencyEnabled = true +var DisplayTokenStatEnabled = true + +// Any options with "Secret", "Token" in its key won't be return by GetOptions + +var SessionSecret = uuid.New().String() + +var OptionMap map[string]string +var OptionMapRWMutex sync.RWMutex + +var ItemsPerPage = 10 +var MaxRecentItems = 100 + +var PasswordLoginEnabled = true +var PasswordRegisterEnabled = true +var EmailVerificationEnabled = false +var GitHubOAuthEnabled = false +var WeChatAuthEnabled = false +var TurnstileCheckEnabled = false +var RegisterEnabled = true + +var EmailDomainRestrictionEnabled = false +var EmailDomainWhitelist = []string{ + "gmail.com", + "163.com", + "126.com", + "qq.com", + "outlook.com", + "hotmail.com", + "icloud.com", + "yahoo.com", + "foxmail.com", +} + +var DebugEnabled = os.Getenv("DEBUG") == "true" +var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" + +var LogConsumeEnabled = true + +var SMTPServer = "" +var SMTPPort = 587 +var SMTPAccount = "" +var SMTPFrom = "" +var SMTPToken = "" + +var GitHubClientId = "" +var GitHubClientSecret = "" + +var WeChatServerAddress = "" +var WeChatServerToken = "" +var WeChatAccountQRCodeImageURL = "" + +var TurnstileSiteKey = "" +var TurnstileSecretKey = "" + +var QuotaForNewUser = 0 +var QuotaForInviter = 0 +var QuotaForInvitee = 0 +var ChannelDisableThreshold = 5.0 +var AutomaticDisableChannelEnabled = false +var AutomaticEnableChannelEnabled = false +var QuotaRemindThreshold = 1000 +var PreConsumedQuota = 500 +var ApproximateTokenEnabled = false +var RetryTimes = 0 + +var RootUserEmail = "" + +var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" + +var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) +var RequestInterval = time.Duration(requestInterval) * time.Second + +var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second + +var BatchUpdateEnabled = false +var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) + +var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second + +var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") + +var Theme = helper.GetOrDefaultEnvString("THEME", "default") +var ValidThemes = map[string]bool{ + "default": true, + "berry": true, +} + +// All duration's unit is seconds +// Shouldn't larger then RateLimitKeyExpirationDuration +var ( + GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) + GlobalApiRateLimitDuration int64 = 3 * 60 + + GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) + GlobalWebRateLimitDuration int64 = 3 * 60 + + UploadRateLimitNum = 10 + UploadRateLimitDuration int64 = 60 + + DownloadRateLimitNum = 10 + DownloadRateLimitDuration int64 = 60 + + CriticalRateLimitNum = 20 + CriticalRateLimitDuration int64 = 20 * 60 +) + +var RateLimitKeyExpirationDuration = 20 * time.Minute diff --git a/common/constants.go b/common/constants.go index 70589041..325454d4 100644 --- a/common/constants.go +++ b/common/constants.go @@ -1,110 +1,9 @@ package common -import ( - "os" - "strconv" - "sync" - "time" - - "github.com/google/uuid" -) +import "time" var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change -var SystemName = "One API" -var ServerAddress = "http://localhost:3000" -var Footer = "" -var Logo = "" -var TopUpLink = "" -var ChatLink = "" -var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens -var DisplayInCurrencyEnabled = true -var DisplayTokenStatEnabled = true - -// Any options with "Secret", "Token" in its key won't be return by GetOptions - -var SessionSecret = uuid.New().String() - -var OptionMap map[string]string -var OptionMapRWMutex sync.RWMutex - -var ItemsPerPage = 10 -var MaxRecentItems = 100 - -var PasswordLoginEnabled = true -var PasswordRegisterEnabled = true -var EmailVerificationEnabled = false -var GitHubOAuthEnabled = false -var WeChatAuthEnabled = false -var TurnstileCheckEnabled = false -var RegisterEnabled = true - -var EmailDomainRestrictionEnabled = false -var EmailDomainWhitelist = []string{ - "gmail.com", - "163.com", - "126.com", - "qq.com", - "outlook.com", - "hotmail.com", - "icloud.com", - "yahoo.com", - "foxmail.com", -} - -var DebugEnabled = os.Getenv("DEBUG") == "true" -var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" - -var LogConsumeEnabled = true - -var SMTPServer = "" -var SMTPPort = 587 -var SMTPAccount = "" -var SMTPFrom = "" -var SMTPToken = "" - -var GitHubClientId = "" -var GitHubClientSecret = "" - -var WeChatServerAddress = "" -var WeChatServerToken = "" -var WeChatAccountQRCodeImageURL = "" - -var TurnstileSiteKey = "" -var TurnstileSecretKey = "" - -var QuotaForNewUser = 0 -var QuotaForInviter = 0 -var QuotaForInvitee = 0 -var ChannelDisableThreshold = 5.0 -var AutomaticDisableChannelEnabled = false -var AutomaticEnableChannelEnabled = false -var QuotaRemindThreshold = 1000 -var PreConsumedQuota = 500 -var ApproximateTokenEnabled = false -var RetryTimes = 0 - -var RootUserEmail = "" - -var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" - -var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) -var RequestInterval = time.Duration(requestInterval) * time.Second - -var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second - -var BatchUpdateEnabled = false -var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) - -var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second - -var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") - -var Theme = GetOrDefaultString("THEME", "default") - -const ( - RequestIdKey = "X-Oneapi-Request-Id" -) const ( RoleGuestUser = 0 @@ -113,34 +12,6 @@ const ( RoleRootUser = 100 ) -var ( - FileUploadPermission = RoleGuestUser - FileDownloadPermission = RoleGuestUser - ImageUploadPermission = RoleGuestUser - ImageDownloadPermission = RoleGuestUser -) - -// All duration's unit is seconds -// Shouldn't larger then RateLimitKeyExpirationDuration -var ( - GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) - GlobalApiRateLimitDuration int64 = 3 * 60 - - GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) - GlobalWebRateLimitDuration int64 = 3 * 60 - - UploadRateLimitNum = 10 - UploadRateLimitDuration int64 = 60 - - DownloadRateLimitNum = 10 - DownloadRateLimitDuration int64 = 60 - - CriticalRateLimitNum = 20 - CriticalRateLimitDuration int64 = 20 * 60 -) - -var RateLimitKeyExpirationDuration = 20 * time.Minute - const ( UserStatusEnabled = 1 // don't use 0, 0 is the default value! UserStatusDisabled = 2 // also don't use 0 @@ -195,29 +66,29 @@ const ( ) var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 - "https://fastgpt.run/api/openapi", // 22 - "https://hunyuan.cloud.tencent.com", //23 - "", //24 + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "https://generativelanguage.googleapis.com", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.cloud.tencent.com", // 23 + "https://generativelanguage.googleapis.com", // 24 } diff --git a/common/database.go b/common/database.go index 76f2cd55..4c2c2717 100644 --- a/common/database.go +++ b/common/database.go @@ -1,7 +1,9 @@ package common +import "one-api/common/helper" + var UsingSQLite = false var UsingPostgreSQL = false var SQLitePath = "one-api.db" -var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) +var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/email.go b/common/email.go index b915f0f9..28719742 100644 --- a/common/email.go +++ b/common/email.go @@ -6,18 +6,19 @@ import ( "encoding/base64" "fmt" "net/smtp" + "one-api/common/config" "strings" "time" ) func SendEmail(subject string, receiver string, content string) error { - if SMTPFrom == "" { // for compatibility - SMTPFrom = SMTPAccount + if config.SMTPFrom == "" { // for compatibility + config.SMTPFrom = config.SMTPAccount } encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) // Extract domain from SMTPFrom - parts := strings.Split(SMTPFrom, "@") + parts := strings.Split(config.SMTPFrom, "@") var domain string if len(parts) > 1 { domain = parts[1] @@ -36,21 +37,21 @@ func SendEmail(subject string, receiver string, content string) error { "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 "Date: %s\r\n"+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", - receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) - auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) - addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) + receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) + auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) + addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) to := strings.Split(receiver, ";") - if SMTPPort == 465 { + if config.SMTPPort == 465 { tlsConfig := &tls.Config{ InsecureSkipVerify: true, - ServerName: SMTPServer, + ServerName: config.SMTPServer, } - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) if err != nil { return err } - client, err := smtp.NewClient(conn, SMTPServer) + client, err := smtp.NewClient(conn, config.SMTPServer) if err != nil { return err } @@ -58,7 +59,7 @@ func SendEmail(subject string, receiver string, content string) error { if err = client.Auth(auth); err != nil { return err } - if err = client.Mail(SMTPFrom); err != nil { + if err = client.Mail(config.SMTPFrom); err != nil { return err } receiverEmails := strings.Split(receiver, ";") @@ -80,7 +81,7 @@ func SendEmail(subject string, receiver string, content string) error { return err } } else { - err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) + err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail) } return err } diff --git a/common/gin.go b/common/gin.go index f5012688..bed2c2b1 100644 --- a/common/gin.go +++ b/common/gin.go @@ -31,3 +31,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return nil } + +func SetEventStreamHeaders(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") +} diff --git a/common/group-ratio.go b/common/group-ratio.go index 1ec73c78..86e9e1f6 100644 --- a/common/group-ratio.go +++ b/common/group-ratio.go @@ -1,6 +1,9 @@ package common -import "encoding/json" +import ( + "encoding/json" + "one-api/common/logger" +) var GroupRatio = map[string]float64{ "default": 1, @@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{ func GroupRatio2JSONString() string { jsonBytes, err := json.Marshal(GroupRatio) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + logger.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error { func GetGroupRatio(name string) float64 { ratio, ok := GroupRatio[name] if !ok { - SysError("group ratio not found: " + name) + logger.SysError("group ratio not found: " + name) return 1 } return ratio diff --git a/common/helper/helper.go b/common/helper/helper.go new file mode 100644 index 00000000..09f1df29 --- /dev/null +++ b/common/helper/helper.go @@ -0,0 +1,224 @@ +package helper + +import ( + "fmt" + "github.com/google/uuid" + "html/template" + "log" + "math/rand" + "net" + "one-api/common/logger" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "time" +) + +func OpenBrowser(url string) { + var err error + + switch runtime.GOOS { + case "linux": + err = exec.Command("xdg-open", url).Start() + case "windows": + err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + err = exec.Command("open", url).Start() + } + if err != nil { + log.Println(err) + } +} + +func GetIp() (ip string) { + ips, err := net.InterfaceAddrs() + if err != nil { + log.Println(err) + return ip + } + + for _, a := range ips { + if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.To4() != nil { + ip = ipNet.IP.String() + if strings.HasPrefix(ip, "10") { + return + } + if strings.HasPrefix(ip, "172") { + return + } + if strings.HasPrefix(ip, "192.168") { + return + } + ip = "" + } + } + } + return +} + +var sizeKB = 1024 +var sizeMB = sizeKB * 1024 +var sizeGB = sizeMB * 1024 + +func Bytes2Size(num int64) string { + numStr := "" + unit := "B" + if num/int64(sizeGB) > 1 { + numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) + unit = "GB" + } else if num/int64(sizeMB) > 1 { + numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) + unit = "MB" + } else if num/int64(sizeKB) > 1 { + numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) + unit = "KB" + } else { + numStr = fmt.Sprintf("%d", num) + } + return numStr + " " + unit +} + +func Seconds2Time(num int) (time string) { + if num/31104000 > 0 { + time += strconv.Itoa(num/31104000) + " 年 " + num %= 31104000 + } + if num/2592000 > 0 { + time += strconv.Itoa(num/2592000) + " 个月 " + num %= 2592000 + } + if num/86400 > 0 { + time += strconv.Itoa(num/86400) + " 天 " + num %= 86400 + } + if num/3600 > 0 { + time += strconv.Itoa(num/3600) + " 小时 " + num %= 3600 + } + if num/60 > 0 { + time += strconv.Itoa(num/60) + " 分钟 " + num %= 60 + } + time += strconv.Itoa(num) + " 秒" + return +} + +func Interface2String(inter interface{}) string { + switch inter.(type) { + case string: + return inter.(string) + case int: + return fmt.Sprintf("%d", inter.(int)) + case float64: + return fmt.Sprintf("%f", inter.(float64)) + } + return "Not Implemented" +} + +func UnescapeHTML(x string) interface{} { + return template.HTML(x) +} + +func IntMax(a int, b int) int { + if a >= b { + return a + } else { + return b + } +} + +func GetUUID() string { + code := uuid.New().String() + code = strings.Replace(code, "-", "", -1) + return code +} + +const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func GenerateKey() string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, 48) + for i := 0; i < 16; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + uuid_ := GetUUID() + for i := 0; i < 32; i++ { + c := uuid_[i] + if i%2 == 0 && c >= 'a' && c <= 'z' { + c = c - 'a' + 'A' + } + key[i+16] = c + } + return string(key) +} + +func GetRandomString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + return string(key) +} + +func GetTimestamp() int64 { + return time.Now().Unix() +} + +func GetTimeString() string { + now := time.Now() + return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +} + +func Max(a int, b int) int { + if a >= b { + return a + } else { + return b + } +} + +func GetOrDefaultEnvInt(env string, defaultValue int) int { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.Atoi(os.Getenv(env)) + if err != nil { + logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) + return defaultValue + } + return num +} + +func GetOrDefaultEnvString(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} + +func AssignOrDefault(value string, defaultValue string) string { + if len(value) != 0 { + return value + } + return defaultValue +} + +func MessageWithRequestId(message string, id string) string { + return fmt.Sprintf("%s (request id: %s)", message, id) +} + +func String2Int(str string) int { + num, err := strconv.Atoi(str) + if err != nil { + return 0 + } + return num +} diff --git a/common/init.go b/common/init.go index 12df5f51..abc71108 100644 --- a/common/init.go +++ b/common/init.go @@ -4,6 +4,8 @@ import ( "flag" "fmt" "log" + "one-api/common/config" + "one-api/common/logger" "os" "path/filepath" ) @@ -37,9 +39,9 @@ func init() { if os.Getenv("SESSION_SECRET") != "" { if os.Getenv("SESSION_SECRET") == "random_string" { - SysError("SESSION_SECRET is set to an example value, please change it to a random string.") + logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.") } else { - SessionSecret = os.Getenv("SESSION_SECRET") + config.SessionSecret = os.Getenv("SESSION_SECRET") } } if os.Getenv("SQLITE_PATH") != "" { @@ -57,5 +59,6 @@ func init() { log.Fatal(err) } } + logger.LogDir = *LogDir } } diff --git a/common/logger/constants.go b/common/logger/constants.go new file mode 100644 index 00000000..78d32062 --- /dev/null +++ b/common/logger/constants.go @@ -0,0 +1,7 @@ +package logger + +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) + +var LogDir string diff --git a/common/logger.go b/common/logger/logger.go similarity index 76% rename from common/logger.go rename to common/logger/logger.go index 61627217..b89dbdb7 100644 --- a/common/logger.go +++ b/common/logger/logger.go @@ -1,4 +1,4 @@ -package common +package logger import ( "context" @@ -25,7 +25,7 @@ var setupLogLock sync.Mutex var setupLogWorking bool func SetupLogger() { - if *LogDir != "" { + if LogDir != "" { ok := setupLogLock.TryLock() if !ok { log.Println("setup log is already working") @@ -35,7 +35,7 @@ func SetupLogger() { setupLogLock.Unlock() setupLogWorking = false }() - logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") @@ -55,18 +55,30 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } -func LogInfo(ctx context.Context, msg string) { +func Info(ctx context.Context, msg string) { logHelper(ctx, loggerINFO, msg) } -func LogWarn(ctx context.Context, msg string) { +func Warn(ctx context.Context, msg string) { logHelper(ctx, loggerWarn, msg) } -func LogError(ctx context.Context, msg string) { +func Error(ctx context.Context, msg string) { logHelper(ctx, loggerError, msg) } +func Infof(ctx context.Context, format string, a ...any) { + Info(ctx, fmt.Sprintf(format, a)) +} + +func Warnf(ctx context.Context, format string, a ...any) { + Warn(ctx, fmt.Sprintf(format, a)) +} + +func Errorf(ctx context.Context, format string, a ...any) { + Error(ctx, fmt.Sprintf(format, a)) +} + func logHelper(ctx context.Context, level string, msg string) { writer := gin.DefaultErrorWriter if level == loggerINFO { @@ -90,11 +102,3 @@ func FatalLog(v ...any) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) os.Exit(1) } - -func LogQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) - } else { - return fmt.Sprintf("%d 点额度", quota) - } -} diff --git a/common/model-ratio.go b/common/model-ratio.go index 97cb060d..9f31e0d7 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -2,6 +2,7 @@ package common import ( "encoding/json" + "one-api/common/logger" "strings" "time" ) @@ -107,7 +108,7 @@ var ModelRatio = map[string]float64{ func ModelRatio2JSONString() string { jsonBytes, err := json.Marshal(ModelRatio) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + logger.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -123,7 +124,7 @@ func GetModelRatio(name string) float64 { } ratio, ok := ModelRatio[name] if !ok { - SysError("model ratio not found: " + name) + logger.SysError("model ratio not found: " + name) return 30 } return ratio diff --git a/common/redis.go b/common/redis.go index 12c477b8..ed3fcd9d 100644 --- a/common/redis.go +++ b/common/redis.go @@ -3,6 +3,7 @@ package common import ( "context" "github.com/go-redis/redis/v8" + "one-api/common/logger" "os" "time" ) @@ -14,18 +15,18 @@ var RedisEnabled = true func InitRedisClient() (err error) { if os.Getenv("REDIS_CONN_STRING") == "" { RedisEnabled = false - SysLog("REDIS_CONN_STRING not set, Redis is not enabled") + logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") return nil } if os.Getenv("SYNC_FREQUENCY") == "" { RedisEnabled = false - SysLog("SYNC_FREQUENCY not set, Redis is disabled") + logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") return nil } - SysLog("Redis is enabled") + logger.SysLog("Redis is enabled") opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) if err != nil { - FatalLog("failed to parse Redis connection string: " + err.Error()) + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) } RDB = redis.NewClient(opt) @@ -34,7 +35,7 @@ func InitRedisClient() (err error) { _, err = RDB.Ping(ctx).Result() if err != nil { - FatalLog("Redis ping test failed: " + err.Error()) + logger.FatalLog("Redis ping test failed: " + err.Error()) } return err } @@ -42,7 +43,7 @@ func InitRedisClient() (err error) { func ParseRedisOption() *redis.Options { opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) if err != nil { - FatalLog("failed to parse Redis connection string: " + err.Error()) + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) } return opt } diff --git a/common/utils.go b/common/utils.go index 9a7038e2..41de9367 100644 --- a/common/utils.go +++ b/common/utils.go @@ -2,215 +2,13 @@ package common import ( "fmt" - "github.com/google/uuid" - "html/template" - "log" - "math/rand" - "net" - "os" - "os/exec" - "runtime" - "strconv" - "strings" - "time" + "one-api/common/config" ) -func OpenBrowser(url string) { - var err error - - switch runtime.GOOS { - case "linux": - err = exec.Command("xdg-open", url).Start() - case "windows": - err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() - case "darwin": - err = exec.Command("open", url).Start() - } - if err != nil { - log.Println(err) - } -} - -func GetIp() (ip string) { - ips, err := net.InterfaceAddrs() - if err != nil { - log.Println(err) - return ip - } - - for _, a := range ips { - if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { - if ipNet.IP.To4() != nil { - ip = ipNet.IP.String() - if strings.HasPrefix(ip, "10") { - return - } - if strings.HasPrefix(ip, "172") { - return - } - if strings.HasPrefix(ip, "192.168") { - return - } - ip = "" - } - } - } - return -} - -var sizeKB = 1024 -var sizeMB = sizeKB * 1024 -var sizeGB = sizeMB * 1024 - -func Bytes2Size(num int64) string { - numStr := "" - unit := "B" - if num/int64(sizeGB) > 1 { - numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) - unit = "GB" - } else if num/int64(sizeMB) > 1 { - numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) - unit = "MB" - } else if num/int64(sizeKB) > 1 { - numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) - unit = "KB" +func LogQuota(quota int) string { + if config.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) } else { - numStr = fmt.Sprintf("%d", num) - } - return numStr + " " + unit -} - -func Seconds2Time(num int) (time string) { - if num/31104000 > 0 { - time += strconv.Itoa(num/31104000) + " 年 " - num %= 31104000 - } - if num/2592000 > 0 { - time += strconv.Itoa(num/2592000) + " 个月 " - num %= 2592000 - } - if num/86400 > 0 { - time += strconv.Itoa(num/86400) + " 天 " - num %= 86400 - } - if num/3600 > 0 { - time += strconv.Itoa(num/3600) + " 小时 " - num %= 3600 - } - if num/60 > 0 { - time += strconv.Itoa(num/60) + " 分钟 " - num %= 60 - } - time += strconv.Itoa(num) + " 秒" - return -} - -func Interface2String(inter interface{}) string { - switch inter.(type) { - case string: - return inter.(string) - case int: - return fmt.Sprintf("%d", inter.(int)) - case float64: - return fmt.Sprintf("%f", inter.(float64)) - } - return "Not Implemented" -} - -func UnescapeHTML(x string) interface{} { - return template.HTML(x) -} - -func IntMax(a int, b int) int { - if a >= b { - return a - } else { - return b + return fmt.Sprintf("%d 点额度", quota) } } - -func GetUUID() string { - code := uuid.New().String() - code = strings.Replace(code, "-", "", -1) - return code -} - -const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func GenerateKey() string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, 48) - for i := 0; i < 16; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - uuid_ := GetUUID() - for i := 0; i < 32; i++ { - c := uuid_[i] - if i%2 == 0 && c >= 'a' && c <= 'z' { - c = c - 'a' + 'A' - } - key[i+16] = c - } - return string(key) -} - -func GetRandomString(length int) string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - return string(key) -} - -func GetTimestamp() int64 { - return time.Now().Unix() -} - -func GetTimeString() string { - now := time.Now() - return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) -} - -func Max(a int, b int) int { - if a >= b { - return a - } else { - return b - } -} - -func GetOrDefault(env string, defaultValue int) int { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.Atoi(os.Getenv(env)) - if err != nil { - SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultString(env string, defaultValue string) string { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) -} - -func MessageWithRequestId(message string, id string) string { - return fmt.Sprintf("%s (request id: %s)", message, id) -} - -func String2Int(str string) int { - num, err := strconv.Atoi(str) - if err != nil { - return 0 - } - return num -} diff --git a/controller/billing.go b/controller/billing.go index 42e86aea..1003bf10 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -2,8 +2,9 @@ package controller import ( "github.com/gin-gonic/gin" - "one-api/common" + "one-api/common/config" "one-api/model" + "one-api/relay/channel/openai" ) func GetSubscription(c *gin.Context) { @@ -12,7 +13,7 @@ func GetSubscription(c *gin.Context) { var err error var token *model.Token var expiredTime int64 - if common.DisplayTokenStatEnabled { + if config.DisplayTokenStatEnabled { tokenId := c.GetInt("token_id") token, err = model.GetTokenById(tokenId) expiredTime = token.ExpiredTime @@ -27,19 +28,19 @@ func GetSubscription(c *gin.Context) { expiredTime = 0 } if err != nil { - openAIError := OpenAIError{ + Error := openai.Error{ Message: err.Error(), Type: "upstream_error", } c.JSON(200, gin.H{ - "error": openAIError, + "error": Error, }) return } quota := remainQuota + usedQuota amount := float64(quota) - if common.DisplayInCurrencyEnabled { - amount /= common.QuotaPerUnit + if config.DisplayInCurrencyEnabled { + amount /= config.QuotaPerUnit } if token != nil && token.UnlimitedQuota { amount = 100000000 @@ -60,7 +61,7 @@ func GetUsage(c *gin.Context) { var quota int var err error var token *model.Token - if common.DisplayTokenStatEnabled { + if config.DisplayTokenStatEnabled { tokenId := c.GetInt("token_id") token, err = model.GetTokenById(tokenId) quota = token.UsedQuota @@ -69,18 +70,18 @@ func GetUsage(c *gin.Context) { quota, err = model.GetUserUsedQuota(userId) } if err != nil { - openAIError := OpenAIError{ + Error := openai.Error{ Message: err.Error(), Type: "one_api_error", } c.JSON(200, gin.H{ - "error": openAIError, + "error": Error, }) return } amount := float64(quota) - if common.DisplayInCurrencyEnabled { - amount /= common.QuotaPerUnit + if config.DisplayInCurrencyEnabled { + amount /= config.QuotaPerUnit } usage := OpenAIUsageResponse{ Object: "list", diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 6ddad7ea..49c76760 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -7,7 +7,10 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" + "one-api/common/logger" "one-api/model" + "one-api/relay/util" "strconv" "time" @@ -92,7 +95,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He for k := range headers { req.Header.Add(k, headers.Get(k)) } - res, err := httpClient.Do(req) + res, err := util.HTTPClient.Do(req) if err != nil { return nil, err } @@ -313,7 +316,7 @@ func updateAllChannelsBalance() error { disableChannel(channel.Id, channel.Name, "余额不足") } } - time.Sleep(common.RequestInterval) + time.Sleep(config.RequestInterval) } return nil } @@ -338,8 +341,8 @@ func UpdateAllChannelsBalance(c *gin.Context) { func AutomaticallyUpdateChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("updating all channels") + logger.SysLog("updating all channels") _ = updateAllChannelsBalance() - common.SysLog("channels update done") + logger.SysLog("channels update done") } } diff --git a/controller/channel-test.go b/controller/channel-test.go index 3aaa4897..88d6e3f2 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -8,7 +8,11 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" + "one-api/common/logger" "one-api/model" + "one-api/relay/channel/openai" + "one-api/relay/util" "strconv" "sync" "time" @@ -16,7 +20,7 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { +func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) { switch channel.Type { case common.ChannelTypePaLM: fallthrough @@ -46,13 +50,13 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { - requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) + requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) } else { if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { requestURL = baseURL } - requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) + requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) } jsonData, err := json.Marshal(request) if err != nil { @@ -68,12 +72,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai req.Header.Set("Authorization", "Bearer "+channel.Key) } req.Header.Set("Content-Type", "application/json") - resp, err := httpClient.Do(req) + resp, err := util.HTTPClient.Do(req) if err != nil { return err, nil } defer resp.Body.Close() - var response TextResponse + var response openai.SlimTextResponse body, err := io.ReadAll(resp.Body) if err != nil { return err, nil @@ -91,12 +95,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai return nil, nil } -func buildTestRequest() *ChatRequest { - testRequest := &ChatRequest{ +func buildTestRequest() *openai.ChatRequest { + testRequest := &openai.ChatRequest{ Model: "", // this will be set later MaxTokens: 1, } - testMessage := Message{ + testMessage := openai.Message{ Role: "user", Content: "hi", } @@ -148,12 +152,12 @@ var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false func notifyRootUser(subject string, content string) { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() } - err := common.SendEmail(subject, common.RootUserEmail, content) + err := common.SendEmail(subject, config.RootUserEmail, content) if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } @@ -174,8 +178,8 @@ func enableChannel(channelId int, channelName string) { } func testAllChannels(notify bool) error { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() } testAllChannelsLock.Lock() if testAllChannelsRunning { @@ -189,7 +193,7 @@ func testAllChannels(notify bool) error { return err } testRequest := buildTestRequest() - var disableThreshold = int64(common.ChannelDisableThreshold * 1000) + var disableThreshold = int64(config.ChannelDisableThreshold * 1000) if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value } @@ -204,22 +208,22 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) disableChannel(channel.Id, channel.Name, err.Error()) } - if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { + if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { disableChannel(channel.Id, channel.Name, err.Error()) } - if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { + if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { enableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) - time.Sleep(common.RequestInterval) + time.Sleep(config.RequestInterval) } testAllChannelsLock.Lock() testAllChannelsRunning = false testAllChannelsLock.Unlock() if notify { - err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") + err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } }() @@ -245,8 +249,8 @@ func TestAllChannels(c *gin.Context) { func AutomaticallyTestChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("testing all channels") + logger.SysLog("testing all channels") _ = testAllChannels(false) - common.SysLog("channel test finished") + logger.SysLog("channel test finished") } } diff --git a/controller/channel.go b/controller/channel.go index 904abc23..6d368066 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -3,7 +3,8 @@ package controller import ( "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/model" "strconv" "strings" @@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) { if p < 0 { p = 0 } - channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false) + channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) { }) return } - channel.CreatedTime = common.GetTimestamp() + channel.CreatedTime = helper.GetTimestamp() keys := strings.Split(channel.Key, "\n") channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { diff --git a/controller/github.go b/controller/github.go index ee995379..f6049ddb 100644 --- a/controller/github.go +++ b/controller/github.go @@ -9,6 +9,9 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/config" + "one-api/common/helper" + "one-api/common/logger" "one-api/model" "strconv" "time" @@ -30,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { if code == "" { return nil, errors.New("无效的参数") } - values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} + values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code} jsonData, err := json.Marshal(values) if err != nil { return nil, err @@ -46,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { } res, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res.Body.Close() @@ -62,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) res2, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res2.Body.Close() @@ -93,7 +96,7 @@ func GitHubOAuth(c *gin.Context) { return } - if !common.GitHubOAuthEnabled { + if !config.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", @@ -122,7 +125,7 @@ func GitHubOAuth(c *gin.Context) { return } } else { - if common.RegisterEnabled { + if config.RegisterEnabled { user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) if githubUser.Name != "" { user.DisplayName = githubUser.Name @@ -160,7 +163,7 @@ func GitHubOAuth(c *gin.Context) { } func GitHubBind(c *gin.Context) { - if !common.GitHubOAuthEnabled { + if !config.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", @@ -216,7 +219,7 @@ func GitHubBind(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) { session := sessions.Default(c) - state := common.GetRandomString(12) + state := helper.GetRandomString(12) session.Set("oauth_state", state) err := session.Save() if err != nil { diff --git a/controller/log.go b/controller/log.go index b65867fe..418c6859 100644 --- a/controller/log.go +++ b/controller/log.go @@ -3,7 +3,7 @@ package controller import ( "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/common/config" "one-api/model" "strconv" ) @@ -20,7 +20,7 @@ func GetAllLogs(c *gin.Context) { tokenName := c.Query("token_name") modelName := c.Query("model_name") channel, _ := strconv.Atoi(c.Query("channel")) - logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) + logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -47,7 +47,7 @@ func GetUserLogs(c *gin.Context) { endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") modelName := c.Query("model_name") - logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) + logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/misc.go b/controller/misc.go index 2bcbb41f..df7c0728 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/model" "strings" @@ -18,55 +19,55 @@ func GetStatus(c *gin.Context) { "data": gin.H{ "version": common.Version, "start_time": common.StartTime, - "email_verification": common.EmailVerificationEnabled, - "github_oauth": common.GitHubOAuthEnabled, - "github_client_id": common.GitHubClientId, - "system_name": common.SystemName, - "logo": common.Logo, - "footer_html": common.Footer, - "wechat_qrcode": common.WeChatAccountQRCodeImageURL, - "wechat_login": common.WeChatAuthEnabled, - "server_address": common.ServerAddress, - "turnstile_check": common.TurnstileCheckEnabled, - "turnstile_site_key": common.TurnstileSiteKey, - "top_up_link": common.TopUpLink, - "chat_link": common.ChatLink, - "quota_per_unit": common.QuotaPerUnit, - "display_in_currency": common.DisplayInCurrencyEnabled, + "email_verification": config.EmailVerificationEnabled, + "github_oauth": config.GitHubOAuthEnabled, + "github_client_id": config.GitHubClientId, + "system_name": config.SystemName, + "logo": config.Logo, + "footer_html": config.Footer, + "wechat_qrcode": config.WeChatAccountQRCodeImageURL, + "wechat_login": config.WeChatAuthEnabled, + "server_address": config.ServerAddress, + "turnstile_check": config.TurnstileCheckEnabled, + "turnstile_site_key": config.TurnstileSiteKey, + "top_up_link": config.TopUpLink, + "chat_link": config.ChatLink, + "quota_per_unit": config.QuotaPerUnit, + "display_in_currency": config.DisplayInCurrencyEnabled, }, }) return } func GetNotice(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["Notice"], + "data": config.OptionMap["Notice"], }) return } func GetAbout(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["About"], + "data": config.OptionMap["About"], }) return } func GetHomePageContent(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["HomePageContent"], + "data": config.OptionMap["HomePageContent"], }) return } @@ -80,9 +81,9 @@ func SendEmailVerification(c *gin.Context) { }) return } - if common.EmailDomainRestrictionEnabled { + if config.EmailDomainRestrictionEnabled { allowed := false - for _, domain := range common.EmailDomainWhitelist { + for _, domain := range config.EmailDomainWhitelist { if strings.HasSuffix(email, "@"+domain) { allowed = true break @@ -105,10 +106,10 @@ func SendEmailVerification(c *gin.Context) { } code := common.GenerateVerificationCode(6) common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) - subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) + subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName) content := fmt.Sprintf("

您好,你正在进行%s邮箱验证。

"+ "

您的验证码为: %s

"+ - "

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

", common.SystemName, code, common.VerificationValidMinutes) + "

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

", config.SystemName, code, common.VerificationValidMinutes) err := common.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -142,12 +143,12 @@ func SendPasswordResetEmail(c *gin.Context) { } code := common.GenerateVerificationCode(0) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) - link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) - subject := fmt.Sprintf("%s密码重置", common.SystemName) + link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code) + subject := fmt.Sprintf("%s密码重置", config.SystemName) content := fmt.Sprintf("

您好,你正在进行%s密码重置。

"+ "

点击 此处 进行密码重置。

"+ "

如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s

"+ - "

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

", common.SystemName, link, link, common.VerificationValidMinutes) + "

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

", config.SystemName, link, link, common.VerificationValidMinutes) err := common.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ diff --git a/controller/model.go b/controller/model.go index 6cb530db..b7ec1b6a 100644 --- a/controller/model.go +++ b/controller/model.go @@ -2,8 +2,8 @@ package controller import ( "fmt" - "github.com/gin-gonic/gin" + "one-api/relay/channel/openai" ) // https://platform.openai.com/docs/api-reference/models/list @@ -436,7 +436,7 @@ func init() { Id: "PaLM-2", Object: "model", Created: 1677649963, - OwnedBy: "google", + OwnedBy: "google palm", Permission: permission, Root: "PaLM-2", Parent: nil, @@ -445,7 +445,7 @@ func init() { Id: "gemini-pro", Object: "model", Created: 1677649963, - OwnedBy: "google", + OwnedBy: "google gemini", Permission: permission, Root: "gemini-pro", Parent: nil, @@ -454,7 +454,7 @@ func init() { Id: "gemini-pro-vision", Object: "model", Created: 1677649963, - OwnedBy: "google", + OwnedBy: "google gemini", Permission: permission, Root: "gemini-pro-vision", Parent: nil, @@ -613,14 +613,14 @@ func RetrieveModel(c *gin.Context) { if model, ok := openAIModelsMap[modelId]; ok { c.JSON(200, model) } else { - openAIError := OpenAIError{ + Error := openai.Error{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), Type: "invalid_request_error", Param: "model", Code: "model_not_found", } c.JSON(200, gin.H{ - "error": openAIError, + "error": Error, }) } } diff --git a/controller/option.go b/controller/option.go index bbf83578..593ac2ae 100644 --- a/controller/option.go +++ b/controller/option.go @@ -3,7 +3,8 @@ package controller import ( "encoding/json" "net/http" - "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/model" "strings" @@ -12,17 +13,17 @@ import ( func GetOptions(c *gin.Context) { var options []*model.Option - common.OptionMapRWMutex.Lock() - for k, v := range common.OptionMap { + config.OptionMapRWMutex.Lock() + for k, v := range config.OptionMap { if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { continue } options = append(options, &model.Option{ Key: k, - Value: common.Interface2String(v), + Value: helper.Interface2String(v), }) } - common.OptionMapRWMutex.Unlock() + config.OptionMapRWMutex.Unlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -42,8 +43,16 @@ func UpdateOption(c *gin.Context) { return } switch option.Key { + case "Theme": + if !config.ValidThemes[option.Value] { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的主题", + }) + return + } case "GitHubOAuthEnabled": - if option.Value == "true" && common.GitHubClientId == "" { + if option.Value == "true" && config.GitHubClientId == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", @@ -51,7 +60,7 @@ func UpdateOption(c *gin.Context) { return } case "EmailDomainRestrictionEnabled": - if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { + if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", @@ -59,7 +68,7 @@ func UpdateOption(c *gin.Context) { return } case "WeChatAuthEnabled": - if option.Value == "true" && common.WeChatServerAddress == "" { + if option.Value == "true" && config.WeChatServerAddress == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用微信登录,请先填入微信登录相关配置信息!", @@ -67,7 +76,7 @@ func UpdateOption(c *gin.Context) { return } case "TurnstileCheckEnabled": - if option.Value == "true" && common.TurnstileSiteKey == "" { + if option.Value == "true" && config.TurnstileSiteKey == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", diff --git a/controller/redemption.go b/controller/redemption.go index 0f656be0..9eeeb943 100644 --- a/controller/redemption.go +++ b/controller/redemption.go @@ -3,7 +3,8 @@ package controller import ( "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/model" "strconv" ) @@ -13,7 +14,7 @@ func GetAllRedemptions(c *gin.Context) { if p < 0 { p = 0 } - redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage) + redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -105,12 +106,12 @@ func AddRedemption(c *gin.Context) { } var keys []string for i := 0; i < redemption.Count; i++ { - key := common.GetUUID() + key := helper.GetUUID() cleanRedemption := model.Redemption{ UserId: c.GetInt("id"), Name: redemption.Name, Key: key, - CreatedTime: common.GetTimestamp(), + CreatedTime: helper.GetTimestamp(), Quota: redemption.Quota, } err = cleanRedemption.Insert() diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go deleted file mode 100644 index 5930ae89..00000000 --- a/controller/relay-tencent.go +++ /dev/null @@ -1,288 +0,0 @@ -package controller - -import ( - "bufio" - "crypto/hmac" - "crypto/sha1" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/common" - "sort" - "strconv" - "strings" -) - -// https://cloud.tencent.com/document/product/1729/97732 - -type TencentMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type TencentChatRequest struct { - AppId int64 `json:"app_id"` // 腾讯云账号的 APPID - SecretId string `json:"secret_id"` // 官网 SecretId - // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 - // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 - Timestamp int64 `json:"timestamp"` - // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, - // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 - Expired int64 `json:"expired"` - QueryID string `json:"query_id"` //请求 Id,用于问题排查 - // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 - // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 - // 建议该参数和 top_p 只设置1个,不要同时更改 top_p - Temperature float64 `json:"temperature"` - // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 - // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 - // 建议该参数和 temperature 只设置1个,不要同时更改 - TopP float64 `json:"top_p"` - // Stream 0:同步,1:流式 (默认,协议:SSE) - // 同步请求超时:60s,如果内容较长建议使用流式 - Stream int `json:"stream"` - // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 - // 输入 content 总数最大支持 3000 token。 - Messages []TencentMessage `json:"messages"` -} - -type TencentError struct { - Code int `json:"code"` - Message string `json:"message"` -} - -type TencentUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type TencentResponseChoices struct { - FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 - Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 - Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 -} - -type TencentChatResponse struct { - Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 - Created string `json:"created,omitempty"` // unix 时间戳的字符串 - Id string `json:"id,omitempty"` // 会话 id - Usage Usage `json:"usage,omitempty"` // token 数量 - Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 - Note string `json:"note,omitempty"` // 注释 - ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 -} - -func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { - messages := make([]TencentMessage, 0, len(request.Messages)) - for i := 0; i < len(request.Messages); i++ { - message := request.Messages[i] - if message.Role == "system" { - messages = append(messages, TencentMessage{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, TencentMessage{ - Role: "assistant", - Content: "Okay", - }) - continue - } - messages = append(messages, TencentMessage{ - Content: message.StringContent(), - Role: message.Role, - }) - } - stream := 0 - if request.Stream { - stream = 1 - } - return &TencentChatRequest{ - Timestamp: common.GetTimestamp(), - Expired: common.GetTimestamp() + 24*60*60, - QueryID: common.GetUUID(), - Temperature: request.Temperature, - TopP: request.TopP, - Stream: stream, - Messages: messages, - } -} - -func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ - Object: "chat.completion", - Created: common.GetTimestamp(), - Usage: response.Usage, - } - if len(response.Choices) > 0 { - choice := OpenAITextResponseChoice{ - Index: 0, - Message: Message{ - Role: "assistant", - Content: response.Choices[0].Messages.Content, - }, - FinishReason: response.Choices[0].FinishReason, - } - fullTextResponse.Choices = append(fullTextResponse.Choices, choice) - } - return &fullTextResponse -} - -func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { - response := ChatCompletionsStreamResponse{ - Object: "chat.completion.chunk", - Created: common.GetTimestamp(), - Model: "tencent-hunyuan", - } - if len(TencentResponse.Choices) > 0 { - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = TencentResponse.Choices[0].Delta.Content - if TencentResponse.Choices[0].FinishReason == "stop" { - choice.FinishReason = &stopFinishReason - } - response.Choices = append(response.Choices, choice) - } - return &response -} - -func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { - var responseText string - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() - setEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var TencentResponse TencentChatResponse - err := json.Unmarshal([]byte(data), &TencentResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response := streamResponseTencent2OpenAI(&TencentResponse) - if len(response.Choices) != 0 { - responseText += response.Choices[0].Delta.Content - } - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - err := resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" - } - return nil, responseText -} - -func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var TencentResponse TencentChatResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &TencentResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if TencentResponse.Error.Code != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: TencentResponse.Error.Message, - Code: TencentResponse.Error.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseTencent2OpenAI(&TencentResponse) - fullTextResponse.Model = "hunyuan" - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage -} - -func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { - parts := strings.Split(config, "|") - if len(parts) != 3 { - err = errors.New("invalid tencent config") - return - } - appId, err = strconv.ParseInt(parts[0], 10, 64) - secretId = parts[1] - secretKey = parts[2] - return -} - -func getTencentSign(req TencentChatRequest, secretKey string) string { - params := make([]string, 0) - params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) - params = append(params, "secret_id="+req.SecretId) - params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) - params = append(params, "query_id="+req.QueryID) - params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) - params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) - params = append(params, "stream="+strconv.Itoa(req.Stream)) - params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) - - var messageStr string - for _, msg := range req.Messages { - messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) - } - messageStr = strings.TrimSuffix(messageStr, ",") - params = append(params, "messages=["+messageStr+"]") - - sort.Sort(sort.StringSlice(params)) - url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") - mac := hmac.New(sha1.New, []byte(secretKey)) - signURL := url - mac.Write([]byte(signURL)) - sign := mac.Sum([]byte(nil)) - return base64.StdEncoding.EncodeToString(sign) -} diff --git a/controller/relay-text.go b/controller/relay-text.go deleted file mode 100644 index 64338545..00000000 --- a/controller/relay-text.go +++ /dev/null @@ -1,689 +0,0 @@ -package controller - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "math" - "net/http" - "one-api/common" - "one-api/model" - "strings" - "time" - - "github.com/gin-gonic/gin" -) - -const ( - APITypeOpenAI = iota - APITypeClaude - APITypePaLM - APITypeBaidu - APITypeZhipu - APITypeAli - APITypeXunfei - APITypeAIProxyLibrary - APITypeTencent - APITypeGemini -) - -var httpClient *http.Client -var impatientHTTPClient *http.Client - -func init() { - if common.RelayTimeout == 0 { - httpClient = &http.Client{} - } else { - httpClient = &http.Client{ - Timeout: time.Duration(common.RelayTimeout) * time.Second, - } - } - - impatientHTTPClient = &http.Client{ - Timeout: 5 * time.Second, - } -} - -func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - group := c.GetString("group") - var textRequest GeneralOpenAIRequest - err := common.UnmarshalBodyReusable(c, &textRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { - return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest) - } - if relayMode == RelayModeModerations && textRequest.Model == "" { - textRequest.Model = "text-moderation-latest" - } - if relayMode == RelayModeEmbeddings && textRequest.Model == "" { - textRequest.Model = c.Param("model") - } - // request validation - if textRequest.Model == "" { - return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) - } - switch relayMode { - case RelayModeCompletions: - if textRequest.Prompt == "" { - return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeChatCompletions: - if textRequest.Messages == nil || len(textRequest.Messages) == 0 { - return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeEmbeddings: - case RelayModeModerations: - if textRequest.Input == "" { - return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeEdits: - if textRequest.Instruction == "" { - return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) - } - } - // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[textRequest.Model] != "" { - textRequest.Model = modelMap[textRequest.Model] - isModelMapped = true - } - } - apiType := APITypeOpenAI - switch channelType { - case common.ChannelTypeAnthropic: - apiType = APITypeClaude - case common.ChannelTypeBaidu: - apiType = APITypeBaidu - case common.ChannelTypePaLM: - apiType = APITypePaLM - case common.ChannelTypeZhipu: - apiType = APITypeZhipu - case common.ChannelTypeAli: - apiType = APITypeAli - case common.ChannelTypeXunfei: - apiType = APITypeXunfei - case common.ChannelTypeAIProxyLibrary: - apiType = APITypeAIProxyLibrary - case common.ChannelTypeTencent: - apiType = APITypeTencent - case common.ChannelTypeGemini: - apiType = APITypeGemini - } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - switch apiType { - case APITypeOpenAI: - if channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api - apiVersion := GetAPIVersion(c) - requestURL := strings.Split(requestURL, "?")[0] - requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) - baseURL = c.GetString("base_url") - task := strings.TrimPrefix(requestURL, "/v1/") - model_ := textRequest.Model - model_ = strings.Replace(model_, ".", "", -1) - // https://github.com/songquanpeng/one-api/issues/67 - model_ = strings.TrimSuffix(model_, "-0301") - model_ = strings.TrimSuffix(model_, "-0314") - model_ = strings.TrimSuffix(model_, "-0613") - - requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) - } - case APITypeClaude: - fullRequestURL = "https://api.anthropic.com/v1/complete" - if baseURL != "" { - fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) - } - case APITypeBaidu: - switch textRequest.Model { - case "ERNIE-Bot": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" - case "ERNIE-Bot-turbo": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" - case "ERNIE-Bot-4": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" - case "BLOOMZ-7B": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" - case "Embedding-V1": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" - } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - var err error - if apiKey, err = getBaiduAccessToken(apiKey); err != nil { - return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) - } - fullRequestURL += "?access_token=" + apiKey - case APITypePaLM: - fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" - if baseURL != "" { - fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) - } - case APITypeGemini: - requestBaseURL := "https://generativelanguage.googleapis.com" - if baseURL != "" { - requestBaseURL = baseURL - } - version := "v1" - if c.GetString("api_version") != "" { - version = c.GetString("api_version") - } - action := "generateContent" - if textRequest.Stream { - action = "streamGenerateContent" - } - fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) - case APITypeZhipu: - method := "invoke" - if textRequest.Stream { - method = "sse-invoke" - } - fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) - case APITypeAli: - fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" - if relayMode == RelayModeEmbeddings { - fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" - } - case APITypeTencent: - fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" - case APITypeAIProxyLibrary: - fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) - } - var promptTokens int - var completionTokens int - switch relayMode { - case RelayModeChatCompletions: - promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) - case RelayModeCompletions: - promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) - case RelayModeModerations: - promptTokens = countTokenInput(textRequest.Input, textRequest.Model) - } - preConsumedTokens := common.PreConsumedQuota - if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + textRequest.MaxTokens - } - modelRatio := common.GetModelRatio(textRequest.Model) - groupRatio := common.GetGroupRatio(group) - ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) - userQuota, err := model.CacheGetUserQuota(userId) - if err != nil { - return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) - } - if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) - } - } - var requestBody io.Reader - if isModelMapped { - jsonStr, err := json.Marshal(textRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } - switch apiType { - case APITypeClaude: - claudeRequest := requestOpenAI2Claude(textRequest) - jsonStr, err := json.Marshal(claudeRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeBaidu: - var jsonData []byte - var err error - switch relayMode { - case RelayModeEmbeddings: - baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) - jsonData, err = json.Marshal(baiduEmbeddingRequest) - default: - baiduRequest := requestOpenAI2Baidu(textRequest) - jsonData, err = json.Marshal(baiduRequest) - } - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonData) - case APITypePaLM: - palmRequest := requestOpenAI2PaLM(textRequest) - jsonStr, err := json.Marshal(palmRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeGemini: - geminiChatRequest := requestOpenAI2Gemini(textRequest) - jsonStr, err := json.Marshal(geminiChatRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeZhipu: - zhipuRequest := requestOpenAI2Zhipu(textRequest) - jsonStr, err := json.Marshal(zhipuRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeAli: - var jsonStr []byte - var err error - switch relayMode { - case RelayModeEmbeddings: - aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) - jsonStr, err = json.Marshal(aliEmbeddingRequest) - default: - aliRequest := requestOpenAI2Ali(textRequest) - jsonStr, err = json.Marshal(aliRequest) - } - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeTencent: - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - appId, secretId, secretKey, err := parseTencentConfig(apiKey) - if err != nil { - return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) - } - tencentRequest := requestOpenAI2Tencent(textRequest) - tencentRequest.AppId = appId - tencentRequest.SecretId = secretId - jsonStr, err := json.Marshal(tencentRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - sign := getTencentSign(*tencentRequest, secretKey) - c.Request.Header.Set("Authorization", sign) - requestBody = bytes.NewBuffer(jsonStr) - case APITypeAIProxyLibrary: - aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) - aiProxyLibraryRequest.LibraryId = c.GetString("library_id") - jsonStr, err := json.Marshal(aiProxyLibraryRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } - - var req *http.Request - var resp *http.Response - isStream := textRequest.Stream - - if apiType != APITypeXunfei { // cause xunfei use websocket - req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - switch apiType { - case APITypeOpenAI: - if channelType == common.ChannelTypeAzure { - req.Header.Set("api-key", apiKey) - } else { - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - if channelType == common.ChannelTypeOpenRouter { - req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") - req.Header.Set("X-Title", "One API") - } - } - case APITypeClaude: - req.Header.Set("x-api-key", apiKey) - anthropicVersion := c.Request.Header.Get("anthropic-version") - if anthropicVersion == "" { - anthropicVersion = "2023-06-01" - } - req.Header.Set("anthropic-version", anthropicVersion) - case APITypeZhipu: - token := getZhipuToken(apiKey) - req.Header.Set("Authorization", token) - case APITypeAli: - req.Header.Set("Authorization", "Bearer "+apiKey) - if textRequest.Stream { - req.Header.Set("X-DashScope-SSE", "enable") - } - if c.GetString("plugin") != "" { - req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) - } - case APITypeTencent: - req.Header.Set("Authorization", apiKey) - case APITypePaLM: - req.Header.Set("x-goog-api-key", apiKey) - case APITypeGemini: - req.Header.Set("x-goog-api-key", apiKey) - default: - req.Header.Set("Authorization", "Bearer "+apiKey) - } - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - if isStream && c.Request.Header.Get("Accept") == "" { - req.Header.Set("Accept", "text/event-stream") - } - //req.Header.Set("Connection", c.Request.Header.Get("Connection")) - resp, err = httpClient.Do(req) - if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) - } - err = req.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - - if resp.StatusCode != http.StatusOK { - if preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - return relayErrorHandler(resp) - } - } - - var textResponse TextResponse - tokenName := c.GetString("token_name") - - defer func(ctx context.Context) { - // c.Writer.Flush() - go func() { - quota := 0 - completionRatio := common.GetCompletionRatio(textRequest.Model) - promptTokens = textResponse.Usage.PromptTokens - completionTokens = textResponse.Usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 - } - totalTokens := promptTokens + completionTokens - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } - if quota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) - } - - }() - }(c.Request.Context()) - switch apiType { - case APITypeOpenAI: - if isStream { - err, responseText := openaiStreamHandler(c, resp, relayMode) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeClaude: - if isStream { - err, responseText := claudeStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeBaidu: - if isStream { - err, usage := baiduStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - var err *OpenAIErrorWithStatusCode - var usage *Usage - switch relayMode { - case RelayModeEmbeddings: - err, usage = baiduEmbeddingHandler(c, resp) - default: - err, usage = baiduHandler(c, resp) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypePaLM: - if textRequest.Stream { // PaLM2 API does not support stream - err, responseText := palmStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeGemini: - if textRequest.Stream { - err, responseText := geminiChatStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeZhipu: - if isStream { - err, usage := zhipuStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - // zhipu's API does not return prompt tokens & completion tokens - textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens - return nil - } else { - err, usage := zhipuHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - // zhipu's API does not return prompt tokens & completion tokens - textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens - return nil - } - case APITypeAli: - if isStream { - err, usage := aliStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - var err *OpenAIErrorWithStatusCode - var usage *Usage - switch relayMode { - case RelayModeEmbeddings: - err, usage = aliEmbeddingHandler(c, resp) - default: - err, usage = aliHandler(c, resp) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeXunfei: - auth := c.Request.Header.Get("Authorization") - auth = strings.TrimPrefix(auth, "Bearer ") - splits := strings.Split(auth, "|") - if len(splits) != 3 { - return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) - } - var err *OpenAIErrorWithStatusCode - var usage *Usage - if isStream { - err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) - } else { - err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - case APITypeAIProxyLibrary: - if isStream { - err, usage := aiProxyLibraryStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - err, usage := aiProxyLibraryHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeTencent: - if isStream { - err, responseText := tencentStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := tencentHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - default: - return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) - } -} diff --git a/controller/relay.go b/controller/relay.go index e45fd3eb..46fedc7e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,374 +2,57 @@ package controller import ( "fmt" - "net/http" - "one-api/common" - "strconv" - "strings" - "github.com/gin-gonic/gin" -) - -type Message struct { - Role string `json:"role"` - Content any `json:"content"` - Name *string `json:"name,omitempty"` -} - -type ImageURL struct { - Url string `json:"url,omitempty"` - Detail string `json:"detail,omitempty"` -} - -type TextContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text,omitempty"` -} - -type ImageContent struct { - Type string `json:"type,omitempty"` - ImageURL *ImageURL `json:"image_url,omitempty"` -} - -const ( - ContentTypeText = "text" - ContentTypeImageURL = "image_url" -) - -type OpenAIMessageContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text"` - ImageURL *ImageURL `json:"image_url,omitempty"` -} - -func (m Message) IsStringContent() bool { - _, ok := m.Content.(string) - return ok -} - -func (m Message) StringContent() string { - content, ok := m.Content.(string) - if ok { - return content - } - contentList, ok := m.Content.([]any) - if ok { - var contentStr string - for _, contentItem := range contentList { - contentMap, ok := contentItem.(map[string]any) - if !ok { - continue - } - if contentMap["type"] == ContentTypeText { - if subStr, ok := contentMap["text"].(string); ok { - contentStr += subStr - } - } - } - return contentStr - } - return "" -} - -func (m Message) ParseContent() []OpenAIMessageContent { - var contentList []OpenAIMessageContent - content, ok := m.Content.(string) - if ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeText, - Text: content, - }) - return contentList - } - anyList, ok := m.Content.([]any) - if ok { - for _, contentItem := range anyList { - contentMap, ok := contentItem.(map[string]any) - if !ok { - continue - } - switch contentMap["type"] { - case ContentTypeText: - if subStr, ok := contentMap["text"].(string); ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeText, - Text: subStr, - }) - } - case ContentTypeImageURL: - if subObj, ok := contentMap["image_url"].(map[string]any); ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeImageURL, - ImageURL: &ImageURL{ - Url: subObj["url"].(string), - }, - }) - } - } - } - return contentList - } - return nil -} - -const ( - RelayModeUnknown = iota - RelayModeChatCompletions - RelayModeCompletions - RelayModeEmbeddings - RelayModeModerations - RelayModeImagesGenerations - RelayModeEdits - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation + "net/http" + "one-api/common/config" + "one-api/common/helper" + "one-api/common/logger" + "one-api/relay/channel/openai" + "one-api/relay/constant" + "one-api/relay/controller" + "one-api/relay/util" + "strconv" ) // https://platform.openai.com/docs/api-reference/chat -type ResponseFormat struct { - Type string `json:"type,omitempty"` -} - -type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` -} - -func (r GeneralOpenAIRequest) ParseInput() []string { - if r.Input == nil { - return nil - } - var input []string - switch r.Input.(type) { - case string: - input = []string{r.Input.(string)} - case []any: - input = make([]string, 0, len(r.Input.([]any))) - for _, item := range r.Input.([]any) { - if str, ok := item.(string); ok { - input = append(input, str) - } - } - } - return input -} - -type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - MaxTokens int `json:"max_tokens"` -} - -type TextRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Prompt string `json:"prompt"` - MaxTokens int `json:"max_tokens"` - //Stream bool `json:"stream"` -} - -// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create -type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` - User string `json:"user,omitempty"` -} - -type WhisperJSONResponse struct { - Text string `json:"text,omitempty"` -} - -type WhisperVerboseJSONResponse struct { - Task string `json:"task,omitempty"` - Language string `json:"language,omitempty"` - Duration float64 `json:"duration,omitempty"` - Text string `json:"text,omitempty"` - Segments []Segment `json:"segments,omitempty"` -} - -type Segment struct { - Id int `json:"id"` - Seek int `json:"seek"` - Start float64 `json:"start"` - End float64 `json:"end"` - Text string `json:"text"` - Tokens []int `json:"tokens"` - Temperature float64 `json:"temperature"` - AvgLogprob float64 `json:"avg_logprob"` - CompressionRatio float64 `json:"compression_ratio"` - NoSpeechProb float64 `json:"no_speech_prob"` -} - -type TextToSpeechRequest struct { - Model string `json:"model" binding:"required"` - Input string `json:"input" binding:"required"` - Voice string `json:"voice" binding:"required"` - Speed float64 `json:"speed"` - ResponseFormat string `json:"response_format"` -} - -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type OpenAIError struct { - Message string `json:"message"` - Type string `json:"type"` - Param string `json:"param"` - Code any `json:"code"` -} - -type OpenAIErrorWithStatusCode struct { - OpenAIError - StatusCode int `json:"status_code"` -} - -type TextResponse struct { - Choices []OpenAITextResponseChoice `json:"choices"` - Usage `json:"usage"` - Error OpenAIError `json:"error"` -} - -type OpenAITextResponseChoice struct { - Index int `json:"index"` - Message `json:"message"` - FinishReason string `json:"finish_reason"` -} - -type OpenAITextResponse struct { - Id string `json:"id"` - Model string `json:"model,omitempty"` - Object string `json:"object"` - Created int64 `json:"created"` - Choices []OpenAITextResponseChoice `json:"choices"` - Usage `json:"usage"` -} - -type OpenAIEmbeddingResponseItem struct { - Object string `json:"object"` - Index int `json:"index"` - Embedding []float64 `json:"embedding"` -} - -type OpenAIEmbeddingResponse struct { - Object string `json:"object"` - Data []OpenAIEmbeddingResponseItem `json:"data"` - Model string `json:"model"` - Usage `json:"usage"` -} - -type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - } -} - -type ChatCompletionsStreamResponseChoice struct { - Delta struct { - Content string `json:"content"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` -} - -type ChatCompletionsStreamResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionsStreamResponseChoice `json:"choices"` -} - -type CompletionsStreamResponse struct { - Choices []struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` -} - func Relay(c *gin.Context) { - relayMode := RelayModeUnknown - if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { - relayMode = RelayModeChatCompletions - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { - relayMode = RelayModeCompletions - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - relayMode = RelayModeModerations - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - relayMode = RelayModeImagesGenerations - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { - relayMode = RelayModeEdits - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - relayMode = RelayModeAudioSpeech - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - relayMode = RelayModeAudioTranscription - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - relayMode = RelayModeAudioTranslation - } - var err *OpenAIErrorWithStatusCode + relayMode := constant.Path2RelayMode(c.Request.URL.Path) + var err *openai.ErrorWithStatusCode switch relayMode { - case RelayModeImagesGenerations: - err = relayImageHelper(c, relayMode) - case RelayModeAudioSpeech: + case constant.RelayModeImagesGenerations: + err = controller.RelayImageHelper(c, relayMode) + case constant.RelayModeAudioSpeech: fallthrough - case RelayModeAudioTranslation: + case constant.RelayModeAudioTranslation: fallthrough - case RelayModeAudioTranscription: - err = relayAudioHelper(c, relayMode) + case constant.RelayModeAudioTranscription: + err = controller.RelayAudioHelper(c, relayMode) default: - err = relayTextHelper(c, relayMode) + err = controller.RelayTextHelper(c, relayMode) } if err != nil { - requestId := c.GetString(common.RequestIdKey) + requestId := c.GetString(logger.RequestIdKey) retryTimesStr := c.Query("retry") retryTimes, _ := strconv.Atoi(retryTimesStr) if retryTimesStr == "" { - retryTimes = common.RetryTimes + retryTimes = config.RetryTimes } if retryTimes > 0 { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) } else { if err.StatusCode == http.StatusTooManyRequests { - err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" + err.Error.Message = "当前分组上游负载已饱和,请稍后再试" } - err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) + err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId) c.JSON(err.StatusCode, gin.H{ - "error": err.OpenAIError, + "error": err.Error, }) } channelId := c.GetInt("channel_id") - common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) + logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors - if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { + if util.ShouldDisableChannel(&err.Error, err.StatusCode) { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") disableChannel(channelId, channelName, err.Message) @@ -378,7 +61,7 @@ func Relay(c *gin.Context) { } func RelayNotImplemented(c *gin.Context) { - err := OpenAIError{ + err := openai.Error{ Message: "API not implemented", Type: "one_api_error", Param: "", @@ -390,7 +73,7 @@ func RelayNotImplemented(c *gin.Context) { } func RelayNotFound(c *gin.Context) { - err := OpenAIError{ + err := openai.Error{ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Type: "invalid_request_error", Param: "", diff --git a/controller/token.go b/controller/token.go index 8642122c..d6554abe 100644 --- a/controller/token.go +++ b/controller/token.go @@ -4,6 +4,8 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/model" "strconv" ) @@ -14,7 +16,7 @@ func GetAllTokens(c *gin.Context) { if p < 0 { p = 0 } - tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage) + tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -119,9 +121,9 @@ func AddToken(c *gin.Context) { cleanToken := model.Token{ UserId: c.GetInt("id"), Name: token.Name, - Key: common.GenerateKey(), - CreatedTime: common.GetTimestamp(), - AccessedTime: common.GetTimestamp(), + Key: helper.GenerateKey(), + CreatedTime: helper.GetTimestamp(), + AccessedTime: helper.GetTimestamp(), ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, @@ -187,7 +189,7 @@ func UpdateToken(c *gin.Context) { return } if token.Status == common.TokenStatusEnabled { - if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { + if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", diff --git a/controller/user.go b/controller/user.go index 8fd10b82..d13acddd 100644 --- a/controller/user.go +++ b/controller/user.go @@ -5,8 +5,11 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/model" "strconv" + "time" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" @@ -18,7 +21,7 @@ type LoginRequest struct { } func Login(c *gin.Context) { - if !common.PasswordLoginEnabled { + if !config.PasswordLoginEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了密码登录", "success": false, @@ -105,14 +108,14 @@ func Logout(c *gin.Context) { } func Register(c *gin.Context) { - if !common.RegisterEnabled { + if !config.RegisterEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了新用户注册", "success": false, }) return } - if !common.PasswordRegisterEnabled { + if !config.PasswordRegisterEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", "success": false, @@ -135,7 +138,7 @@ func Register(c *gin.Context) { }) return } - if common.EmailVerificationEnabled { + if config.EmailVerificationEnabled { if user.Email == "" || user.VerificationCode == "" { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -159,7 +162,7 @@ func Register(c *gin.Context) { DisplayName: user.Username, InviterId: inviterId, } - if common.EmailVerificationEnabled { + if config.EmailVerificationEnabled { cleanUser.Email = user.Email } if err := cleanUser.Insert(inviterId); err != nil { @@ -181,7 +184,7 @@ func GetAllUsers(c *gin.Context) { if p < 0 { p = 0 } - users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage) + users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -248,6 +251,29 @@ func GetUser(c *gin.Context) { return } +func GetUserDashboard(c *gin.Context) { + id := c.GetInt("id") + now := time.Now() + startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix() + endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix() + + dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法获取统计信息", + "data": nil, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": dashboards, + }) + return +} + func GenerateAccessToken(c *gin.Context) { id := c.GetInt("id") user, err := model.GetUserById(id, true) @@ -258,7 +284,7 @@ func GenerateAccessToken(c *gin.Context) { }) return } - user.AccessToken = common.GetUUID() + user.AccessToken = helper.GetUUID() if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { c.JSON(http.StatusOK, gin.H{ @@ -295,7 +321,7 @@ func GetAffCode(c *gin.Context) { return } if user.AffCode == "" { - user.AffCode = common.GetRandomString(4) + user.AffCode = helper.GetRandomString(4) if err := user.Update(false); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -702,7 +728,7 @@ func EmailBind(c *gin.Context) { return } if user.Role == common.RoleRootUser { - common.RootUserEmail = email + config.RootUserEmail = email } c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/controller/wechat.go b/controller/wechat.go index ff4c9fb6..fc231a24 100644 --- a/controller/wechat.go +++ b/controller/wechat.go @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/config" "one-api/model" "strconv" "time" @@ -22,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) { if code == "" { return "", errors.New("无效的参数") } - req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) + req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil) if err != nil { return "", err } - req.Header.Set("Authorization", common.WeChatServerToken) + req.Header.Set("Authorization", config.WeChatServerToken) client := http.Client{ Timeout: 5 * time.Second, } @@ -50,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) { } func WeChatAuth(c *gin.Context) { - if !common.WeChatAuthEnabled { + if !config.WeChatAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员未开启通过微信登录以及注册", "success": false, @@ -79,7 +80,7 @@ func WeChatAuth(c *gin.Context) { return } } else { - if common.RegisterEnabled { + if config.RegisterEnabled { user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.DisplayName = "WeChat User" user.Role = common.RoleCommonUser @@ -112,7 +113,7 @@ func WeChatAuth(c *gin.Context) { } func WeChatBind(c *gin.Context) { - if !common.WeChatAuthEnabled { + if !config.WeChatAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员未开启通过微信登录以及注册", "success": false, diff --git a/i18n/en.json b/i18n/en.json index f67d8665..774be837 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -86,6 +86,7 @@ "该令牌已过期": "The token has expired", "该令牌额度已用尽": "The token quota has been used up", "无效的令牌": "Invalid token", + "令牌验证失败": "Token verification failed", "id 或 userId 为空!": "id or userId is empty!", "quota 不能为负数!": "quota cannot be negative!", "令牌额度不足": "Insufficient token quota", diff --git a/main.go b/main.go index 3ab1872c..b79c7bf7 100644 --- a/main.go +++ b/main.go @@ -7,9 +7,12 @@ import ( "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" "one-api/common" + "one-api/common/config" + "one-api/common/logger" "one-api/controller" "one-api/middleware" "one-api/model" + "one-api/relay/channel/openai" "one-api/router" "os" "strconv" @@ -19,67 +22,68 @@ import ( var buildFS embed.FS func main() { - common.SetupLogger() - common.SysLog(fmt.Sprintf("One API %s started with theme %s", common.Version, common.Theme)) + logger.SetupLogger() + logger.SysLog(fmt.Sprintf("One API %s started", common.Version)) if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) } - if common.DebugEnabled { - common.SysLog("running in debug mode") + if config.DebugEnabled { + logger.SysLog("running in debug mode") } // Initialize SQL Database err := model.InitDB() if err != nil { - common.FatalLog("failed to initialize database: " + err.Error()) + logger.FatalLog("failed to initialize database: " + err.Error()) } defer func() { err := model.CloseDB() if err != nil { - common.FatalLog("failed to close database: " + err.Error()) + logger.FatalLog("failed to close database: " + err.Error()) } }() // Initialize Redis err = common.InitRedisClient() if err != nil { - common.FatalLog("failed to initialize Redis: " + err.Error()) + logger.FatalLog("failed to initialize Redis: " + err.Error()) } // Initialize options model.InitOptionMap() + logger.SysLog(fmt.Sprintf("using theme %s", config.Theme)) if common.RedisEnabled { // for compatibility with old versions - common.MemoryCacheEnabled = true + config.MemoryCacheEnabled = true } - if common.MemoryCacheEnabled { - common.SysLog("memory cache enabled") - common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) + if config.MemoryCacheEnabled { + logger.SysLog("memory cache enabled") + logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency)) model.InitChannelCache() } - if common.MemoryCacheEnabled { - go model.SyncOptions(common.SyncFrequency) - go model.SyncChannelCache(common.SyncFrequency) + if config.MemoryCacheEnabled { + go model.SyncOptions(config.SyncFrequency) + go model.SyncChannelCache(config.SyncFrequency) } if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) if err != nil { - common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) + logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) } go controller.AutomaticallyUpdateChannels(frequency) } if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) if err != nil { - common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) + logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) } go controller.AutomaticallyTestChannels(frequency) } if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { - common.BatchUpdateEnabled = true - common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + config.BatchUpdateEnabled = true + logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") model.InitBatchUpdater() } - controller.InitTokenEncoders() + openai.InitTokenEncoders() // Initialize HTTP server server := gin.New() @@ -89,7 +93,7 @@ func main() { server.Use(middleware.RequestId()) middleware.SetUpLogger(server) // Initialize session store - store := cookie.NewStore([]byte(common.SessionSecret)) + store := cookie.NewStore([]byte(config.SessionSecret)) server.Use(sessions.Sessions("session", store)) router.SetRouter(server, buildFS) @@ -99,6 +103,6 @@ func main() { } err = server.Run(":" + port) if err != nil { - common.FatalLog("failed to start HTTP server: " + err.Error()) + logger.FatalLog("failed to start HTTP server: " + err.Error()) } } diff --git a/middleware/distributor.go b/middleware/distributor.go index 81338130..6b607d68 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "strconv" "strings" @@ -69,7 +70,7 @@ func Distribute() func(c *gin.Context) { if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) if channel != nil { - common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" } abortWithMessage(c, http.StatusServiceUnavailable, message) diff --git a/middleware/logger.go b/middleware/logger.go index 02f2e0a9..ca372c52 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -3,14 +3,14 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" - "one-api/common" + "one-api/common/logger" ) func SetUpLogger(server *gin.Engine) { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { var requestID string if param.Keys != nil { - requestID = param.Keys[common.RequestIdKey].(string) + requestID = param.Keys[logger.RequestIdKey].(string) } return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", param.TimeStamp.Format("2006/01/02 - 15:04:05"), diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go index 8e5cff6c..e89ef8d6 100644 --- a/middleware/rate-limit.go +++ b/middleware/rate-limit.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/config" "time" ) @@ -26,7 +27,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st } if listLength < int64(maxRequestNum) { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) } else { oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) @@ -47,14 +48,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st // time.Since will return negative number! // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows if int64(nowTime.Sub(oldTime).Seconds()) < duration { - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) c.Status(http.StatusTooManyRequests) c.Abort() return } else { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) } } } @@ -75,7 +76,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi } } else { // It's safe to call multi times. - inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) + inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration) return func(c *gin.Context) { memoryRateLimiter(c, maxRequestNum, duration, mark) } @@ -83,21 +84,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi } func GlobalWebRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") + return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW") } func GlobalAPIRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") + return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA") } func CriticalRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") + return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT") } func DownloadRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") + return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW") } func UploadRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") + return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP") } diff --git a/middleware/recover.go b/middleware/recover.go index 8338a514..9d3edc27 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/common/logger" "runtime/debug" ) @@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - common.SysError(fmt.Sprintf("panic detected: %v", err)) - common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + logger.SysError(fmt.Sprintf("panic detected: %v", err)) + logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), diff --git a/middleware/request-id.go b/middleware/request-id.go index e623be7a..811a7ad8 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -3,16 +3,17 @@ package middleware import ( "context" "github.com/gin-gonic/gin" - "one-api/common" + "one-api/common/helper" + "one-api/common/logger" ) func RequestId() func(c *gin.Context) { return func(c *gin.Context) { - id := common.GetTimeString() + common.GetRandomString(8) - c.Set(common.RequestIdKey, id) - ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) + id := helper.GetTimeString() + helper.GetRandomString(8) + c.Set(logger.RequestIdKey, id) + ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) c.Request = c.Request.WithContext(ctx) - c.Header(common.RequestIdKey, id) + c.Header(logger.RequestIdKey, id) c.Next() } } diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index 26688810..629395e7 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -6,7 +6,8 @@ import ( "github.com/gin-gonic/gin" "net/http" "net/url" - "one-api/common" + "one-api/common/config" + "one-api/common/logger" ) type turnstileCheckResponse struct { @@ -15,7 +16,7 @@ type turnstileCheckResponse struct { func TurnstileCheck() gin.HandlerFunc { return func(c *gin.Context) { - if common.TurnstileCheckEnabled { + if config.TurnstileCheckEnabled { session := sessions.Default(c) turnstileChecked := session.Get("turnstile") if turnstileChecked != nil { @@ -32,12 +33,12 @@ func TurnstileCheck() gin.HandlerFunc { return } rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ - "secret": {common.TurnstileSecretKey}, + "secret": {config.TurnstileSecretKey}, "response": {response}, "remoteip": {c.ClientIP()}, }) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc { var res turnstileCheckResponse err = json.NewDecoder(rawRes.Body).Decode(&res) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/middleware/utils.go b/middleware/utils.go index 536125cc..d866d75b 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -2,16 +2,17 @@ package middleware import ( "github.com/gin-gonic/gin" - "one-api/common" + "one-api/common/helper" + "one-api/common/logger" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ - "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), + "message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), "type": "one_api_error", }, }) c.Abort() - common.LogError(c.Request.Context(), message) + logger.Error(c.Request.Context(), message) } diff --git a/model/cache.go b/model/cache.go index c6d0c70a..a81bdddd 100644 --- a/model/cache.go +++ b/model/cache.go @@ -6,6 +6,8 @@ import ( "fmt" "math/rand" "one-api/common" + "one-api/common/config" + "one-api/common/logger" "sort" "strconv" "strings" @@ -14,10 +16,10 @@ import ( ) var ( - TokenCacheSeconds = common.SyncFrequency - UserId2GroupCacheSeconds = common.SyncFrequency - UserId2QuotaCacheSeconds = common.SyncFrequency - UserId2StatusCacheSeconds = common.SyncFrequency + TokenCacheSeconds = config.SyncFrequency + UserId2GroupCacheSeconds = config.SyncFrequency + UserId2QuotaCacheSeconds = config.SyncFrequency + UserId2StatusCacheSeconds = config.SyncFrequency ) func CacheGetTokenByKey(key string) (*Token, error) { @@ -42,7 +44,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { } err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set token error: " + err.Error()) + logger.SysError("Redis set token error: " + err.Error()) } return &token, nil } @@ -62,7 +64,7 @@ func CacheGetUserGroup(id int) (group string, err error) { } err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user group error: " + err.Error()) + logger.SysError("Redis set user group error: " + err.Error()) } } return group, err @@ -80,7 +82,7 @@ func CacheGetUserQuota(id int) (quota int, err error) { } err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user quota error: " + err.Error()) + logger.SysError("Redis set user quota error: " + err.Error()) } return quota, err } @@ -127,7 +129,7 @@ func CacheIsUserEnabled(userId int) (bool, error) { } err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user enabled error: " + err.Error()) + logger.SysError("Redis set user enabled error: " + err.Error()) } return userEnabled, err } @@ -178,19 +180,19 @@ func InitChannelCache() { channelSyncLock.Lock() group2model2channels = newGroup2model2channels channelSyncLock.Unlock() - common.SysLog("channels synced from database") + logger.SysLog("channels synced from database") } func SyncChannelCache(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing channels from database") + logger.SysLog("syncing channels from database") InitChannelCache() } } func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { - if !common.MemoryCacheEnabled { + if !config.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model) } channelSyncLock.RLock() diff --git a/model/channel.go b/model/channel.go index 085e3ca4..4c77d902 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,8 +1,13 @@ package model import ( + "encoding/json" + "fmt" "gorm.io/gorm" "one-api/common" + "one-api/common/config" + "one-api/common/helper" + "one-api/common/logger" ) type Channel struct { @@ -42,7 +47,7 @@ func SearchChannels(keyword string) (channels []*Channel, err error) { if common.UsingPostgreSQL { keyCol = `"key"` } - err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error + err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error return channels, err } @@ -86,11 +91,17 @@ func (channel *Channel) GetBaseURL() string { return *channel.BaseURL } -func (channel *Channel) GetModelMapping() string { - if channel.ModelMapping == nil { - return "" +func (channel *Channel) GetModelMapping() map[string]string { + if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" { + return nil } - return *channel.ModelMapping + modelMapping := make(map[string]string) + err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping) + if err != nil { + logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error())) + return nil + } + return modelMapping } func (channel *Channel) Insert() error { @@ -116,21 +127,21 @@ func (channel *Channel) Update() error { func (channel *Channel) UpdateResponseTime(responseTime int64) { err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ - TestTime: common.GetTimestamp(), + TestTime: helper.GetTimestamp(), ResponseTime: int(responseTime), }).Error if err != nil { - common.SysError("failed to update response time: " + err.Error()) + logger.SysError("failed to update response time: " + err.Error()) } } func (channel *Channel) UpdateBalance(balance float64) { err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ - BalanceUpdatedTime: common.GetTimestamp(), + BalanceUpdatedTime: helper.GetTimestamp(), Balance: balance, }).Error if err != nil { - common.SysError("failed to update balance: " + err.Error()) + logger.SysError("failed to update balance: " + err.Error()) } } @@ -147,16 +158,16 @@ func (channel *Channel) Delete() error { func UpdateChannelStatusById(id int, status int) { err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) if err != nil { - common.SysError("failed to update ability status: " + err.Error()) + logger.SysError("failed to update ability status: " + err.Error()) } err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error if err != nil { - common.SysError("failed to update channel status: " + err.Error()) + logger.SysError("failed to update channel status: " + err.Error()) } } func UpdateChannelUsedQuota(id int, quota int) { - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) return } @@ -166,7 +177,7 @@ func UpdateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { - common.SysError("failed to update channel used quota: " + err.Error()) + logger.SysError("failed to update channel used quota: " + err.Error()) } } diff --git a/model/log.go b/model/log.go index 3d3ffae3..78b6d9b3 100644 --- a/model/log.go +++ b/model/log.go @@ -3,14 +3,18 @@ package model import ( "context" "fmt" - "gorm.io/gorm" "one-api/common" + "one-api/common/config" + "one-api/common/helper" + "one-api/common/logger" + + "gorm.io/gorm" ) type Log struct { - Id int `json:"id;index:idx_created_at_id,priority:1"` + Id int `json:"id"` UserId int `json:"user_id" gorm:"index"` - CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` + CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"` Type int `json:"type" gorm:"index:idx_created_at_type"` Content string `json:"content"` Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` @@ -31,31 +35,31 @@ const ( ) func RecordLog(userId int, logType int, content string) { - if logType == LogTypeConsume && !common.LogConsumeEnabled { + if logType == LogTypeConsume && !config.LogConsumeEnabled { return } log := &Log{ UserId: userId, Username: GetUsernameById(userId), - CreatedAt: common.GetTimestamp(), + CreatedAt: helper.GetTimestamp(), Type: logType, Content: content, } err := DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + logger.SysError("failed to record log: " + err.Error()) } } func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { - common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) - if !common.LogConsumeEnabled { + logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) + if !config.LogConsumeEnabled { return } log := &Log{ UserId: userId, Username: GetUsernameById(userId), - CreatedAt: common.GetTimestamp(), + CreatedAt: helper.GetTimestamp(), Type: LogTypeConsume, Content: content, PromptTokens: promptTokens, @@ -67,7 +71,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke } err := DB.Create(log).Error if err != nil { - common.LogError(ctx, "failed to record log: "+err.Error()) + logger.Error(ctx, "failed to record log: "+err.Error()) } } @@ -124,12 +128,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int } func SearchAllLogs(keyword string) (logs []*Log, err error) { - err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error + err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error return logs, err } func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { - err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error + err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error return logs, err } @@ -182,3 +186,40 @@ func DeleteOldLog(targetTimestamp int64) (int64, error) { result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) return result.RowsAffected, result.Error } + +type LogStatistic struct { + Day string `gorm:"column:day"` + ModelName string `gorm:"column:model_name"` + RequestCount int `gorm:"column:request_count"` + Quota int `gorm:"column:quota"` + PromptTokens int `gorm:"column:prompt_tokens"` + CompletionTokens int `gorm:"column:completion_tokens"` +} + +func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatistic, err error) { + groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day" + + if common.UsingPostgreSQL { + groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day" + } + + if common.UsingSQLite { + groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" + } + + err = DB.Raw(` + SELECT `+groupSelect+`, + model_name, count(1) as request_count, + sum(quota) as quota, + sum(prompt_tokens) as prompt_tokens, + sum(completion_tokens) as completion_tokens + FROM logs + WHERE type=2 + AND user_id= ? + AND created_at BETWEEN ? AND ? + GROUP BY day, model_name + ORDER BY day, model_name + `, userId, start, end).Scan(&LogStatistics).Error + + return LogStatistics, err +} diff --git a/model/main.go b/model/main.go index bfd6888b..2ed6f0e3 100644 --- a/model/main.go +++ b/model/main.go @@ -7,6 +7,9 @@ import ( "gorm.io/driver/sqlite" "gorm.io/gorm" "one-api/common" + "one-api/common/config" + "one-api/common/helper" + "one-api/common/logger" "os" "strings" "time" @@ -16,9 +19,9 @@ var DB *gorm.DB func createRootAccountIfNeed() error { var user User - //if user.Status != common.UserStatusEnabled { + //if user.Status != util.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { - common.SysLog("no user exists, create a root user for you: username is root, password is 123456") + logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") if err != nil { return err @@ -29,7 +32,7 @@ func createRootAccountIfNeed() error { Role: common.RoleRootUser, Status: common.UserStatusEnabled, DisplayName: "Root User", - AccessToken: common.GetUUID(), + AccessToken: helper.GetUUID(), Quota: 100000000, } DB.Create(&rootUser) @@ -42,7 +45,7 @@ func chooseDB() (*gorm.DB, error) { dsn := os.Getenv("SQL_DSN") if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL - common.SysLog("using PostgreSQL as database") + logger.SysLog("using PostgreSQL as database") common.UsingPostgreSQL = true return gorm.Open(postgres.New(postgres.Config{ DSN: dsn, @@ -52,13 +55,13 @@ func chooseDB() (*gorm.DB, error) { }) } // Use MySQL - common.SysLog("using MySQL as database") + logger.SysLog("using MySQL as database") return gorm.Open(mysql.Open(dsn), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } // Use SQLite - common.SysLog("SQL_DSN not set, using SQLite as database") + logger.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ @@ -69,7 +72,7 @@ func chooseDB() (*gorm.DB, error) { func InitDB() (err error) { db, err := chooseDB() if err == nil { - if common.DebugEnabled { + if config.DebugEnabled { db = db.Debug() } DB = db @@ -77,14 +80,14 @@ func InitDB() (err error) { if err != nil { return err } - sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) + sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) - if !common.IsMasterNode { + if !config.IsMasterNode { return nil } - common.SysLog("database migration started") + logger.SysLog("database migration started") err = db.AutoMigrate(&Channel{}) if err != nil { return err @@ -113,11 +116,11 @@ func InitDB() (err error) { if err != nil { return err } - common.SysLog("database migrated") + logger.SysLog("database migrated") err = createRootAccountIfNeed() return err } else { - common.FatalLog(err) + logger.FatalLog(err) } return err } diff --git a/model/option.go b/model/option.go index bb8b709c..e211264c 100644 --- a/model/option.go +++ b/model/option.go @@ -2,6 +2,8 @@ package model import ( "one-api/common" + "one-api/common/config" + "one-api/common/logger" "strconv" "strings" "time" @@ -20,59 +22,56 @@ func AllOption() ([]*Option, error) { } func InitOptionMap() { - common.OptionMapRWMutex.Lock() - common.OptionMap = make(map[string]string) - common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) - common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) - common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) - common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) - common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) - common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) - common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) - common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) - common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) - common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) - common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) - common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) - common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) - common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) - common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) - common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) - common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) - common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) - common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) - common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") - common.OptionMap["SMTPServer"] = "" - common.OptionMap["SMTPFrom"] = "" - common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) - common.OptionMap["SMTPAccount"] = "" - common.OptionMap["SMTPToken"] = "" - common.OptionMap["Notice"] = "" - common.OptionMap["About"] = "" - common.OptionMap["HomePageContent"] = "" - common.OptionMap["Footer"] = common.Footer - common.OptionMap["SystemName"] = common.SystemName - common.OptionMap["Logo"] = common.Logo - common.OptionMap["ServerAddress"] = "" - common.OptionMap["GitHubClientId"] = "" - common.OptionMap["GitHubClientSecret"] = "" - common.OptionMap["WeChatServerAddress"] = "" - common.OptionMap["WeChatServerToken"] = "" - common.OptionMap["WeChatAccountQRCodeImageURL"] = "" - common.OptionMap["TurnstileSiteKey"] = "" - common.OptionMap["TurnstileSecretKey"] = "" - common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) - common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) - common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) - common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) - common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) - common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() - common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() - common.OptionMap["TopUpLink"] = common.TopUpLink - common.OptionMap["ChatLink"] = common.ChatLink - common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) - common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) - common.OptionMapRWMutex.Unlock() + config.OptionMapRWMutex.Lock() + config.OptionMap = make(map[string]string) + config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled) + config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) + config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) + config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) + config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) + config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) + config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) + config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled) + config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled) + config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled) + config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled) + config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled) + config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled) + config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64) + config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled) + config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",") + config.OptionMap["SMTPServer"] = "" + config.OptionMap["SMTPFrom"] = "" + config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort) + config.OptionMap["SMTPAccount"] = "" + config.OptionMap["SMTPToken"] = "" + config.OptionMap["Notice"] = "" + config.OptionMap["About"] = "" + config.OptionMap["HomePageContent"] = "" + config.OptionMap["Footer"] = config.Footer + config.OptionMap["SystemName"] = config.SystemName + config.OptionMap["Logo"] = config.Logo + config.OptionMap["ServerAddress"] = "" + config.OptionMap["GitHubClientId"] = "" + config.OptionMap["GitHubClientSecret"] = "" + config.OptionMap["WeChatServerAddress"] = "" + config.OptionMap["WeChatServerToken"] = "" + config.OptionMap["WeChatAccountQRCodeImageURL"] = "" + config.OptionMap["TurnstileSiteKey"] = "" + config.OptionMap["TurnstileSecretKey"] = "" + config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) + config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) + config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) + config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) + config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) + config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() + config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() + config.OptionMap["TopUpLink"] = config.TopUpLink + config.OptionMap["ChatLink"] = config.ChatLink + config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) + config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes) + config.OptionMap["Theme"] = config.Theme + config.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() } @@ -81,7 +80,7 @@ func loadOptionsFromDatabase() { for _, option := range options { err := updateOptionMap(option.Key, option.Value) if err != nil { - common.SysError("failed to update option map: " + err.Error()) + logger.SysError("failed to update option map: " + err.Error()) } } } @@ -89,7 +88,7 @@ func loadOptionsFromDatabase() { func SyncOptions(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing options from database") + logger.SysLog("syncing options from database") loadOptionsFromDatabase() } } @@ -111,115 +110,104 @@ func UpdateOption(key string, value string) error { } func updateOptionMap(key string, value string) (err error) { - common.OptionMapRWMutex.Lock() - defer common.OptionMapRWMutex.Unlock() - common.OptionMap[key] = value - if strings.HasSuffix(key, "Permission") { - intValue, _ := strconv.Atoi(value) - switch key { - case "FileUploadPermission": - common.FileUploadPermission = intValue - case "FileDownloadPermission": - common.FileDownloadPermission = intValue - case "ImageUploadPermission": - common.ImageUploadPermission = intValue - case "ImageDownloadPermission": - common.ImageDownloadPermission = intValue - } - } + config.OptionMapRWMutex.Lock() + defer config.OptionMapRWMutex.Unlock() + config.OptionMap[key] = value if strings.HasSuffix(key, "Enabled") { boolValue := value == "true" switch key { case "PasswordRegisterEnabled": - common.PasswordRegisterEnabled = boolValue + config.PasswordRegisterEnabled = boolValue case "PasswordLoginEnabled": - common.PasswordLoginEnabled = boolValue + config.PasswordLoginEnabled = boolValue case "EmailVerificationEnabled": - common.EmailVerificationEnabled = boolValue + config.EmailVerificationEnabled = boolValue case "GitHubOAuthEnabled": - common.GitHubOAuthEnabled = boolValue + config.GitHubOAuthEnabled = boolValue case "WeChatAuthEnabled": - common.WeChatAuthEnabled = boolValue + config.WeChatAuthEnabled = boolValue case "TurnstileCheckEnabled": - common.TurnstileCheckEnabled = boolValue + config.TurnstileCheckEnabled = boolValue case "RegisterEnabled": - common.RegisterEnabled = boolValue + config.RegisterEnabled = boolValue case "EmailDomainRestrictionEnabled": - common.EmailDomainRestrictionEnabled = boolValue + config.EmailDomainRestrictionEnabled = boolValue case "AutomaticDisableChannelEnabled": - common.AutomaticDisableChannelEnabled = boolValue + config.AutomaticDisableChannelEnabled = boolValue case "AutomaticEnableChannelEnabled": - common.AutomaticEnableChannelEnabled = boolValue + config.AutomaticEnableChannelEnabled = boolValue case "ApproximateTokenEnabled": - common.ApproximateTokenEnabled = boolValue + config.ApproximateTokenEnabled = boolValue case "LogConsumeEnabled": - common.LogConsumeEnabled = boolValue + config.LogConsumeEnabled = boolValue case "DisplayInCurrencyEnabled": - common.DisplayInCurrencyEnabled = boolValue + config.DisplayInCurrencyEnabled = boolValue case "DisplayTokenStatEnabled": - common.DisplayTokenStatEnabled = boolValue + config.DisplayTokenStatEnabled = boolValue } } switch key { case "EmailDomainWhitelist": - common.EmailDomainWhitelist = strings.Split(value, ",") + config.EmailDomainWhitelist = strings.Split(value, ",") case "SMTPServer": - common.SMTPServer = value + config.SMTPServer = value case "SMTPPort": intValue, _ := strconv.Atoi(value) - common.SMTPPort = intValue + config.SMTPPort = intValue case "SMTPAccount": - common.SMTPAccount = value + config.SMTPAccount = value case "SMTPFrom": - common.SMTPFrom = value + config.SMTPFrom = value case "SMTPToken": - common.SMTPToken = value + config.SMTPToken = value case "ServerAddress": - common.ServerAddress = value + config.ServerAddress = value case "GitHubClientId": - common.GitHubClientId = value + config.GitHubClientId = value case "GitHubClientSecret": - common.GitHubClientSecret = value + config.GitHubClientSecret = value case "Footer": - common.Footer = value + config.Footer = value case "SystemName": - common.SystemName = value + config.SystemName = value case "Logo": - common.Logo = value + config.Logo = value case "WeChatServerAddress": - common.WeChatServerAddress = value + config.WeChatServerAddress = value case "WeChatServerToken": - common.WeChatServerToken = value + config.WeChatServerToken = value case "WeChatAccountQRCodeImageURL": - common.WeChatAccountQRCodeImageURL = value + config.WeChatAccountQRCodeImageURL = value case "TurnstileSiteKey": - common.TurnstileSiteKey = value + config.TurnstileSiteKey = value case "TurnstileSecretKey": - common.TurnstileSecretKey = value + config.TurnstileSecretKey = value case "QuotaForNewUser": - common.QuotaForNewUser, _ = strconv.Atoi(value) + config.QuotaForNewUser, _ = strconv.Atoi(value) case "QuotaForInviter": - common.QuotaForInviter, _ = strconv.Atoi(value) + config.QuotaForInviter, _ = strconv.Atoi(value) case "QuotaForInvitee": - common.QuotaForInvitee, _ = strconv.Atoi(value) + config.QuotaForInvitee, _ = strconv.Atoi(value) case "QuotaRemindThreshold": - common.QuotaRemindThreshold, _ = strconv.Atoi(value) + config.QuotaRemindThreshold, _ = strconv.Atoi(value) case "PreConsumedQuota": - common.PreConsumedQuota, _ = strconv.Atoi(value) + config.PreConsumedQuota, _ = strconv.Atoi(value) case "RetryTimes": - common.RetryTimes, _ = strconv.Atoi(value) + config.RetryTimes, _ = strconv.Atoi(value) case "ModelRatio": err = common.UpdateModelRatioByJSONString(value) case "GroupRatio": err = common.UpdateGroupRatioByJSONString(value) case "TopUpLink": - common.TopUpLink = value + config.TopUpLink = value case "ChatLink": - common.ChatLink = value + config.ChatLink = value case "ChannelDisableThreshold": - common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) + config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) case "QuotaPerUnit": - common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + case "Theme": + config.Theme = value } return err } diff --git a/model/redemption.go b/model/redemption.go index f16412b5..026794e0 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -5,6 +5,7 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/common/helper" ) type Redemption struct { @@ -67,7 +68,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return err } - redemption.RedeemedTime = common.GetTimestamp() + redemption.RedeemedTime = helper.GetTimestamp() redemption.Status = common.RedemptionCodeStatusUsed err = tx.Save(redemption).Error return err diff --git a/model/token.go b/model/token.go index 0fa984d3..2087225b 100644 --- a/model/token.go +++ b/model/token.go @@ -5,6 +5,9 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/common/config" + "one-api/common/helper" + "one-api/common/logger" ) type Token struct { @@ -38,39 +41,43 @@ func ValidateUserToken(key string) (token *Token, err error) { return nil, errors.New("未提供令牌") } token, err = CacheGetTokenByKey(key) - if err == nil { - if token.Status == common.TokenStatusExhausted { - return nil, errors.New("该令牌额度已用尽") - } else if token.Status == common.TokenStatusExpired { - return nil, errors.New("该令牌已过期") + if err != nil { + logger.SysError("CacheGetTokenByKey failed: " + err.Error()) + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("无效的令牌") } - if token.Status != common.TokenStatusEnabled { - return nil, errors.New("该令牌状态不可用") - } - if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { - if !common.RedisEnabled { - token.Status = common.TokenStatusExpired - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) - } - } - return nil, errors.New("该令牌已过期") - } - if !token.UnlimitedQuota && token.RemainQuota <= 0 { - if !common.RedisEnabled { - // in this case, we can make sure the token is exhausted - token.Status = common.TokenStatusExhausted - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) - } - } - return nil, errors.New("该令牌额度已用尽") - } - return token, nil + return nil, errors.New("令牌验证失败") } - return nil, errors.New("无效的令牌") + if token.Status == common.TokenStatusExhausted { + return nil, errors.New("该令牌额度已用尽") + } else if token.Status == common.TokenStatusExpired { + return nil, errors.New("该令牌已过期") + } + if token.Status != common.TokenStatusEnabled { + return nil, errors.New("该令牌状态不可用") + } + if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { + if !common.RedisEnabled { + token.Status = common.TokenStatusExpired + err := token.SelectUpdate() + if err != nil { + logger.SysError("failed to update token status" + err.Error()) + } + } + return nil, errors.New("该令牌已过期") + } + if !token.UnlimitedQuota && token.RemainQuota <= 0 { + if !common.RedisEnabled { + // in this case, we can make sure the token is exhausted + token.Status = common.TokenStatusExhausted + err := token.SelectUpdate() + if err != nil { + logger.SysError("failed to update token status" + err.Error()) + } + } + return nil, errors.New("该令牌额度已用尽") + } + return token, nil } func GetTokenByIds(id int, userId int) (*Token, error) { @@ -134,7 +141,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeTokenQuota, id, quota) return nil } @@ -146,7 +153,7 @@ func increaseTokenQuota(id int, quota int) (err error) { map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), "used_quota": gorm.Expr("used_quota - ?", quota), - "accessed_time": common.GetTimestamp(), + "accessed_time": helper.GetTimestamp(), }, ).Error return err @@ -156,7 +163,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) return nil } @@ -168,7 +175,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota), - "accessed_time": common.GetTimestamp(), + "accessed_time": helper.GetTimestamp(), }, ).Error return err @@ -192,24 +199,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { if userQuota < quota { return errors.New("用户额度不足") } - quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold + quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold noMoreQuota := userQuota-quota <= 0 if quotaTooLow || noMoreQuota { go func() { email, err := GetUserEmail(token.UserId) if err != nil { - common.SysError("failed to fetch user email: " + err.Error()) + logger.SysError("failed to fetch user email: " + err.Error()) } prompt := "您的额度即将用尽" if noMoreQuota { prompt = "您的额度已用尽" } if email != "" { - topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) + topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) err = common.SendEmail(prompt, email, fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) if err != nil { - common.SysError("failed to send email" + err.Error()) + logger.SysError("failed to send email" + err.Error()) } } }() diff --git a/model/user.go b/model/user.go index e738b1ba..82e9707b 100644 --- a/model/user.go +++ b/model/user.go @@ -5,6 +5,9 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/common/config" + "one-api/common/helper" + "one-api/common/logger" "strings" ) @@ -15,7 +18,7 @@ type User struct { Username string `json:"username" gorm:"unique;index" validate:"max=12"` Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` - Role int `json:"role" gorm:"type:int;default:1"` // admin, common + Role int `json:"role" gorm:"type:int;default:1"` // admin, util Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` @@ -89,24 +92,24 @@ func (user *User) Insert(inviterId int) error { return err } } - user.Quota = common.QuotaForNewUser - user.AccessToken = common.GetUUID() - user.AffCode = common.GetRandomString(4) + user.Quota = config.QuotaForNewUser + user.AccessToken = helper.GetUUID() + user.AffCode = helper.GetRandomString(4) result := DB.Create(user) if result.Error != nil { return result.Error } - if common.QuotaForNewUser > 0 { - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) + if config.QuotaForNewUser > 0 { + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) } if inviterId != 0 { - if common.QuotaForInvitee > 0 { - _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) + if config.QuotaForInvitee > 0 { + _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) } - if common.QuotaForInviter > 0 { - _ = IncreaseUserQuota(inviterId, common.QuotaForInviter) - RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) + if config.QuotaForInviter > 0 { + _ = IncreaseUserQuota(inviterId, config.QuotaForInviter) + RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) } } return nil @@ -141,7 +144,15 @@ func (user *User) ValidateAndFill() (err error) { if user.Username == "" || password == "" { return errors.New("用户名或密码为空") } - DB.Where(User{Username: user.Username}).First(user) + err = DB.Where("username = ?", user.Username).First(user).Error + if err != nil { + // we must make sure check username firstly + // consider this case: a malicious user set his username as other's email + err := DB.Where("email = ?", user.Username).First(user).Error + if err != nil { + return errors.New("用户名或密码错误,或用户已被封禁") + } + } okay := common.ValidatePasswordAndHash(password, user.Password) if !okay || user.Status != common.UserStatusEnabled { return errors.New("用户名或密码错误,或用户已被封禁") @@ -224,7 +235,7 @@ func IsAdmin(userId int) bool { var user User err := DB.Where("id = ?", userId).Select("role").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) + logger.SysError("no such user " + err.Error()) return false } return user.Role >= common.RoleAdminUser @@ -283,7 +294,7 @@ func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, quota) return nil } @@ -299,7 +310,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, -quota) return nil } @@ -317,7 +328,7 @@ func GetRootUserEmail() (email string) { } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeRequestCount, id, 1) return @@ -333,7 +344,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota and request count: " + err.Error()) + logger.SysError("failed to update user used quota and request count: " + err.Error()) } } @@ -344,14 +355,14 @@ func updateUserUsedQuota(id int, quota int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota: " + err.Error()) + logger.SysError("failed to update user used quota: " + err.Error()) } } func updateUserRequestCount(id int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error if err != nil { - common.SysError("failed to update user request count: " + err.Error()) + logger.SysError("failed to update user request count: " + err.Error()) } } diff --git a/model/utils.go b/model/utils.go index 1c28340b..e0826e0d 100644 --- a/model/utils.go +++ b/model/utils.go @@ -1,7 +1,8 @@ package model import ( - "one-api/common" + "one-api/common/config" + "one-api/common/logger" "sync" "time" ) @@ -28,7 +29,7 @@ func init() { func InitBatchUpdater() { go func() { for { - time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) + time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second) batchUpdate() } }() @@ -45,7 +46,7 @@ func addNewRecord(type_ int, id int, value int) { } func batchUpdate() { - common.SysLog("batch update started") + logger.SysLog("batch update started") for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() store := batchUpdateStores[i] @@ -57,12 +58,12 @@ func batchUpdate() { case BatchUpdateTypeUserQuota: err := increaseUserQuota(key, value) if err != nil { - common.SysError("failed to batch update user quota: " + err.Error()) + logger.SysError("failed to batch update user quota: " + err.Error()) } case BatchUpdateTypeTokenQuota: err := increaseTokenQuota(key, value) if err != nil { - common.SysError("failed to batch update token quota: " + err.Error()) + logger.SysError("failed to batch update token quota: " + err.Error()) } case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) @@ -73,5 +74,5 @@ func batchUpdate() { } } } - common.SysLog("batch update finished") + logger.SysLog("batch update finished") } diff --git a/relay/channel/aiproxy/adaptor.go b/relay/channel/aiproxy/adaptor.go new file mode 100644 index 00000000..44b6f58d --- /dev/null +++ b/relay/channel/aiproxy/adaptor.go @@ -0,0 +1,22 @@ +package aiproxy + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/controller/relay-aiproxy.go b/relay/channel/aiproxy/main.go similarity index 52% rename from controller/relay-aiproxy.go rename to relay/channel/aiproxy/main.go index 543954f7..af9cd6f6 100644 --- a/controller/relay-aiproxy.go +++ b/relay/channel/aiproxy/main.go @@ -1,4 +1,4 @@ -package controller +package aiproxy import ( "bufio" @@ -8,56 +8,29 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" + "one-api/common/logger" + "one-api/relay/channel/openai" + "one-api/relay/constant" "strconv" "strings" ) // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 -type AIProxyLibraryRequest struct { - Model string `json:"model"` - Query string `json:"query"` - LibraryId string `json:"libraryId"` - Stream bool `json:"stream"` -} - -type AIProxyLibraryError struct { - ErrCode int `json:"errCode"` - Message string `json:"message"` -} - -type AIProxyLibraryDocument struct { - Title string `json:"title"` - URL string `json:"url"` -} - -type AIProxyLibraryResponse struct { - Success bool `json:"success"` - Answer string `json:"answer"` - Documents []AIProxyLibraryDocument `json:"documents"` - AIProxyLibraryError -} - -type AIProxyLibraryStreamResponse struct { - Content string `json:"content"` - Finish bool `json:"finish"` - Model string `json:"model"` - Documents []AIProxyLibraryDocument `json:"documents"` -} - -func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { +func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest { query := "" if len(request.Messages) != 0 { query = request.Messages[len(request.Messages)-1].StringContent() } - return &AIProxyLibraryRequest{ + return &LibraryRequest{ Model: request.Model, Stream: request.Stream, Query: query, } } -func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { +func aiProxyDocuments2Markdown(documents []LibraryDocument) string { if len(documents) == 0 { return "" } @@ -68,52 +41,52 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { return content } -func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { +func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse { content := response.Answer + aiProxyDocuments2Markdown(response.Documents) - choice := OpenAITextResponseChoice{ + choice := openai.TextResponseChoice{ Index: 0, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: content, }, FinishReason: "stop", } - fullTextResponse := OpenAITextResponse{ - Id: common.GetUUID(), + fullTextResponse := openai.TextResponse{ + Id: helper.GetUUID(), Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, } return &fullTextResponse } -func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = aiProxyDocuments2Markdown(documents) - choice.FinishReason = &stopFinishReason - return &ChatCompletionsStreamResponse{ - Id: common.GetUUID(), + choice.FinishReason = &constant.StopFinishReason + return &openai.ChatCompletionsStreamResponse{ + Id: helper.GetUUID(), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } } -func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = response.Content - return &ChatCompletionsStreamResponse{ - Id: common.GetUUID(), + return &openai.ChatCompletionsStreamResponse{ + Id: helper.GetUUID(), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: response.Model, - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } } -func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var usage openai.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -143,15 +116,15 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr } stopChan <- true }() - setEventStreamHeaders(c) - var documents []AIProxyLibraryDocument + common.SetEventStreamHeaders(c) + var documents []LibraryDocument c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - var AIProxyLibraryResponse AIProxyLibraryStreamResponse + var AIProxyLibraryResponse LibraryStreamResponse err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } if len(AIProxyLibraryResponse.Documents) != 0 { @@ -160,7 +133,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -169,7 +142,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr response := documentsAIProxyLibrary(documents) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -179,28 +152,28 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, &usage } -func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var AIProxyLibraryResponse AIProxyLibraryResponse +func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var AIProxyLibraryResponse LibraryResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if AIProxyLibraryResponse.ErrCode != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: AIProxyLibraryResponse.Message, Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode), Code: AIProxyLibraryResponse.ErrCode, @@ -211,7 +184,7 @@ func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/aiproxy/model.go b/relay/channel/aiproxy/model.go new file mode 100644 index 00000000..39689b3d --- /dev/null +++ b/relay/channel/aiproxy/model.go @@ -0,0 +1,32 @@ +package aiproxy + +type LibraryRequest struct { + Model string `json:"model"` + Query string `json:"query"` + LibraryId string `json:"libraryId"` + Stream bool `json:"stream"` +} + +type LibraryError struct { + ErrCode int `json:"errCode"` + Message string `json:"message"` +} + +type LibraryDocument struct { + Title string `json:"title"` + URL string `json:"url"` +} + +type LibraryResponse struct { + Success bool `json:"success"` + Answer string `json:"answer"` + Documents []LibraryDocument `json:"documents"` + LibraryError +} + +type LibraryStreamResponse struct { + Content string `json:"content"` + Finish bool `json:"finish"` + Model string `json:"model"` + Documents []LibraryDocument `json:"documents"` +} diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go new file mode 100644 index 00000000..49022cfc --- /dev/null +++ b/relay/channel/ali/adaptor.go @@ -0,0 +1,22 @@ +package ali + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/controller/relay-ali.go b/relay/channel/ali/main.go similarity index 51% rename from controller/relay-ali.go rename to relay/channel/ali/main.go index df1cc084..81dc5370 100644 --- a/controller/relay-ali.go +++ b/relay/channel/ali/main.go @@ -1,4 +1,4 @@ -package controller +package ali import ( "bufio" @@ -7,112 +7,45 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" + "one-api/common/logger" + "one-api/relay/channel/openai" "strings" ) // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r -type AliMessage struct { - Content string `json:"content"` - Role string `json:"role"` -} +const EnableSearchModelSuffix = "-internet" -type AliInput struct { - //Prompt string `json:"prompt"` - Messages []AliMessage `json:"messages"` -} - -type AliParameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` - IncrementalOutput bool `json:"incremental_output,omitempty"` -} - -type AliChatRequest struct { - Model string `json:"model"` - Input AliInput `json:"input"` - Parameters AliParameters `json:"parameters,omitempty"` -} - -type AliEmbeddingRequest struct { - Model string `json:"model"` - Input struct { - Texts []string `json:"texts"` - } `json:"input"` - Parameters *struct { - TextType string `json:"text_type,omitempty"` - } `json:"parameters,omitempty"` -} - -type AliEmbedding struct { - Embedding []float64 `json:"embedding"` - TextIndex int `json:"text_index"` -} - -type AliEmbeddingResponse struct { - Output struct { - Embeddings []AliEmbedding `json:"embeddings"` - } `json:"output"` - Usage AliUsage `json:"usage"` - AliError -} - -type AliError struct { - Code string `json:"code"` - Message string `json:"message"` - RequestId string `json:"request_id"` -} - -type AliUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type AliOutput struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` -} - -type AliChatResponse struct { - Output AliOutput `json:"output"` - Usage AliUsage `json:"usage"` - AliError -} - -const AliEnableSearchModelSuffix = "-internet" - -func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { - messages := make([]AliMessage, 0, len(request.Messages)) +func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { + messages := make([]Message, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] - messages = append(messages, AliMessage{ + messages = append(messages, Message{ Content: message.StringContent(), Role: strings.ToLower(message.Role), }) } enableSearch := false aliModel := request.Model - if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) { + if strings.HasSuffix(aliModel, EnableSearchModelSuffix) { enableSearch = true - aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix) + aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) } - return &AliChatRequest{ + return &ChatRequest{ Model: aliModel, - Input: AliInput{ + Input: Input{ Messages: messages, }, - Parameters: AliParameters{ + Parameters: Parameters{ EnableSearch: enableSearch, IncrementalOutput: request.Stream, }, } } -func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { - return &AliEmbeddingRequest{ +func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ Model: "text-embedding-v1", Input: struct { Texts []string `json:"texts"` @@ -122,21 +55,21 @@ func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingReque } } -func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var aliResponse AliEmbeddingResponse +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var aliResponse EmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&aliResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } if aliResponse.Code != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: aliResponse.Message, Type: aliResponse.Code, Param: aliResponse.RequestId, @@ -149,7 +82,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -157,16 +90,16 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS return nil, &fullTextResponse.Usage } -func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { - openAIEmbeddingResponse := OpenAIEmbeddingResponse{ +func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ Object: "list", - Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)), Model: "text-embedding-v1", - Usage: Usage{TotalTokens: response.Usage.TotalTokens}, + Usage: openai.Usage{TotalTokens: response.Usage.TotalTokens}, } for _, item := range response.Output.Embeddings { - openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ Object: `embedding`, Index: item.TextIndex, Embedding: item.Embedding, @@ -175,21 +108,21 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin return &openAIEmbeddingResponse } -func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { - choice := OpenAITextResponseChoice{ +func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { + choice := openai.TextResponseChoice{ Index: 0, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: response.Output.Text, }, FinishReason: response.Output.FinishReason, } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := openai.TextResponse{ Id: response.RequestId, Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, - Usage: Usage{ + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + Usage: openai.Usage{ PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, @@ -198,25 +131,25 @@ func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { return &fullTextResponse } -func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = aliResponse.Output.Text if aliResponse.Output.FinishReason != "null" { finishReason := aliResponse.Output.FinishReason choice.FinishReason = &finishReason } - response := ChatCompletionsStreamResponse{ + response := openai.ChatCompletionsStreamResponse{ Id: aliResponse.RequestId, Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "qwen", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response } -func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var usage openai.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -246,15 +179,15 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) //lastResponseText := "" c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - var aliResponse AliChatResponse + var aliResponse ChatResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } if aliResponse.Usage.OutputTokens != 0 { @@ -267,7 +200,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat //lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -279,28 +212,28 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, &usage } -func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var aliResponse AliChatResponse +func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var aliResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if aliResponse.Code != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: aliResponse.Message, Type: aliResponse.Code, Param: aliResponse.RequestId, @@ -313,7 +246,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode fullTextResponse.Model = "qwen" jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go new file mode 100644 index 00000000..54f13041 --- /dev/null +++ b/relay/channel/ali/model.go @@ -0,0 +1,71 @@ +package ali + +type Message struct { + Content string `json:"content"` + Role string `json:"role"` +} + +type Input struct { + //Prompt string `json:"prompt"` + Messages []Message `json:"messages"` +} + +type Parameters struct { + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameters Parameters `json:"parameters,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type Embedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type EmbeddingResponse struct { + Output struct { + Embeddings []Embedding `json:"embeddings"` + } `json:"output"` + Usage Usage `json:"usage"` + Error +} + +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Output struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +type ChatResponse struct { + Output Output `json:"output"` + Usage Usage `json:"usage"` + Error +} diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go new file mode 100644 index 00000000..55577228 --- /dev/null +++ b/relay/channel/anthropic/adaptor.go @@ -0,0 +1,22 @@ +package anthropic + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/controller/relay-claude.go b/relay/channel/anthropic/main.go similarity index 59% rename from controller/relay-claude.go rename to relay/channel/anthropic/main.go index ca7a701a..060fcde8 100644 --- a/controller/relay-claude.go +++ b/relay/channel/anthropic/main.go @@ -1,4 +1,4 @@ -package controller +package anthropic import ( "bufio" @@ -8,37 +8,12 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" + "one-api/common/logger" + "one-api/relay/channel/openai" "strings" ) -type ClaudeMetadata struct { - UserId string `json:"user_id"` -} - -type ClaudeRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - MaxTokensToSample int `json:"max_tokens_to_sample"` - StopSequences []string `json:"stop_sequences,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - //ClaudeMetadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` -} - -type ClaudeError struct { - Type string `json:"type"` - Message string `json:"message"` -} - -type ClaudeResponse struct { - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error ClaudeError `json:"error"` -} - func stopReasonClaude2OpenAI(reason string) string { switch reason { case "stop_sequence": @@ -50,8 +25,8 @@ func stopReasonClaude2OpenAI(reason string) string { } } -func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { - claudeRequest := ClaudeRequest{ +func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request { + claudeRequest := Request{ Model: textRequest.Model, Prompt: "", MaxTokensToSample: textRequest.MaxTokens, @@ -80,43 +55,43 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { return &claudeRequest } -func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = claudeResponse.Completion finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) if finishReason != "null" { choice.FinishReason = &finishReason } - var response ChatCompletionsStreamResponse + var response openai.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model - response.Choices = []ChatCompletionsStreamResponseChoice{choice} + response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} return &response } -func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { - choice := OpenAITextResponseChoice{ +func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { + choice := openai.TextResponseChoice{ Index: 0, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: strings.TrimPrefix(claudeResponse.Completion, " "), Name: nil, }, FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), } - fullTextResponse := OpenAITextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, } return &fullTextResponse } -func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() + responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) + createdTime := helper.GetTimestamp() scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -143,16 +118,16 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: // some implementations may add \r at the end of data data = strings.TrimSuffix(data, "\r") - var claudeResponse ClaudeResponse + var claudeResponse Response err := json.Unmarshal([]byte(data), &claudeResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } responseText += claudeResponse.Completion @@ -161,7 +136,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS response.Created = createdTime jsonStr, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) @@ -173,28 +148,28 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - var claudeResponse ClaudeResponse + var claudeResponse Response err = json.Unmarshal(responseBody, &claudeResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if claudeResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: claudeResponse.Error.Message, Type: claudeResponse.Error.Type, Param: "", @@ -205,8 +180,8 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model } fullTextResponse := responseClaude2OpenAI(&claudeResponse) fullTextResponse.Model = model - completionTokens := countTokenText(claudeResponse.Completion, model) - usage := Usage{ + completionTokens := openai.CountTokenText(claudeResponse.Completion, model) + usage := openai.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, @@ -214,7 +189,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/anthropic/model.go b/relay/channel/anthropic/model.go new file mode 100644 index 00000000..70fc9430 --- /dev/null +++ b/relay/channel/anthropic/model.go @@ -0,0 +1,29 @@ +package anthropic + +type Metadata struct { + UserId string `json:"user_id"` +} + +type Request struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + MaxTokensToSample int `json:"max_tokens_to_sample"` + StopSequences []string `json:"stop_sequences,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + //Metadata `json:"metadata,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type Error struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type Response struct { + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Model string `json:"model"` + Error Error `json:"error"` +} diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go new file mode 100644 index 00000000..498b664a --- /dev/null +++ b/relay/channel/baidu/adaptor.go @@ -0,0 +1,22 @@ +package baidu + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/controller/relay-baidu.go b/relay/channel/baidu/main.go similarity index 55% rename from controller/relay-baidu.go rename to relay/channel/baidu/main.go index dca30da1..f5b98155 100644 --- a/controller/relay-baidu.go +++ b/relay/channel/baidu/main.go @@ -1,4 +1,4 @@ -package controller +package baidu import ( "bufio" @@ -9,6 +9,10 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" + "one-api/relay/channel/openai" + "one-api/relay/constant" + "one-api/relay/util" "strings" "sync" "time" @@ -16,148 +20,104 @@ import ( // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 -type BaiduTokenResponse struct { +type TokenResponse struct { ExpiresIn int `json:"expires_in"` AccessToken string `json:"access_token"` } -type BaiduMessage struct { +type Message struct { Role string `json:"role"` Content string `json:"content"` } -type BaiduChatRequest struct { - Messages []BaiduMessage `json:"messages"` - Stream bool `json:"stream"` - UserId string `json:"user_id,omitempty"` +type ChatRequest struct { + Messages []Message `json:"messages"` + Stream bool `json:"stream"` + UserId string `json:"user_id,omitempty"` } -type BaiduError struct { +type Error struct { ErrorCode int `json:"error_code"` ErrorMsg string `json:"error_msg"` } -type BaiduChatResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Result string `json:"result"` - IsTruncated bool `json:"is_truncated"` - NeedClearHistory bool `json:"need_clear_history"` - Usage Usage `json:"usage"` - BaiduError -} - -type BaiduChatStreamResponse struct { - BaiduChatResponse - SentenceId int `json:"sentence_id"` - IsEnd bool `json:"is_end"` -} - -type BaiduEmbeddingRequest struct { - Input []string `json:"input"` -} - -type BaiduEmbeddingData struct { - Object string `json:"object"` - Embedding []float64 `json:"embedding"` - Index int `json:"index"` -} - -type BaiduEmbeddingResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Data []BaiduEmbeddingData `json:"data"` - Usage Usage `json:"usage"` - BaiduError -} - -type BaiduAccessToken struct { - AccessToken string `json:"access_token"` - Error string `json:"error,omitempty"` - ErrorDescription string `json:"error_description,omitempty"` - ExpiresIn int64 `json:"expires_in,omitempty"` - ExpiresAt time.Time `json:"-"` -} - var baiduTokenStore sync.Map -func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { - messages := make([]BaiduMessage, 0, len(request.Messages)) +func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { + messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { - messages = append(messages, BaiduMessage{ + messages = append(messages, Message{ Role: "user", Content: message.StringContent(), }) - messages = append(messages, BaiduMessage{ + messages = append(messages, Message{ Role: "assistant", Content: "Okay", }) } else { - messages = append(messages, BaiduMessage{ + messages = append(messages, Message{ Role: message.Role, Content: message.StringContent(), }) } } - return &BaiduChatRequest{ + return &ChatRequest{ Messages: messages, Stream: request.Stream, } } -func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { - choice := OpenAITextResponseChoice{ +func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { + choice := openai.TextResponseChoice{ Index: 0, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: response.Result, }, FinishReason: "stop", } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := openai.TextResponse{ Id: response.Id, Object: "chat.completion", Created: response.Created, - Choices: []OpenAITextResponseChoice{choice}, + Choices: []openai.TextResponseChoice{choice}, Usage: response.Usage, } return &fullTextResponse } -func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = baiduResponse.Result if baiduResponse.IsEnd { - choice.FinishReason = &stopFinishReason + choice.FinishReason = &constant.StopFinishReason } - response := ChatCompletionsStreamResponse{ + response := openai.ChatCompletionsStreamResponse{ Id: baiduResponse.Id, Object: "chat.completion.chunk", Created: baiduResponse.Created, Model: "ernie-bot", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response } -func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { - return &BaiduEmbeddingRequest{ +func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ Input: request.ParseInput(), } } -func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { - openAIEmbeddingResponse := OpenAIEmbeddingResponse{ +func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ Object: "list", - Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)), Model: "baidu-embedding", Usage: response.Usage, } for _, item := range response.Data { - openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ Object: item.Object, Index: item.Index, Embedding: item.Embedding, @@ -166,8 +126,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe return &openAIEmbeddingResponse } -func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var usage openai.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -194,14 +154,14 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - var baiduResponse BaiduChatStreamResponse + var baiduResponse ChatStreamResponse err := json.Unmarshal([]byte(data), &baiduResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } if baiduResponse.Usage.TotalTokens != 0 { @@ -212,7 +172,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt response := streamResponseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -224,28 +184,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, &usage } -func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var baiduResponse BaiduChatResponse +func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var baiduResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if baiduResponse.ErrorMsg != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: baiduResponse.ErrorMsg, Type: "baidu_error", Param: "", @@ -258,7 +218,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo fullTextResponse.Model = "ernie-bot" jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -266,23 +226,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo return nil, &fullTextResponse.Usage } -func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var baiduResponse BaiduEmbeddingResponse +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var baiduResponse EmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if baiduResponse.ErrorMsg != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: baiduResponse.ErrorMsg, Type: "baidu_error", Param: "", @@ -294,7 +254,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -302,10 +262,10 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit return nil, &fullTextResponse.Usage } -func getBaiduAccessToken(apiKey string) (string, error) { +func GetAccessToken(apiKey string) (string, error) { if val, ok := baiduTokenStore.Load(apiKey); ok { - var accessToken BaiduAccessToken - if accessToken, ok = val.(BaiduAccessToken); ok { + var accessToken AccessToken + if accessToken, ok = val.(AccessToken); ok { // soon this will expire if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { go func() { @@ -320,12 +280,12 @@ func getBaiduAccessToken(apiKey string) (string, error) { return "", err } if accessToken == nil { - return "", errors.New("getBaiduAccessToken return a nil token") + return "", errors.New("GetAccessToken return a nil token") } return (*accessToken).AccessToken, nil } -func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { +func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) { parts := strings.Split(apiKey, "|") if len(parts) != 2 { return nil, errors.New("invalid baidu apikey") @@ -337,13 +297,13 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { } req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") - res, err := impatientHTTPClient.Do(req) + res, err := util.ImpatientHTTPClient.Do(req) if err != nil { return nil, err } defer res.Body.Close() - var accessToken BaiduAccessToken + var accessToken AccessToken err = json.NewDecoder(res.Body).Decode(&accessToken) if err != nil { return nil, err diff --git a/relay/channel/baidu/model.go b/relay/channel/baidu/model.go new file mode 100644 index 00000000..e182f5dd --- /dev/null +++ b/relay/channel/baidu/model.go @@ -0,0 +1,50 @@ +package baidu + +import ( + "one-api/relay/channel/openai" + "time" +) + +type ChatResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage openai.Usage `json:"usage"` + Error +} + +type ChatStreamResponse struct { + ChatResponse + SentenceId int `json:"sentence_id"` + IsEnd bool `json:"is_end"` +} + +type EmbeddingRequest struct { + Input []string `json:"input"` +} + +type EmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type EmbeddingResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Data []EmbeddingData `json:"data"` + Usage openai.Usage `json:"usage"` + Error +} + +type AccessToken struct { + AccessToken string `json:"access_token"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"-"` +} diff --git a/relay/channel/google/adaptor.go b/relay/channel/google/adaptor.go new file mode 100644 index 00000000..b328db32 --- /dev/null +++ b/relay/channel/google/adaptor.go @@ -0,0 +1,22 @@ +package google + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/controller/relay-gemini.go b/relay/channel/google/gemini.go similarity index 65% rename from controller/relay-gemini.go rename to relay/channel/google/gemini.go index d8ab58d6..3adc3fdd 100644 --- a/controller/relay-gemini.go +++ b/relay/channel/google/gemini.go @@ -1,4 +1,4 @@ -package controller +package google import ( "bufio" @@ -7,7 +7,12 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/common/image" + "one-api/common/logger" + "one-api/relay/channel/openai" + "one-api/relay/constant" "strings" "github.com/gin-gonic/gin" @@ -19,66 +24,26 @@ const ( GeminiVisionMaxImageNum = 16 ) -type GeminiChatRequest struct { - Contents []GeminiChatContent `json:"contents"` - SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` - GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` - Tools []GeminiChatTools `json:"tools,omitempty"` -} - -type GeminiInlineData struct { - MimeType string `json:"mimeType"` - Data string `json:"data"` -} - -type GeminiPart struct { - Text string `json:"text,omitempty"` - InlineData *GeminiInlineData `json:"inlineData,omitempty"` -} - -type GeminiChatContent struct { - Role string `json:"role,omitempty"` - Parts []GeminiPart `json:"parts"` -} - -type GeminiChatSafetySettings struct { - Category string `json:"category"` - Threshold string `json:"threshold"` -} - -type GeminiChatTools struct { - FunctionDeclarations any `json:"functionDeclarations,omitempty"` -} - -type GeminiChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` -} - // Setting safety to the lowest possible values since Gemini is already powerless enough -func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { +func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest { geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), SafetySettings: []GeminiChatSafetySettings{ { Category: "HARM_CATEGORY_HARASSMENT", - Threshold: common.GeminiSafetySetting, + Threshold: config.GeminiSafetySetting, }, { Category: "HARM_CATEGORY_HATE_SPEECH", - Threshold: common.GeminiSafetySetting, + Threshold: config.GeminiSafetySetting, }, { Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", - Threshold: common.GeminiSafetySetting, + Threshold: config.GeminiSafetySetting, }, { Category: "HARM_CATEGORY_DANGEROUS_CONTENT", - Threshold: common.GeminiSafetySetting, + Threshold: config.GeminiSafetySetting, }, }, GenerationConfig: GeminiChatGenerationConfig{ @@ -108,11 +73,11 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { var parts []GeminiPart imageNum := 0 for _, part := range openaiContent { - if part.Type == ContentTypeText { + if part.Type == openai.ContentTypeText { parts = append(parts, GeminiPart{ Text: part.Text, }) - } else if part.Type == ContentTypeImageURL { + } else if part.Type == openai.ContentTypeImageURL { imageNum += 1 if imageNum > GeminiVisionMaxImageNum { continue @@ -187,21 +152,21 @@ type GeminiChatPromptFeedback struct { SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` } -func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), +func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + Created: helper.GetTimestamp(), + Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { - choice := OpenAITextResponseChoice{ + choice := openai.TextResponseChoice{ Index: i, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: "", }, - FinishReason: stopFinishReason, + FinishReason: constant.StopFinishReason, } if len(candidate.Content.Parts) > 0 { choice.Message.Content = candidate.Content.Parts[0].Text @@ -211,18 +176,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse return &fullTextResponse } -func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = geminiResponse.GetResponseText() - choice.FinishReason = &stopFinishReason - var response ChatCompletionsStreamResponse + choice.FinishReason = &constant.StopFinishReason + var response openai.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = "gemini" - response.Choices = []ChatCompletionsStreamResponseChoice{choice} + response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} return &response } -func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { responseText := "" dataChan := make(chan string) stopChan := make(chan bool) @@ -252,7 +217,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -264,18 +229,18 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW var dummy dummyStruct err := json.Unmarshal([]byte(data), &dummy) responseText += dummy.Content - var choice ChatCompletionsStreamResponseChoice + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = dummy.Content - response := ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "gemini-pro", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -287,28 +252,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } var geminiResponse GeminiChatResponse err = json.Unmarshal(responseBody, &geminiResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if len(geminiResponse.Candidates) == 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: "No candidates returned", Type: "server_error", Param: "", @@ -319,8 +284,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) fullTextResponse.Model = model - completionTokens := countTokenText(geminiResponse.GetResponseText(), model) - usage := Usage{ + completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model) + usage := openai.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, @@ -328,7 +293,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/google/model.go b/relay/channel/google/model.go new file mode 100644 index 00000000..694c2dd1 --- /dev/null +++ b/relay/channel/google/model.go @@ -0,0 +1,80 @@ +package google + +import ( + "one-api/relay/channel/openai" +) + +type GeminiChatRequest struct { + Contents []GeminiChatContent `json:"contents"` + SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + Tools []GeminiChatTools `json:"tools,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + +type GeminiChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type GeminiChatTools struct { + FunctionDeclarations any `json:"functionDeclarations,omitempty"` +} + +type GeminiChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +type PaLMChatMessage struct { + Author string `json:"author"` + Content string `json:"content"` +} + +type PaLMFilter struct { + Reason string `json:"reason"` + Message string `json:"message"` +} + +type PaLMPrompt struct { + Messages []PaLMChatMessage `json:"messages"` +} + +type PaLMChatRequest struct { + Prompt PaLMPrompt `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` +} + +type PaLMError struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` +} + +type PaLMChatResponse struct { + Candidates []PaLMChatMessage `json:"candidates"` + Messages []openai.Message `json:"messages"` + Filters []PaLMFilter `json:"filters"` + Error PaLMError `json:"error"` +} diff --git a/controller/relay-palm.go b/relay/channel/google/palm.go similarity index 56% rename from controller/relay-palm.go rename to relay/channel/google/palm.go index 0c1c8af6..3c86a432 100644 --- a/controller/relay-palm.go +++ b/relay/channel/google/palm.go @@ -1,4 +1,4 @@ -package controller +package google import ( "encoding/json" @@ -7,47 +7,16 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" + "one-api/common/logger" + "one-api/relay/channel/openai" + "one-api/relay/constant" ) // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body -type PaLMChatMessage struct { - Author string `json:"author"` - Content string `json:"content"` -} - -type PaLMFilter struct { - Reason string `json:"reason"` - Message string `json:"message"` -} - -type PaLMPrompt struct { - Messages []PaLMChatMessage `json:"messages"` -} - -type PaLMChatRequest struct { - Prompt PaLMPrompt `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` -} - -type PaLMError struct { - Code int `json:"code"` - Message string `json:"message"` - Status string `json:"status"` -} - -type PaLMChatResponse struct { - Candidates []PaLMChatMessage `json:"candidates"` - Messages []Message `json:"messages"` - Filters []PaLMFilter `json:"filters"` - Error PaLMError `json:"error"` -} - -func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { +func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest { palmRequest := PaLMChatRequest{ Prompt: PaLMPrompt{ Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), @@ -71,14 +40,14 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { return &palmRequest } -func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ - Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), +func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ + Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { - choice := OpenAITextResponseChoice{ + choice := openai.TextResponseChoice{ Index: i, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: candidate.Content, }, @@ -89,42 +58,42 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { return &fullTextResponse } -func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice if len(palmResponse.Candidates) > 0 { choice.Delta.Content = palmResponse.Candidates[0].Content } - choice.FinishReason = &stopFinishReason - var response ChatCompletionsStreamResponse + choice.FinishReason = &constant.StopFinishReason + var response openai.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = "palm2" - response.Choices = []ChatCompletionsStreamResponseChoice{choice} + response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} return &response } -func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() + responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) + createdTime := helper.GetTimestamp() dataChan := make(chan string) stopChan := make(chan bool) go func() { responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.SysError("error reading stream response: " + err.Error()) + logger.SysError("error reading stream response: " + err.Error()) stopChan <- true return } err = resp.Body.Close() if err != nil { - common.SysError("error closing stream response: " + err.Error()) + logger.SysError("error closing stream response: " + err.Error()) stopChan <- true return } var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) stopChan <- true return } @@ -136,14 +105,14 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta } jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) stopChan <- true return } dataChan <- string(jsonResponse) stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -156,28 +125,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: palmResponse.Error.Message, Type: palmResponse.Error.Status, Param: "", @@ -188,8 +157,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st } fullTextResponse := responsePaLM2OpenAI(&palmResponse) fullTextResponse.Model = model - completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) - usage := Usage{ + completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model) + usage := openai.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, @@ -197,7 +166,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/interface.go b/relay/channel/interface.go new file mode 100644 index 00000000..7a0fcbd3 --- /dev/null +++ b/relay/channel/interface.go @@ -0,0 +1,15 @@ +package channel + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor interface { + GetRequestURL() string + Auth(c *gin.Context) error + ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) + DoRequest(request *openai.GeneralOpenAIRequest) error + DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go new file mode 100644 index 00000000..cc302611 --- /dev/null +++ b/relay/channel/openai/adaptor.go @@ -0,0 +1,21 @@ +package openai + +import ( + "github.com/gin-gonic/gin" + "net/http" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*ErrorWithStatusCode, *Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go new file mode 100644 index 00000000..000f72ee --- /dev/null +++ b/relay/channel/openai/constant.go @@ -0,0 +1,6 @@ +package openai + +const ( + ContentTypeText = "text" + ContentTypeImageURL = "image_url" +) diff --git a/controller/relay-openai.go b/relay/channel/openai/main.go similarity index 73% rename from controller/relay-openai.go rename to relay/channel/openai/main.go index 37867843..5f464249 100644 --- a/controller/relay-openai.go +++ b/relay/channel/openai/main.go @@ -1,4 +1,4 @@ -package controller +package openai import ( "bufio" @@ -8,10 +8,12 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" + "one-api/relay/constant" "strings" ) -func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) { responseText := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -41,21 +43,21 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O data = data[6:] if !strings.HasPrefix(data, "[DONE]") { switch relayMode { - case RelayModeChatCompletions: + case constant.RelayModeChatCompletions: var streamResponse ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) continue // just ignore the error } for _, choice := range streamResponse.Choices { responseText += choice.Delta.Content } - case RelayModeCompletions: + case constant.RelayModeCompletions: var streamResponse CompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) continue } for _, choice := range streamResponse.Choices { @@ -66,7 +68,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -83,29 +85,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { - var textResponse TextResponse +func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) { + var textResponse SlimTextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &textResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if textResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: textResponse.Error, - StatusCode: resp.StatusCode, + return &ErrorWithStatusCode{ + Error: textResponse.Error, + StatusCode: resp.StatusCode, }, nil } // Reset response body @@ -113,7 +115,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. + // So the HTTPClient will be confused by the response. // For example, Postman will report error, and we cannot check the response at all. for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) @@ -121,17 +123,17 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } if textResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += countTokenText(choice.Message.StringContent(), model) + completionTokens += CountTokenText(choice.Message.StringContent(), model) } textResponse.Usage = Usage{ PromptTokens: promptTokens, diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go new file mode 100644 index 00000000..937fb424 --- /dev/null +++ b/relay/channel/openai/model.go @@ -0,0 +1,288 @@ +package openai + +type Message struct { + Role string `json:"role"` + Content any `json:"content"` + Name *string `json:"name,omitempty"` +} + +type ImageURL struct { + Url string `json:"url,omitempty"` + Detail string `json:"detail,omitempty"` +} + +type TextContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` +} + +type ImageContent struct { + Type string `json:"type,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +type OpenAIMessageContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +func (m Message) IsStringContent() bool { + _, ok := m.Content.(string) + return ok +} + +func (m Message) StringContent() string { + content, ok := m.Content.(string) + if ok { + return content + } + contentList, ok := m.Content.([]any) + if ok { + var contentStr string + for _, contentItem := range contentList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + if contentMap["type"] == ContentTypeText { + if subStr, ok := contentMap["text"].(string); ok { + contentStr += subStr + } + } + } + return contentStr + } + return "" +} + +func (m Message) ParseContent() []OpenAIMessageContent { + var contentList []OpenAIMessageContent + content, ok := m.Content.(string) + if ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeText, + Text: content, + }) + return contentList + } + anyList, ok := m.Content.([]any) + if ok { + for _, contentItem := range anyList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + switch contentMap["type"] { + case ContentTypeText: + if subStr, ok := contentMap["text"].(string); ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeText, + Text: subStr, + }) + } + case ContentTypeImageURL: + if subObj, ok := contentMap["image_url"].(map[string]any); ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeImageURL, + ImageURL: &ImageURL{ + Url: subObj["url"].(string), + }, + }) + } + } + } + return contentList + } + return nil +} + +type ResponseFormat struct { + Type string `json:"type,omitempty"` +} + +type GeneralOpenAIRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` +} + +func (r GeneralOpenAIRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} + +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` +} + +type TextRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` + //Stream bool `json:"stream"` +} + +// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` +} + +type WhisperJSONResponse struct { + Text string `json:"text,omitempty"` +} + +type WhisperVerboseJSONResponse struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Text string `json:"text,omitempty"` + Segments []Segment `json:"segments,omitempty"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} + +type TextToSpeechRequest struct { + Model string `json:"model" binding:"required"` + Input string `json:"input" binding:"required"` + Voice string `json:"voice" binding:"required"` + Speed float64 `json:"speed"` + ResponseFormat string `json:"response_format"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type UsageOrResponseText struct { + *Usage + ResponseText string +} + +type Error struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param"` + Code any `json:"code"` +} + +type ErrorWithStatusCode struct { + Error + StatusCode int `json:"status_code"` +} + +type SlimTextResponse struct { + Choices []TextResponseChoice `json:"choices"` + Usage `json:"usage"` + Error Error `json:"error"` +} + +type TextResponseChoice struct { + Index int `json:"index"` + Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type TextResponse struct { + Id string `json:"id"` + Model string `json:"model,omitempty"` + Object string `json:"object"` + Created int64 `json:"created"` + Choices []TextResponseChoice `json:"choices"` + Usage `json:"usage"` +} + +type EmbeddingResponseItem struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +type EmbeddingResponse struct { + Object string `json:"object"` + Data []EmbeddingResponseItem `json:"data"` + Model string `json:"model"` + Usage `json:"usage"` +} + +type ImageResponse struct { + Created int `json:"created"` + Data []struct { + Url string `json:"url"` + } +} + +type ChatCompletionsStreamResponseChoice struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +type ChatCompletionsStreamResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionsStreamResponseChoice `json:"choices"` +} + +type CompletionsStreamResponse struct { + Choices []struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} diff --git a/controller/relay-utils.go b/relay/channel/openai/token.go similarity index 51% rename from controller/relay-utils.go rename to relay/channel/openai/token.go index a6a1f0f6..6803770e 100644 --- a/controller/relay-utils.go +++ b/relay/channel/openai/token.go @@ -1,39 +1,31 @@ -package controller +package openai import ( - "context" - "encoding/json" "errors" "fmt" - "io" - "math" - "net/http" - "one-api/common" - "one-api/common/image" - "one-api/model" - "strconv" - "strings" - - "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" + "math" + "one-api/common" + "one-api/common/config" + "one-api/common/image" + "one-api/common/logger" + "strings" ) -var stopFinishReason = "stop" - // tokenEncoderMap won't grow after initialization var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var defaultTokenEncoder *tiktoken.Tiktoken func InitTokenEncoders() { - common.SysLog("initializing token encoders") + logger.SysLog("initializing token encoders") gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") if err != nil { - common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) + logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) } defaultTokenEncoder = gpt35TokenEncoder gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") if err != nil { - common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) + logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } for model, _ := range common.ModelRatio { if strings.HasPrefix(model, "gpt-3.5") { @@ -44,7 +36,7 @@ func InitTokenEncoders() { tokenEncoderMap[model] = nil } } - common.SysLog("token encoders initialized") + logger.SysLog("token encoders initialized") } func getTokenEncoder(model string) *tiktoken.Tiktoken { @@ -55,7 +47,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { if ok { tokenEncoder, err := tiktoken.EncodingForModel(model) if err != nil { - common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) + logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) tokenEncoder = defaultTokenEncoder } tokenEncoderMap[model] = tokenEncoder @@ -65,13 +57,13 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { } func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { - if common.ApproximateTokenEnabled { + if config.ApproximateTokenEnabled { return int(float64(len(text)) * 0.38) } return len(tokenEncoder.Encode(text, nil, nil)) } -func countTokenMessages(messages []Message, model string) int { +func CountTokenMessages(messages []Message, model string) int { tokenEncoder := getTokenEncoder(model) // Reference: // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb @@ -109,7 +101,7 @@ func countTokenMessages(messages []Message, model string) int { } imageTokens, err := countImageTokens(url, detail) if err != nil { - common.SysError("error counting image tokens: " + err.Error()) + logger.SysError("error counting image tokens: " + err.Error()) } else { tokenNum += imageTokens } @@ -195,191 +187,21 @@ func countImageTokens(url string, detail string) (_ int, err error) { } } -func countTokenInput(input any, model string) int { +func CountTokenInput(input any, model string) int { switch v := input.(type) { case string: - return countTokenText(v, model) + return CountTokenText(v, model) case []string: text := "" for _, s := range v { text += s } - return countTokenText(text, model) + return CountTokenText(text, model) } return 0 } -func countTokenText(text string, model string) int { +func CountTokenText(text string, model string) int { tokenEncoder := getTokenEncoder(model) return getTokenNum(tokenEncoder, text) } - -func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { - openAIError := OpenAIError{ - Message: err.Error(), - Type: "one_api_error", - Code: code, - } - return &OpenAIErrorWithStatusCode{ - OpenAIError: openAIError, - StatusCode: statusCode, - } -} - -func shouldDisableChannel(err *OpenAIError, statusCode int) bool { - if !common.AutomaticDisableChannelEnabled { - return false - } - if err == nil { - return false - } - if statusCode == http.StatusUnauthorized { - return true - } - if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { - return true - } - return false -} - -func shouldEnableChannel(err error, openAIErr *OpenAIError) bool { - if !common.AutomaticEnableChannelEnabled { - return false - } - if err != nil { - return false - } - if openAIErr != nil { - return false - } - return true -} - -func setEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") -} - -type GeneralErrorResponse struct { - Error OpenAIError `json:"error"` - Message string `json:"message"` - Msg string `json:"msg"` - Err string `json:"err"` - ErrorMsg string `json:"error_msg"` - Header struct { - Message string `json:"message"` - } `json:"header"` - Response struct { - Error struct { - Message string `json:"message"` - } `json:"error"` - } `json:"response"` -} - -func (e GeneralErrorResponse) ToMessage() string { - if e.Error.Message != "" { - return e.Error.Message - } - if e.Message != "" { - return e.Message - } - if e.Msg != "" { - return e.Msg - } - if e.Err != "" { - return e.Err - } - if e.ErrorMsg != "" { - return e.ErrorMsg - } - if e.Header.Message != "" { - return e.Header.Message - } - if e.Response.Error.Message != "" { - return e.Response.Error.Message - } - return "" -} - -func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { - openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ - StatusCode: resp.StatusCode, - OpenAIError: OpenAIError{ - Message: "", - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - err = resp.Body.Close() - if err != nil { - return - } - var errResponse GeneralErrorResponse - err = json.Unmarshal(responseBody, &errResponse) - if err != nil { - return - } - if errResponse.Error.Message != "" { - // OpenAI format error, so we override the default one - openAIErrorWithStatusCode.OpenAIError = errResponse.Error - } else { - openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage() - } - if openAIErrorWithStatusCode.OpenAIError.Message == "" { - openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) - } - return -} - -func getFullRequestURL(baseURL string, requestURL string, channelType int) string { - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - switch channelType { - case common.ChannelTypeOpenAI: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - case common.ChannelTypeAzure: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) - } - } - return fullRequestURL -} - -func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { - // quotaDelta is remaining quota to be consumed - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - // totalQuota is total quota consumed - if totalQuota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) - model.UpdateChannelUsedQuota(channelId, totalQuota) - } - if totalQuota <= 0 { - common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) - } -} - -func GetAPIVersion(c *gin.Context) string { - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") - } - return apiVersion -} diff --git a/relay/channel/openai/util.go b/relay/channel/openai/util.go new file mode 100644 index 00000000..69ece6b3 --- /dev/null +++ b/relay/channel/openai/util.go @@ -0,0 +1,13 @@ +package openai + +func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode { + Error := Error{ + Message: err.Error(), + Type: "one_api_error", + Code: code, + } + return &ErrorWithStatusCode{ + Error: Error, + StatusCode: statusCode, + } +} diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go new file mode 100644 index 00000000..e9f86aff --- /dev/null +++ b/relay/channel/tencent/adaptor.go @@ -0,0 +1,22 @@ +package tencent + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go new file mode 100644 index 00000000..64091d62 --- /dev/null +++ b/relay/channel/tencent/main.go @@ -0,0 +1,234 @@ +package tencent + +import ( + "bufio" + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/common/helper" + "one-api/common/logger" + "one-api/relay/channel/openai" + "one-api/relay/constant" + "sort" + "strconv" + "strings" +) + +// https://cloud.tencent.com/document/product/1729/97732 + +func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { + messages := make([]Message, 0, len(request.Messages)) + for i := 0; i < len(request.Messages); i++ { + message := request.Messages[i] + if message.Role == "system" { + messages = append(messages, Message{ + Role: "user", + Content: message.StringContent(), + }) + messages = append(messages, Message{ + Role: "assistant", + Content: "Okay", + }) + continue + } + messages = append(messages, Message{ + Content: message.StringContent(), + Role: message.Role, + }) + } + stream := 0 + if request.Stream { + stream = 1 + } + return &ChatRequest{ + Timestamp: helper.GetTimestamp(), + Expired: helper.GetTimestamp() + 24*60*60, + QueryID: helper.GetUUID(), + Temperature: request.Temperature, + TopP: request.TopP, + Stream: stream, + Messages: messages, + } +} + +func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ + Object: "chat.completion", + Created: helper.GetTimestamp(), + Usage: response.Usage, + } + if len(response.Choices) > 0 { + choice := openai.TextResponseChoice{ + Index: 0, + Message: openai.Message{ + Role: "assistant", + Content: response.Choices[0].Messages.Content, + }, + FinishReason: response.Choices[0].FinishReason, + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + response := openai.ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: "tencent-hunyuan", + } + if len(TencentResponse.Choices) > 0 { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = TencentResponse.Choices[0].Delta.Content + if TencentResponse.Choices[0].FinishReason == "stop" { + choice.FinishReason = &constant.StopFinishReason + } + response.Choices = append(response.Choices, choice) + } + return &response +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { + var responseText string + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { // ignore blank line or wrong format + continue + } + if data[:5] != "data:" { + continue + } + data = data[5:] + dataChan <- data + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var TencentResponse ChatResponse + err := json.Unmarshal([]byte(data), &TencentResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response := streamResponseTencent2OpenAI(&TencentResponse) + if len(response.Choices) != 0 { + responseText += response.Choices[0].Delta.Content + } + jsonResponse, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var TencentResponse ChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &TencentResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if TencentResponse.Error.Code != 0 { + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ + Message: TencentResponse.Error.Message, + Code: TencentResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseTencent2OpenAI(&TencentResponse) + fullTextResponse.Model = "hunyuan" + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) { + parts := strings.Split(config, "|") + if len(parts) != 3 { + err = errors.New("invalid tencent config") + return + } + appId, err = strconv.ParseInt(parts[0], 10, 64) + secretId = parts[1] + secretKey = parts[2] + return +} + +func GetSign(req ChatRequest, secretKey string) string { + params := make([]string, 0) + params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) + params = append(params, "secret_id="+req.SecretId) + params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) + params = append(params, "query_id="+req.QueryID) + params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) + params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) + params = append(params, "stream="+strconv.Itoa(req.Stream)) + params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) + + var messageStr string + for _, msg := range req.Messages { + messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) + } + messageStr = strings.TrimSuffix(messageStr, ",") + params = append(params, "messages=["+messageStr+"]") + + sort.Sort(sort.StringSlice(params)) + url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") + mac := hmac.New(sha1.New, []byte(secretKey)) + signURL := url + mac.Write([]byte(signURL)) + sign := mac.Sum([]byte(nil)) + return base64.StdEncoding.EncodeToString(sign) +} diff --git a/relay/channel/tencent/model.go b/relay/channel/tencent/model.go new file mode 100644 index 00000000..511f3d97 --- /dev/null +++ b/relay/channel/tencent/model.go @@ -0,0 +1,63 @@ +package tencent + +import ( + "one-api/relay/channel/openai" +) + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatRequest struct { + AppId int64 `json:"app_id"` // 腾讯云账号的 APPID + SecretId string `json:"secret_id"` // 官网 SecretId + // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 + // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 + Timestamp int64 `json:"timestamp"` + // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, + // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 + Expired int64 `json:"expired"` + QueryID string `json:"query_id"` //请求 Id,用于问题排查 + // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 + // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 + // 建议该参数和 top_p 只设置1个,不要同时更改 top_p + Temperature float64 `json:"temperature"` + // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 + // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 + // 建议该参数和 temperature 只设置1个,不要同时更改 + TopP float64 `json:"top_p"` + // Stream 0:同步,1:流式 (默认,协议:SSE) + // 同步请求超时:60s,如果内容较长建议使用流式 + Stream int `json:"stream"` + // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 + // 输入 content 总数最大支持 3000 token。 + Messages []Message `json:"messages"` +} + +type Error struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type ResponseChoices struct { + FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 + Messages Message `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 + Delta Message `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 +} + +type ChatResponse struct { + Choices []ResponseChoices `json:"choices,omitempty"` // 结果 + Created string `json:"created,omitempty"` // unix 时间戳的字符串 + Id string `json:"id,omitempty"` // 会话 id + Usage openai.Usage `json:"usage,omitempty"` // token 数量 + Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 + Note string `json:"note,omitempty"` // 注释 + ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 +} diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go new file mode 100644 index 00000000..89e58485 --- /dev/null +++ b/relay/channel/xunfei/adaptor.go @@ -0,0 +1,22 @@ +package xunfei + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/controller/relay-xunfei.go b/relay/channel/xunfei/main.go similarity index 58% rename from controller/relay-xunfei.go rename to relay/channel/xunfei/main.go index 904e6d14..906d2844 100644 --- a/controller/relay-xunfei.go +++ b/relay/channel/xunfei/main.go @@ -1,4 +1,4 @@ -package controller +package xunfei import ( "crypto/hmac" @@ -12,6 +12,10 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/common/helper" + "one-api/common/logger" + "one-api/relay/channel/openai" + "one-api/relay/constant" "strings" "time" ) @@ -19,82 +23,26 @@ import ( // https://console.xfyun.cn/services/cbm // https://www.xfyun.cn/doc/spark/Web.html -type XunfeiMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type XunfeiChatRequest struct { - Header struct { - AppId string `json:"app_id"` - } `json:"header"` - Parameter struct { - Chat struct { - Domain string `json:"domain,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Auditing bool `json:"auditing,omitempty"` - } `json:"chat"` - } `json:"parameter"` - Payload struct { - Message struct { - Text []XunfeiMessage `json:"text"` - } `json:"message"` - } `json:"payload"` -} - -type XunfeiChatResponseTextItem struct { - Content string `json:"content"` - Role string `json:"role"` - Index int `json:"index"` -} - -type XunfeiChatResponse struct { - Header struct { - Code int `json:"code"` - Message string `json:"message"` - Sid string `json:"sid"` - Status int `json:"status"` - } `json:"header"` - Payload struct { - Choices struct { - Status int `json:"status"` - Seq int `json:"seq"` - Text []XunfeiChatResponseTextItem `json:"text"` - } `json:"choices"` - Usage struct { - //Text struct { - // QuestionTokens string `json:"question_tokens"` - // PromptTokens string `json:"prompt_tokens"` - // CompletionTokens string `json:"completion_tokens"` - // TotalTokens string `json:"total_tokens"` - //} `json:"text"` - Text Usage `json:"text"` - } `json:"usage"` - } `json:"payload"` -} - -func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { - messages := make([]XunfeiMessage, 0, len(request.Messages)) +func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { + messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { - messages = append(messages, XunfeiMessage{ + messages = append(messages, Message{ Role: "user", Content: message.StringContent(), }) - messages = append(messages, XunfeiMessage{ + messages = append(messages, Message{ Role: "assistant", Content: "Okay", }) } else { - messages = append(messages, XunfeiMessage{ + messages = append(messages, Message{ Role: message.Role, Content: message.StringContent(), }) } } - xunfeiRequest := XunfeiChatRequest{} + xunfeiRequest := ChatRequest{} xunfeiRequest.Header.AppId = xunfeiAppId xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature @@ -104,49 +52,49 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma return &xunfeiRequest } -func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { +func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { if len(response.Payload.Choices.Text) == 0 { - response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + response.Payload.Choices.Text = []ChatResponseTextItem{ { Content: "", }, } } - choice := OpenAITextResponseChoice{ + choice := openai.TextResponseChoice{ Index: 0, - Message: Message{ + Message: openai.Message{ Role: "assistant", Content: response.Payload.Choices.Text[0].Content, }, - FinishReason: stopFinishReason, + FinishReason: constant.StopFinishReason, } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := openai.TextResponse{ Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, Usage: response.Payload.Usage.Text, } return &fullTextResponse } -func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { +func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { if len(xunfeiResponse.Payload.Choices.Text) == 0 { - xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ { Content: "", }, } } - var choice ChatCompletionsStreamResponseChoice + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content if xunfeiResponse.Payload.Choices.Status == 2 { - choice.FinishReason = &stopFinishReason + choice.FinishReason = &constant.StopFinishReason } - response := ChatCompletionsStreamResponse{ + response := openai.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "SparkDesk", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response } @@ -177,14 +125,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { return callUrl } -func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { +func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - setEventStreamHeaders(c) - var usage Usage + common.SetEventStreamHeaders(c) + var usage openai.Usage c.Stream(func(w io.Writer) bool { select { case xunfeiResponse := <-dataChan: @@ -194,7 +142,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId response := streamResponseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -207,15 +155,15 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId return nil, &usage } -func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { +func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - var usage Usage + var usage openai.Usage var content string - var xunfeiResponse XunfeiChatResponse + var xunfeiResponse ChatResponse stop := false for !stop { select { @@ -231,7 +179,7 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin } } if len(xunfeiResponse.Payload.Choices.Text) == 0 { - xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ { Content: "", }, @@ -242,14 +190,14 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin response := responseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") _, _ = c.Writer.Write(jsonResponse) return nil, &usage } -func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { +func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } @@ -263,26 +211,26 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId return nil, nil, err } - dataChan := make(chan XunfeiChatResponse) + dataChan := make(chan ChatResponse) stopChan := make(chan bool) go func() { for { _, msg, err := conn.ReadMessage() if err != nil { - common.SysError("error reading stream response: " + err.Error()) + logger.SysError("error reading stream response: " + err.Error()) break } - var response XunfeiChatResponse + var response ChatResponse err = json.Unmarshal(msg, &response) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) break } dataChan <- response if response.Payload.Choices.Status == 2 { err := conn.Close() if err != nil { - common.SysError("error closing websocket connection: " + err.Error()) + logger.SysError("error closing websocket connection: " + err.Error()) } break } @@ -301,7 +249,7 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, } if apiVersion == "" { apiVersion = "v1.1" - common.SysLog("api_version not found, use default: " + apiVersion) + logger.SysLog("api_version not found, use default: " + apiVersion) } domain := "general" if apiVersion != "v1.1" { diff --git a/relay/channel/xunfei/model.go b/relay/channel/xunfei/model.go new file mode 100644 index 00000000..0ca42818 --- /dev/null +++ b/relay/channel/xunfei/model.go @@ -0,0 +1,61 @@ +package xunfei + +import ( + "one-api/relay/channel/openai" +) + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatRequest struct { + Header struct { + AppId string `json:"app_id"` + } `json:"header"` + Parameter struct { + Chat struct { + Domain string `json:"domain,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Auditing bool `json:"auditing,omitempty"` + } `json:"chat"` + } `json:"parameter"` + Payload struct { + Message struct { + Text []Message `json:"text"` + } `json:"message"` + } `json:"payload"` +} + +type ChatResponseTextItem struct { + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` +} + +type ChatResponse struct { + Header struct { + Code int `json:"code"` + Message string `json:"message"` + Sid string `json:"sid"` + Status int `json:"status"` + } `json:"header"` + Payload struct { + Choices struct { + Status int `json:"status"` + Seq int `json:"seq"` + Text []ChatResponseTextItem `json:"text"` + } `json:"choices"` + Usage struct { + //Text struct { + // QuestionTokens string `json:"question_tokens"` + // PromptTokens string `json:"prompt_tokens"` + // CompletionTokens string `json:"completion_tokens"` + // TotalTokens string `json:"total_tokens"` + //} `json:"text"` + Text openai.Usage `json:"text"` + } `json:"usage"` + } `json:"payload"` +} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go new file mode 100644 index 00000000..6d901bc3 --- /dev/null +++ b/relay/channel/zhipu/adaptor.go @@ -0,0 +1,22 @@ +package zhipu + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/controller/relay-zhipu.go b/relay/channel/zhipu/main.go similarity index 58% rename from controller/relay-zhipu.go rename to relay/channel/zhipu/main.go index cb5a78cf..d831f57a 100644 --- a/controller/relay-zhipu.go +++ b/relay/channel/zhipu/main.go @@ -1,4 +1,4 @@ -package controller +package zhipu import ( "bufio" @@ -8,6 +8,10 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" + "one-api/common/logger" + "one-api/relay/channel/openai" + "one-api/relay/constant" "strings" "sync" "time" @@ -18,53 +22,13 @@ import ( // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke -type ZhipuMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type ZhipuRequest struct { - Prompt []ZhipuMessage `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - RequestId string `json:"request_id,omitempty"` - Incremental bool `json:"incremental,omitempty"` -} - -type ZhipuResponseData struct { - TaskId string `json:"task_id"` - RequestId string `json:"request_id"` - TaskStatus string `json:"task_status"` - Choices []ZhipuMessage `json:"choices"` - Usage `json:"usage"` -} - -type ZhipuResponse struct { - Code int `json:"code"` - Msg string `json:"msg"` - Success bool `json:"success"` - Data ZhipuResponseData `json:"data"` -} - -type ZhipuStreamMetaResponse struct { - RequestId string `json:"request_id"` - TaskId string `json:"task_id"` - TaskStatus string `json:"task_status"` - Usage `json:"usage"` -} - -type zhipuTokenData struct { - Token string - ExpiryTime time.Time -} - var zhipuTokens sync.Map var expSeconds int64 = 24 * 3600 -func getZhipuToken(apikey string) string { +func GetToken(apikey string) string { data, ok := zhipuTokens.Load(apikey) if ok { - tokenData := data.(zhipuTokenData) + tokenData := data.(tokenData) if time.Now().Before(tokenData.ExpiryTime) { return tokenData.Token } @@ -72,7 +36,7 @@ func getZhipuToken(apikey string) string { split := strings.Split(apikey, ".") if len(split) != 2 { - common.SysError("invalid zhipu key: " + apikey) + logger.SysError("invalid zhipu key: " + apikey) return "" } @@ -100,7 +64,7 @@ func getZhipuToken(apikey string) string { return "" } - zhipuTokens.Store(apikey, zhipuTokenData{ + zhipuTokens.Store(apikey, tokenData{ Token: tokenString, ExpiryTime: expiryTime, }) @@ -108,26 +72,26 @@ func getZhipuToken(apikey string) string { return tokenString } -func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { - messages := make([]ZhipuMessage, 0, len(request.Messages)) +func ConvertRequest(request openai.GeneralOpenAIRequest) *Request { + messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { - messages = append(messages, ZhipuMessage{ + messages = append(messages, Message{ Role: "system", Content: message.StringContent(), }) - messages = append(messages, ZhipuMessage{ + messages = append(messages, Message{ Role: "user", Content: "Okay", }) } else { - messages = append(messages, ZhipuMessage{ + messages = append(messages, Message{ Role: message.Role, Content: message.StringContent(), }) } } - return &ZhipuRequest{ + return &Request{ Prompt: messages, Temperature: request.Temperature, TopP: request.TopP, @@ -135,18 +99,18 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { } } -func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ +func responseZhipu2OpenAI(response *Response) *openai.TextResponse { + fullTextResponse := openai.TextResponse{ Id: response.Data.TaskId, Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), + Created: helper.GetTimestamp(), + Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)), Usage: response.Data.Usage, } for i, choice := range response.Data.Choices { - openaiChoice := OpenAITextResponseChoice{ + openaiChoice := openai.TextResponseChoice{ Index: i, - Message: Message{ + Message: openai.Message{ Role: choice.Role, Content: strings.Trim(choice.Content, "\""), }, @@ -160,34 +124,34 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { return &fullTextResponse } -func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = zhipuResponse - response := ChatCompletionsStreamResponse{ + response := openai.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "chatglm", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response } -func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { - var choice ChatCompletionsStreamResponseChoice +func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) { + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = "" - choice.FinishReason = &stopFinishReason - response := ChatCompletionsStreamResponse{ + choice.FinishReason = &constant.StopFinishReason + response := openai.ChatCompletionsStreamResponse{ Id: zhipuResponse.RequestId, Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "chatglm", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } return &response, &zhipuResponse.Usage } -func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage *Usage +func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var usage *openai.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -224,29 +188,29 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt } stopChan <- true }() - setEventStreamHeaders(c) + common.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: response := streamResponseZhipu2OpenAI(data) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case data := <-metaChan: - var zhipuResponse ZhipuStreamMetaResponse + var zhipuResponse StreamMetaResponse err := json.Unmarshal([]byte(data), &zhipuResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } usage = zhipuUsage @@ -259,28 +223,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, usage } -func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var zhipuResponse ZhipuResponse +func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { + var zhipuResponse Response responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if !zhipuResponse.Success { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &openai.ErrorWithStatusCode{ + Error: openai.Error{ Message: zhipuResponse.Msg, Type: "zhipu_error", Param: "", @@ -293,7 +257,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo fullTextResponse.Model = "chatglm" jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/zhipu/model.go b/relay/channel/zhipu/model.go new file mode 100644 index 00000000..08a5ec5f --- /dev/null +++ b/relay/channel/zhipu/model.go @@ -0,0 +1,46 @@ +package zhipu + +import ( + "one-api/relay/channel/openai" + "time" +) + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Request struct { + Prompt []Message `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + RequestId string `json:"request_id,omitempty"` + Incremental bool `json:"incremental,omitempty"` +} + +type ResponseData struct { + TaskId string `json:"task_id"` + RequestId string `json:"request_id"` + TaskStatus string `json:"task_status"` + Choices []Message `json:"choices"` + openai.Usage `json:"usage"` +} + +type Response struct { + Code int `json:"code"` + Msg string `json:"msg"` + Success bool `json:"success"` + Data ResponseData `json:"data"` +} + +type StreamMetaResponse struct { + RequestId string `json:"request_id"` + TaskId string `json:"task_id"` + TaskStatus string `json:"task_status"` + openai.Usage `json:"usage"` +} + +type tokenData struct { + Token string + ExpiryTime time.Time +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go new file mode 100644 index 00000000..658bfb90 --- /dev/null +++ b/relay/constant/api_type.go @@ -0,0 +1,69 @@ +package constant + +import ( + "one-api/common" +) + +const ( + APITypeOpenAI = iota + APITypeClaude + APITypePaLM + APITypeBaidu + APITypeZhipu + APITypeAli + APITypeXunfei + APITypeAIProxyLibrary + APITypeTencent + APITypeGemini +) + +func ChannelType2APIType(channelType int) int { + apiType := APITypeOpenAI + switch channelType { + case common.ChannelTypeAnthropic: + apiType = APITypeClaude + case common.ChannelTypeBaidu: + apiType = APITypeBaidu + case common.ChannelTypePaLM: + apiType = APITypePaLM + case common.ChannelTypeZhipu: + apiType = APITypeZhipu + case common.ChannelTypeAli: + apiType = APITypeAli + case common.ChannelTypeXunfei: + apiType = APITypeXunfei + case common.ChannelTypeAIProxyLibrary: + apiType = APITypeAIProxyLibrary + case common.ChannelTypeTencent: + apiType = APITypeTencent + case common.ChannelTypeGemini: + apiType = APITypeGemini + } + return apiType +} + +//func GetAdaptor(apiType int) channel.Adaptor { +// switch apiType { +// case APITypeOpenAI: +// return &openai.Adaptor{} +// case APITypeClaude: +// return &anthropic.Adaptor{} +// case APITypePaLM: +// return &google.Adaptor{} +// case APITypeZhipu: +// return &baidu.Adaptor{} +// case APITypeBaidu: +// return &baidu.Adaptor{} +// case APITypeAli: +// return &ali.Adaptor{} +// case APITypeXunfei: +// return &xunfei.Adaptor{} +// case APITypeAIProxyLibrary: +// return &aiproxy.Adaptor{} +// case APITypeTencent: +// return &tencent.Adaptor{} +// case APITypeGemini: +// return &google.Adaptor{} +// } +// return nil +//} diff --git a/relay/constant/common.go b/relay/constant/common.go new file mode 100644 index 00000000..b6606cc6 --- /dev/null +++ b/relay/constant/common.go @@ -0,0 +1,3 @@ +package constant + +var StopFinishReason = "stop" diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go new file mode 100644 index 00000000..5e2fe574 --- /dev/null +++ b/relay/constant/relay_mode.go @@ -0,0 +1,42 @@ +package constant + +import "strings" + +const ( + RelayModeUnknown = iota + RelayModeChatCompletions + RelayModeCompletions + RelayModeEmbeddings + RelayModeModerations + RelayModeImagesGenerations + RelayModeEdits + RelayModeAudioSpeech + RelayModeAudioTranscription + RelayModeAudioTranslation +) + +func Path2RelayMode(path string) int { + relayMode := RelayModeUnknown + if strings.HasPrefix(path, "/v1/chat/completions") { + relayMode = RelayModeChatCompletions + } else if strings.HasPrefix(path, "/v1/completions") { + relayMode = RelayModeCompletions + } else if strings.HasPrefix(path, "/v1/embeddings") { + relayMode = RelayModeEmbeddings + } else if strings.HasSuffix(path, "embeddings") { + relayMode = RelayModeEmbeddings + } else if strings.HasPrefix(path, "/v1/moderations") { + relayMode = RelayModeModerations + } else if strings.HasPrefix(path, "/v1/images/generations") { + relayMode = RelayModeImagesGenerations + } else if strings.HasPrefix(path, "/v1/edits") { + relayMode = RelayModeEdits + } else if strings.HasPrefix(path, "/v1/audio/speech") { + relayMode = RelayModeAudioSpeech + } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { + relayMode = RelayModeAudioTranscription + } else if strings.HasPrefix(path, "/v1/audio/translations") { + relayMode = RelayModeAudioTranslation + } + return relayMode +} diff --git a/controller/relay-audio.go b/relay/controller/audio.go similarity index 63% rename from controller/relay-audio.go rename to relay/controller/audio.go index 2247f4c7..822d7e39 100644 --- a/controller/relay-audio.go +++ b/relay/controller/audio.go @@ -11,11 +11,16 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" + "one-api/common/logger" "one-api/model" + "one-api/relay/channel/openai" + "one-api/relay/constant" + "one-api/relay/util" "strings" ) -func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { +func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { audioModel := "whisper-1" tokenId := c.GetInt("token_id") @@ -25,18 +30,18 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode group := c.GetString("group") tokenName := c.GetString("token_name") - var ttsRequest TextToSpeechRequest - if relayMode == RelayModeAudioSpeech { + var ttsRequest openai.TextToSpeechRequest + if relayMode == constant.RelayModeAudioSpeech { // Read JSON err := common.UnmarshalBodyReusable(c, &ttsRequest) // Check if JSON is valid if err != nil { - return errorWrapper(err, "invalid_json", http.StatusBadRequest) + return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest) } audioModel = ttsRequest.Model // Check if text is too long 4096 if len(ttsRequest.Input) > 4096 { - return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) } } @@ -46,24 +51,24 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode var quota int var preConsumedQuota int switch relayMode { - case RelayModeAudioSpeech: + case constant.RelayModeAudioSpeech: preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota default: - preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio) + preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio) } userQuota, err := model.CacheGetUserQuota(userId) if err != nil { - return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } // Check if user quota is enough if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) } if userQuota > 100*preConsumedQuota { // in this case, we do not pre-consume quota @@ -73,7 +78,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } @@ -83,7 +88,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[audioModel] != "" { audioModel = modelMap[audioModel] @@ -96,27 +101,27 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) + if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiVersion := GetAPIVersion(c) + apiVersion := util.GetAzureAPIVersion(c) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) } requestBody := &bytes.Buffer{} _, err = io.Copy(requestBody, c.Request.Body) if err != nil { - return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") @@ -128,34 +133,34 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - resp, err := httpClient.Do(req) + resp, err := util.HTTPClient.Do(req) if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } err = req.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } err = c.Request.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode != RelayModeAudioSpeech { + if relayMode != constant.RelayModeAudioSpeech { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - var openAIErr TextResponse + var openAIErr openai.SlimTextResponse if err = json.Unmarshal(responseBody, &openAIErr); err == nil { if openAIErr.Error.Message != "" { - return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) + return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) } } @@ -172,12 +177,12 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode case "vtt": text, err = getTextFromVTT(responseBody) default: - return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) + return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) } if err != nil { - return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) } - quota = countTokenText(text, audioModel) + quota = openai.CountTokenText(text, audioModel) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { @@ -188,16 +193,16 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode // negative means add quota back for token & user err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) if err != nil { - common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) + logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) } }() }(c.Request.Context()) } - return relayErrorHandler(resp) + return util.RelayErrorHandler(resp) } quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { - go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) }(c.Request.Context()) for k, v := range resp.Header { @@ -207,11 +212,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } return nil } @@ -221,7 +226,7 @@ func getTextFromVTT(body []byte) (string, error) { } func getTextFromVerboseJSON(body []byte) (string, error) { - var whisperResponse WhisperVerboseJSONResponse + var whisperResponse openai.WhisperVerboseJSONResponse if err := json.Unmarshal(body, &whisperResponse); err != nil { return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) } @@ -254,7 +259,7 @@ func getTextFromText(body []byte) (string, error) { } func getTextFromJSON(body []byte) (string, error) { - var whisperResponse WhisperJSONResponse + var whisperResponse openai.WhisperJSONResponse if err := json.Unmarshal(body, &whisperResponse); err != nil { return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) } diff --git a/controller/relay-image.go b/relay/controller/image.go similarity index 68% rename from controller/relay-image.go rename to relay/controller/image.go index 14a2983b..9502a4d7 100644 --- a/controller/relay-image.go +++ b/relay/controller/image.go @@ -9,7 +9,10 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" + "one-api/relay/channel/openai" + "one-api/relay/util" "strings" "github.com/gin-gonic/gin" @@ -25,7 +28,7 @@ func isWithinRange(element string, value int) bool { return value >= min && value <= max } -func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { +func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { imageModel := "dall-e-2" imageSize := "1024x1024" @@ -35,10 +38,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode userId := c.GetInt("id") group := c.GetString("group") - var imageRequest ImageRequest + var imageRequest openai.ImageRequest err := common.UnmarshalBodyReusable(c, &imageRequest) if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } if imageRequest.N == 0 { @@ -67,24 +70,24 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } } } else { - return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) } // Prompt validation if imageRequest.Prompt == "" { - return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } // Check prompt length if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { - return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } // Number of generated images validation if isWithinRange(imageModel, imageRequest.N) == false { // channel not azure if channelType != common.ChannelTypeAzure { - return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) } } @@ -95,7 +98,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[imageModel] != "" { imageModel = modelMap[imageModel] @@ -107,10 +110,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) if channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := GetAPIVersion(c) + apiVersion := util.GetAzureAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) } @@ -119,7 +122,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } else { @@ -134,12 +137,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode quota := int(ratio*imageCostRatio*1000) * imageRequest.N if userQuota-quota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } token := c.Request.Header.Get("Authorization") if channelType == common.ChannelTypeAzure { // Azure authentication @@ -152,20 +155,20 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - resp, err := httpClient.Do(req) + resp, err := util.HTTPClient.Do(req) if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } err = req.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } err = c.Request.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - var textResponse ImageResponse + var textResponse openai.ImageResponse defer func(ctx context.Context) { if resp.StatusCode != http.StatusOK { @@ -173,11 +176,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } err := model.PostConsumeTokenQuota(tokenId, quota) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + logger.SysError("error consuming token remain quota: " + err.Error()) } err = model.CacheUpdateUserQuota(userId) if err != nil { - common.SysError("error update user quota cache: " + err.Error()) + logger.SysError("error update user quota cache: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") @@ -192,15 +195,15 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } err = json.Unmarshal(responseBody, &textResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) @@ -212,11 +215,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } return nil } diff --git a/relay/controller/text.go b/relay/controller/text.go new file mode 100644 index 00000000..68354628 --- /dev/null +++ b/relay/controller/text.go @@ -0,0 +1,173 @@ +package controller + +import ( + "context" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "math" + "net/http" + "one-api/common" + "one-api/common/config" + "one-api/common/logger" + "one-api/model" + "one-api/relay/channel/openai" + "one-api/relay/constant" + "one-api/relay/util" + "strings" +) + +func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { + ctx := c.Request.Context() + meta := util.GetRelayMeta(c) + var textRequest openai.GeneralOpenAIRequest + err := common.UnmarshalBodyReusable(c, &textRequest) + if err != nil { + return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + if relayMode == constant.RelayModeModerations && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } + if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { + textRequest.Model = c.Param("model") + } + err = util.ValidateTextRequest(&textRequest, relayMode) + if err != nil { + return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) + } + var isModelMapped bool + textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping) + apiType := constant.ChannelType2APIType(meta.ChannelType) + fullRequestURL, err := GetRequestURL(c.Request.URL.String(), apiType, relayMode, meta, &textRequest) + if err != nil { + logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error())) + return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError) + } + var promptTokens int + var completionTokens int + switch relayMode { + case constant.RelayModeChatCompletions: + promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model) + case constant.RelayModeCompletions: + promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model) + case constant.RelayModeModerations: + promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model) + } + preConsumedTokens := config.PreConsumedQuota + if textRequest.MaxTokens != 0 { + preConsumedTokens = promptTokens + textRequest.MaxTokens + } + modelRatio := common.GetModelRatio(textRequest.Model) + groupRatio := common.GetGroupRatio(meta.Group) + ratio := modelRatio * groupRatio + preConsumedQuota := int(float64(preConsumedTokens) * ratio) + userQuota, err := model.CacheGetUserQuota(meta.UserId) + if err != nil { + return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + if userQuota-preConsumedQuota < 0 { + return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota) + if err != nil { + return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + logger.Info(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota)) + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota) + if err != nil { + return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + requestBody, err := GetRequestBody(c, textRequest, isModelMapped, apiType, relayMode) + if err != nil { + return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError) + } + var req *http.Request + var resp *http.Response + isStream := textRequest.Stream + + if apiType != constant.APITypeXunfei { // cause xunfei use websocket + req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + SetupRequestHeaders(c, req, apiType, meta, isStream) + resp, err = util.HTTPClient.Do(req) + if err != nil { + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + err = req.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + err = c.Request.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + + if resp.StatusCode != http.StatusOK { + util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + return util.RelayErrorHandler(resp) + } + } + + var respErr *openai.ErrorWithStatusCode + var usage *openai.Usage + + defer func(ctx context.Context) { + // Why we use defer here? Because if error happened, we will have to return the pre-consumed quota. + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + return + } + if usage == nil { + logger.Error(ctx, "usage is nil, which is unexpected") + return + } + + go func() { + quota := 0 + completionRatio := common.GetCompletionRatio(textRequest.Model) + promptTokens = usage.PromptTokens + completionTokens = usage.CompletionTokens + quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 + } + totalTokens := promptTokens + completionTokens + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + } + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta) + if err != nil { + logger.Error(ctx, "error consuming token remain quota: "+err.Error()) + } + err = model.CacheUpdateUserQuota(meta.UserId) + if err != nil { + logger.Error(ctx, "error update user quota cache: "+err.Error()) + } + if quota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) + model.UpdateChannelUsedQuota(meta.ChannelId, quota) + } + }() + }(ctx) + usage, respErr = DoResponse(c, &textRequest, resp, relayMode, apiType, isStream, promptTokens) + if respErr != nil { + return respErr + } + return nil +} diff --git a/relay/controller/util.go b/relay/controller/util.go new file mode 100644 index 00000000..02f1b30f --- /dev/null +++ b/relay/controller/util.go @@ -0,0 +1,337 @@ +package controller + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/common/helper" + "one-api/relay/channel/aiproxy" + "one-api/relay/channel/ali" + "one-api/relay/channel/anthropic" + "one-api/relay/channel/baidu" + "one-api/relay/channel/google" + "one-api/relay/channel/openai" + "one-api/relay/channel/tencent" + "one-api/relay/channel/xunfei" + "one-api/relay/channel/zhipu" + "one-api/relay/constant" + "one-api/relay/util" + "strings" +) + +func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) { + fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) + switch apiType { + case constant.APITypeOpenAI: + if meta.ChannelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api + requestURL := strings.Split(requestURL, "?")[0] + requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) + task := strings.TrimPrefix(requestURL, "/v1/") + model_ := textRequest.Model + model_ = strings.Replace(model_, ".", "", -1) + // https://github.com/songquanpeng/one-api/issues/67 + model_ = strings.TrimSuffix(model_, "-0301") + model_ = strings.TrimSuffix(model_, "-0314") + model_ = strings.TrimSuffix(model_, "-0613") + + requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) + } + case constant.APITypeClaude: + fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL) + case constant.APITypeBaidu: + switch textRequest.Model { + case "ERNIE-Bot": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" + case "ERNIE-Bot-turbo": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" + case "ERNIE-Bot-4": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" + case "BLOOMZ-7B": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" + case "Embedding-V1": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" + } + var accessToken string + var err error + if accessToken, err = baidu.GetAccessToken(meta.APIKey); err != nil { + return "", fmt.Errorf("failed to get baidu access token: %w", err) + } + fullRequestURL += "?access_token=" + accessToken + case constant.APITypePaLM: + fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL) + case constant.APITypeGemini: + version := helper.AssignOrDefault(meta.APIVersion, "v1") + action := "generateContent" + if textRequest.Stream { + action = "streamGenerateContent" + } + fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, textRequest.Model, action) + case constant.APITypeZhipu: + method := "invoke" + if textRequest.Stream { + method = "sse-invoke" + } + fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) + case constant.APITypeAli: + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + if relayMode == constant.RelayModeEmbeddings { + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + } + case constant.APITypeTencent: + fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" + case constant.APITypeAIProxyLibrary: + fullRequestURL = fmt.Sprintf("%s/api/library/ask", meta.BaseURL) + } + return fullRequestURL, nil +} + +func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) { + var requestBody io.Reader + if isModelMapped { + jsonStr, err := json.Marshal(textRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = c.Request.Body + } + switch apiType { + case constant.APITypeClaude: + claudeRequest := anthropic.ConvertRequest(textRequest) + jsonStr, err := json.Marshal(claudeRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeBaidu: + var jsonData []byte + var err error + switch relayMode { + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest) + jsonData, err = json.Marshal(baiduEmbeddingRequest) + default: + baiduRequest := baidu.ConvertRequest(textRequest) + jsonData, err = json.Marshal(baiduRequest) + } + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonData) + case constant.APITypePaLM: + palmRequest := google.ConvertPaLMRequest(textRequest) + jsonStr, err := json.Marshal(palmRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeGemini: + geminiChatRequest := google.ConvertGeminiRequest(textRequest) + jsonStr, err := json.Marshal(geminiChatRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeZhipu: + zhipuRequest := zhipu.ConvertRequest(textRequest) + jsonStr, err := json.Marshal(zhipuRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeAli: + var jsonStr []byte + var err error + switch relayMode { + case constant.RelayModeEmbeddings: + aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest) + jsonStr, err = json.Marshal(aliEmbeddingRequest) + default: + aliRequest := ali.ConvertRequest(textRequest) + jsonStr, err = json.Marshal(aliRequest) + } + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeTencent: + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + appId, secretId, secretKey, err := tencent.ParseConfig(apiKey) + if err != nil { + return nil, err + } + tencentRequest := tencent.ConvertRequest(textRequest) + tencentRequest.AppId = appId + tencentRequest.SecretId = secretId + jsonStr, err := json.Marshal(tencentRequest) + if err != nil { + return nil, err + } + sign := tencent.GetSign(*tencentRequest, secretKey) + c.Request.Header.Set("Authorization", sign) + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeAIProxyLibrary: + aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest) + aiProxyLibraryRequest.LibraryId = c.GetString("library_id") + jsonStr, err := json.Marshal(aiProxyLibraryRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + } + return requestBody, nil +} + +func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) { + SetupAuthHeaders(c, req, apiType, meta, isStream) + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + if isStream && c.Request.Header.Get("Accept") == "" { + req.Header.Set("Accept", "text/event-stream") + } +} + +func SetupAuthHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) { + apiKey := meta.APIKey + switch apiType { + case constant.APITypeOpenAI: + if meta.ChannelType == common.ChannelTypeAzure { + req.Header.Set("api-key", apiKey) + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + if meta.ChannelType == common.ChannelTypeOpenRouter { + req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") + req.Header.Set("X-Title", "One API") + } + } + case constant.APITypeClaude: + req.Header.Set("x-api-key", apiKey) + anthropicVersion := c.Request.Header.Get("anthropic-version") + if anthropicVersion == "" { + anthropicVersion = "2023-06-01" + } + req.Header.Set("anthropic-version", anthropicVersion) + case constant.APITypeZhipu: + token := zhipu.GetToken(apiKey) + req.Header.Set("Authorization", token) + case constant.APITypeAli: + req.Header.Set("Authorization", "Bearer "+apiKey) + if isStream { + req.Header.Set("X-DashScope-SSE", "enable") + } + if c.GetString("plugin") != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) + } + case constant.APITypeTencent: + req.Header.Set("Authorization", apiKey) + case constant.APITypePaLM: + req.Header.Set("x-goog-api-key", apiKey) + case constant.APITypeGemini: + req.Header.Set("x-goog-api-key", apiKey) + default: + req.Header.Set("Authorization", "Bearer "+apiKey) + } +} + +func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) { + var responseText string + switch apiType { + case constant.APITypeOpenAI: + if isStream { + err, responseText = openai.StreamHandler(c, resp, relayMode) + } else { + err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model) + } + case constant.APITypeClaude: + if isStream { + err, responseText = anthropic.StreamHandler(c, resp) + } else { + err, usage = anthropic.Handler(c, resp, promptTokens, textRequest.Model) + } + case constant.APITypeBaidu: + if isStream { + err, usage = baidu.StreamHandler(c, resp) + } else { + switch relayMode { + case constant.RelayModeEmbeddings: + err, usage = baidu.EmbeddingHandler(c, resp) + default: + err, usage = baidu.Handler(c, resp) + } + } + case constant.APITypePaLM: + if isStream { // PaLM2 API does not support stream + err, responseText = google.PaLMStreamHandler(c, resp) + } else { + err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model) + } + case constant.APITypeGemini: + if isStream { + err, responseText = google.StreamHandler(c, resp) + } else { + err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model) + } + case constant.APITypeZhipu: + if isStream { + err, usage = zhipu.StreamHandler(c, resp) + } else { + err, usage = zhipu.Handler(c, resp) + } + case constant.APITypeAli: + if isStream { + err, usage = ali.StreamHandler(c, resp) + } else { + switch relayMode { + case constant.RelayModeEmbeddings: + err, usage = ali.EmbeddingHandler(c, resp) + default: + err, usage = ali.Handler(c, resp) + } + } + case constant.APITypeXunfei: + auth := c.Request.Header.Get("Authorization") + auth = strings.TrimPrefix(auth, "Bearer ") + splits := strings.Split(auth, "|") + if len(splits) != 3 { + return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + } + if isStream { + err, usage = xunfei.StreamHandler(c, *textRequest, splits[0], splits[1], splits[2]) + } else { + err, usage = xunfei.Handler(c, *textRequest, splits[0], splits[1], splits[2]) + } + case constant.APITypeAIProxyLibrary: + if isStream { + err, usage = aiproxy.StreamHandler(c, resp) + } else { + err, usage = aiproxy.Handler(c, resp) + } + case constant.APITypeTencent: + if isStream { + err, responseText = tencent.StreamHandler(c, resp) + } else { + err, usage = tencent.Handler(c, resp) + } + default: + return nil, openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) + } + if err != nil { + return nil, err + } + if usage == nil && responseText != "" { + usage = &openai.Usage{} + usage.PromptTokens = promptTokens + usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + return usage, nil +} diff --git a/relay/util/billing.go b/relay/util/billing.go new file mode 100644 index 00000000..35fb28a4 --- /dev/null +++ b/relay/util/billing.go @@ -0,0 +1,19 @@ +package util + +import ( + "context" + "one-api/common/logger" + "one-api/model" +) + +func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) { + if preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(ctx) + } +} diff --git a/relay/util/common.go b/relay/util/common.go new file mode 100644 index 00000000..be31857b --- /dev/null +++ b/relay/util/common.go @@ -0,0 +1,168 @@ +package util + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/common/config" + "one-api/common/logger" + "one-api/model" + "one-api/relay/channel/openai" + "strconv" + "strings" + + "github.com/gin-gonic/gin" +) + +func ShouldDisableChannel(err *openai.Error, statusCode int) bool { + if !config.AutomaticDisableChannelEnabled { + return false + } + if err == nil { + return false + } + if statusCode == http.StatusUnauthorized { + return true + } + if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + return true + } + return false +} + +func ShouldEnableChannel(err error, openAIErr *openai.Error) bool { + if !config.AutomaticEnableChannelEnabled { + return false + } + if err != nil { + return false + } + if openAIErr != nil { + return false + } + return true +} + +type GeneralErrorResponse struct { + Error openai.Error `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" +} + +func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) { + ErrorWithStatusCode = &openai.ErrorWithStatusCode{ + StatusCode: resp.StatusCode, + Error: openai.Error{ + Message: "", + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + err = resp.Body.Close() + if err != nil { + return + } + var errResponse GeneralErrorResponse + err = json.Unmarshal(responseBody, &errResponse) + if err != nil { + return + } + if errResponse.Error.Message != "" { + // OpenAI format error, so we override the default one + ErrorWithStatusCode.Error = errResponse.Error + } else { + ErrorWithStatusCode.Error.Message = errResponse.ToMessage() + } + if ErrorWithStatusCode.Error.Message == "" { + ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } + return +} + +func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + switch channelType { + case common.ChannelTypeOpenAI: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + case common.ChannelTypeAzure: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) + } + } + return fullRequestURL +} + +func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + // quotaDelta is remaining quota to be consumed + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + logger.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + logger.SysError("error update user quota cache: " + err.Error()) + } + // totalQuota is total quota consumed + if totalQuota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) + model.UpdateChannelUsedQuota(channelId, totalQuota) + } + if totalQuota <= 0 { + logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) + } +} + +func GetAzureAPIVersion(c *gin.Context) string { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + return apiVersion +} diff --git a/relay/util/init.go b/relay/util/init.go new file mode 100644 index 00000000..62d44d15 --- /dev/null +++ b/relay/util/init.go @@ -0,0 +1,24 @@ +package util + +import ( + "net/http" + "one-api/common/config" + "time" +) + +var HTTPClient *http.Client +var ImpatientHTTPClient *http.Client + +func init() { + if config.RelayTimeout == 0 { + HTTPClient = &http.Client{} + } else { + HTTPClient = &http.Client{ + Timeout: time.Duration(config.RelayTimeout) * time.Second, + } + } + + ImpatientHTTPClient = &http.Client{ + Timeout: 5 * time.Second, + } +} diff --git a/relay/util/model_mapping.go b/relay/util/model_mapping.go new file mode 100644 index 00000000..39e062a1 --- /dev/null +++ b/relay/util/model_mapping.go @@ -0,0 +1,12 @@ +package util + +func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) { + if mapping == nil { + return modelName, false + } + mappedModelName := mapping[modelName] + if mappedModelName != "" { + return mappedModelName, true + } + return modelName, false +} diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go new file mode 100644 index 00000000..19936e49 --- /dev/null +++ b/relay/util/relay_meta.go @@ -0,0 +1,44 @@ +package util + +import ( + "github.com/gin-gonic/gin" + "one-api/common" + "strings" +) + +type RelayMeta struct { + ChannelType int + ChannelId int + TokenId int + TokenName string + UserId int + Group string + ModelMapping map[string]string + BaseURL string + APIVersion string + APIKey string + Config map[string]string +} + +func GetRelayMeta(c *gin.Context) *RelayMeta { + meta := RelayMeta{ + ChannelType: c.GetInt("channel"), + ChannelId: c.GetInt("channel_id"), + TokenId: c.GetInt("token_id"), + TokenName: c.GetString("token_name"), + UserId: c.GetInt("id"), + Group: c.GetString("group"), + ModelMapping: c.GetStringMapString("model_mapping"), + BaseURL: c.GetString("base_url"), + APIVersion: c.GetString("api_version"), + APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + Config: nil, + } + if meta.ChannelType == common.ChannelTypeAzure { + meta.APIVersion = GetAzureAPIVersion(c) + } + if meta.BaseURL == "" { + meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] + } + return &meta +} diff --git a/relay/util/validation.go b/relay/util/validation.go new file mode 100644 index 00000000..48b42d94 --- /dev/null +++ b/relay/util/validation.go @@ -0,0 +1,37 @@ +package util + +import ( + "errors" + "math" + "one-api/relay/channel/openai" + "one-api/relay/constant" +) + +func ValidateTextRequest(textRequest *openai.GeneralOpenAIRequest, relayMode int) error { + if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { + return errors.New("max_tokens is invalid") + } + if textRequest.Model == "" { + return errors.New("model is required") + } + switch relayMode { + case constant.RelayModeCompletions: + if textRequest.Prompt == "" { + return errors.New("field prompt is required") + } + case constant.RelayModeChatCompletions: + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return errors.New("field messages is required") + } + case constant.RelayModeEmbeddings: + case constant.RelayModeModerations: + if textRequest.Input == "" { + return errors.New("field input is required") + } + case constant.RelayModeEdits: + if textRequest.Instruction == "" { + return errors.New("field instruction is required") + } + } + return nil +} diff --git a/router/api-router.go b/router/api-router.go index da3f9e61..162675ce 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -35,6 +35,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute := userRoute.Group("/") selfRoute.Use(middleware.UserAuth()) { + selfRoute.GET("/dashboard", controller.GetUserDashboard) selfRoute.GET("/self", controller.GetSelf) selfRoute.PUT("/self", controller.UpdateSelf) selfRoute.DELETE("/self", controller.DeleteSelf) diff --git a/router/main.go b/router/main.go index 85127a1a..6504b312 100644 --- a/router/main.go +++ b/router/main.go @@ -5,7 +5,8 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/common/config" + "one-api/common/logger" "os" "strings" ) @@ -15,9 +16,9 @@ func SetRouter(router *gin.Engine, buildFS embed.FS) { SetDashboardRouter(router) SetRelayRouter(router) frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") - if common.IsMasterNode && frontendBaseUrl != "" { + if config.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" - common.SysLog("FRONTEND_BASE_URL is ignored on master node") + logger.SysLog("FRONTEND_BASE_URL is ignored on master node") } if frontendBaseUrl == "" { SetWebRouter(router, buildFS) diff --git a/router/web-router.go b/router/web-router.go index 7328c7a3..95d8fcb9 100644 --- a/router/web-router.go +++ b/router/web-router.go @@ -8,17 +8,18 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/config" "one-api/controller" "one-api/middleware" "strings" ) func SetWebRouter(router *gin.Engine, buildFS embed.FS) { - indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", common.Theme)) + indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", config.Theme)) router.Use(gzip.Gzip(gzip.DefaultCompression)) router.Use(middleware.GlobalWebRateLimit()) router.Use(middleware.Cache()) - router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", common.Theme)))) + router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", config.Theme)))) router.NoRoute(func(c *gin.Context) { if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") { controller.RelayNotFound(c) diff --git a/web/README.md b/web/README.md index ca73b298..86486085 100644 --- a/web/README.md +++ b/web/README.md @@ -1,17 +1,38 @@ # One API 的前端界面 + > 每个文件夹代表一个主题,欢迎提交你的主题 ## 提交新的主题 + > 欢迎在页面底部保留你和 One API 的版权信息以及指向链接 + 1. 在 `web` 文件夹下新建一个文件夹,文件夹名为主题名。 2. 把你的主题文件放到这个文件夹下。 -3. 修改 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。 +3. 修改你的 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。 +4. 修改 `common/constants.go` 中的 `ValidThemes`,把你的主题名称注册进去。 +5. 修改 `web/THEMES` 文件,这里也需要同步修改。 ## 主题列表 + ### 主题:default + 默认主题,由 [JustSong](https://github.com/songquanpeng) 开发。 预览: |![image](https://github.com/songquanpeng/one-api/assets/39998050/ccfbc668-3a7f-4bc1-87da-7eacfd7bf371)|![image](https://github.com/songquanpeng/one-api/assets/39998050/a63ed547-44b9-45db-b43a-ecea07d60840)| |:---:|:---:| +### 主题:berry + +由 [MartialBE](https://github.com/MartialBE) 开发。 + +预览: +||| +|:---:|:---:| +|![image](https://github.com/songquanpeng/one-api/assets/42402987/36aff5c6-c5ff-4a90-8e3d-33d5cff34cbf)|![image](https://github.com/songquanpeng/one-api/assets/42402987/9ac63b36-5140-4064-8fad-fc9d25821509)| +|![image](https://github.com/songquanpeng/one-api/assets/42402987/fb2b1c64-ef24-4027-9b80-0cd9d945a47f)|![image](https://github.com/songquanpeng/one-api/assets/42402987/b6b649ec-2888-4324-8b2d-d5e11554eed6)| +|![image](https://github.com/songquanpeng/one-api/assets/42402987/6d3b22e0-436b-4e26-8911-bcc993c6a2bd)|![image](https://github.com/songquanpeng/one-api/assets/42402987/eef1e224-7245-44d7-804e-9d1c8fa3f29c)| + +#### 开发说明 + +请查看 [web/berry/README.md](https://github.com/songquanpeng/one-api/tree/main/web/berry/README.md) diff --git a/web/THEMES b/web/THEMES index 331d858c..b6597eeb 100644 --- a/web/THEMES +++ b/web/THEMES @@ -1 +1,2 @@ -default \ No newline at end of file +default +berry \ No newline at end of file diff --git a/web/berry/.gitignore b/web/berry/.gitignore new file mode 100644 index 00000000..2b5bba76 --- /dev/null +++ b/web/berry/.gitignore @@ -0,0 +1,26 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# production +/build + +# misc +.DS_Store +.env.local +.env.development.local +.env.test.local +.env.production.local + +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.idea +package-lock.json +yarn.lock \ No newline at end of file diff --git a/web/berry/README.md b/web/berry/README.md new file mode 100644 index 00000000..170feedc --- /dev/null +++ b/web/berry/README.md @@ -0,0 +1,61 @@ +# One API 前端界面 + +这个项目是 One API 的前端界面,它基于 [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template) 进行开发。 + +## 使用的开源项目 + +使用了以下开源项目作为我们项目的一部分: + +- [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template) +- [minimal-ui-kit](minimal-ui-kit) + +## 开发说明 + +当添加新的渠道时,需要修改以下地方: + +1. `web/berry/src/constants/ChannelConstants.js` + +在该文件中的 `CHANNEL_OPTIONS` 添加新的渠道 + +```js +export const CHANNEL_OPTIONS = { + //key 为渠道ID + 1: { + key: 1, // 渠道ID + text: "OpenAI", // 渠道名称 + value: 1, // 渠道ID + color: "primary", // 渠道列表显示的颜色 + }, +}; +``` + +2. `web/berry/src/views/Channel/type/Config.js` + +在该文件中的`typeConfig`添加新的渠道配置, 如果无需配置,可以不添加 + +```js +const typeConfig = { + // key 为渠道ID + 3: { + inputLabel: { + // 输入框名称 配置 + // 对应的字段名称 + base_url: "AZURE_OPENAI_ENDPOINT", + other: "默认 API 版本", + }, + prompt: { + // 输入框提示 配置 + // 对应的字段名称 + base_url: "请填写AZURE_OPENAI_ENDPOINT", + + // 注意:通过判断 `other` 是否有值来判断是否需要显示 `other` 输入框, 默认是没有值的 + other: "请输入默认API版本,例如:2023-06-01-preview", + }, + modelGroup: "openai", // 模型组名称,这个值是给 填入渠道支持模型 按钮使用的。 填入渠道支持模型 按钮会根据这个值来获取模型组,如果填写默认是 openai + }, +}; +``` + +## 许可证 + +本项目中使用的代码遵循 MIT 许可证。 diff --git a/web/berry/jsconfig.json b/web/berry/jsconfig.json new file mode 100644 index 00000000..35332c70 --- /dev/null +++ b/web/berry/jsconfig.json @@ -0,0 +1,9 @@ +{ + "compilerOptions": { + "target": "esnext", + "module": "commonjs", + "baseUrl": "src" + }, + "include": ["src/**/*"], + "exclude": ["node_modules"] +} diff --git a/web/berry/package.json b/web/berry/package.json new file mode 100644 index 00000000..f428fd9c --- /dev/null +++ b/web/berry/package.json @@ -0,0 +1,84 @@ +{ + "name": "one_api_web", + "version": "1.0.0", + "proxy": "http://127.0.0.1:3000", + "private": true, + "homepage": "", + "dependencies": { + "@emotion/cache": "^11.9.3", + "@emotion/react": "^11.9.3", + "@emotion/styled": "^11.9.3", + "@mui/icons-material": "^5.8.4", + "@mui/lab": "^5.0.0-alpha.88", + "@mui/material": "^5.8.6", + "@mui/system": "^5.8.6", + "@mui/utils": "^5.8.6", + "@mui/x-date-pickers": "^6.18.5", + "@tabler/icons-react": "^2.44.0", + "apexcharts": "^3.35.3", + "axios": "^0.27.2", + "dayjs": "^1.11.10", + "formik": "^2.2.9", + "framer-motion": "^6.3.16", + "history": "^5.3.0", + "marked": "^4.1.1", + "material-ui-popup-state": "^4.0.1", + "notistack": "^3.0.1", + "prop-types": "^15.8.1", + "react": "^18.2.0", + "react-apexcharts": "^1.4.0", + "react-device-detect": "^2.2.2", + "react-dom": "^18.2.0", + "react-perfect-scrollbar": "^1.5.8", + "react-redux": "^8.0.2", + "react-router": "6.3.0", + "react-router-dom": "6.3.0", + "react-scripts": "^5.0.1", + "react-turnstile": "^1.1.2", + "redux": "^4.2.0", + "yup": "^0.32.11" + }, + "scripts": { + "start": "react-scripts start", + "build": "react-scripts build && mv -f build ../build/berry", + "test": "react-scripts test", + "eject": "react-scripts eject" + }, + "eslintConfig": { + "extends": [ + "react-app" + ] + }, + "babel": { + "presets": [ + "@babel/preset-react" + ] + }, + "browserslist": { + "production": [ + "defaults", + "not IE 11" + ], + "development": [ + "last 1 chrome version", + "last 1 firefox version", + "last 1 safari version" + ] + }, + "devDependencies": { + "@babel/core": "^7.21.4", + "@babel/eslint-parser": "^7.21.3", + "eslint": "^8.38.0", + "eslint-config-prettier": "^8.8.0", + "eslint-config-react-app": "^7.0.1", + "eslint-plugin-flowtype": "^8.0.3", + "eslint-plugin-import": "^2.27.5", + "eslint-plugin-jsx-a11y": "^6.7.1", + "eslint-plugin-prettier": "^4.2.1", + "eslint-plugin-react": "^7.32.2", + "eslint-plugin-react-hooks": "^4.6.0", + "immutable": "^4.3.0", + "prettier": "^2.8.7", + "sass": "^1.53.0" + } +} diff --git a/web/berry/public/favicon.ico b/web/berry/public/favicon.ico new file mode 100644 index 00000000..fbcfb14a Binary files /dev/null and b/web/berry/public/favicon.ico differ diff --git a/web/berry/public/index.html b/web/berry/public/index.html new file mode 100644 index 00000000..6f232250 --- /dev/null +++ b/web/berry/public/index.html @@ -0,0 +1,26 @@ + + + + One API + + + + + + + + + + + + +
+ + + diff --git a/web/berry/src/App.js b/web/berry/src/App.js new file mode 100644 index 00000000..fc54c632 --- /dev/null +++ b/web/berry/src/App.js @@ -0,0 +1,43 @@ +import { useSelector } from 'react-redux'; + +import { ThemeProvider } from '@mui/material/styles'; +import { CssBaseline, StyledEngineProvider } from '@mui/material'; + +// routing +import Routes from 'routes'; + +// defaultTheme +import themes from 'themes'; + +// project imports +import NavigationScroll from 'layout/NavigationScroll'; + +// auth +import UserProvider from 'contexts/UserContext'; +import StatusProvider from 'contexts/StatusContext'; +import { SnackbarProvider } from 'notistack'; + +// ==============================|| APP ||============================== // + +const App = () => { + const customization = useSelector((state) => state.customization); + + return ( + + + + + + + + + + + + + + + ); +}; + +export default App; diff --git a/web/berry/src/assets/images/404.svg b/web/berry/src/assets/images/404.svg new file mode 100644 index 00000000..352a14ad --- /dev/null +++ b/web/berry/src/assets/images/404.svg @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/berry/src/assets/images/auth/auth-blue-card.svg b/web/berry/src/assets/images/auth/auth-blue-card.svg new file mode 100644 index 00000000..6c9fe3e7 --- /dev/null +++ b/web/berry/src/assets/images/auth/auth-blue-card.svg @@ -0,0 +1,65 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/berry/src/assets/images/auth/auth-pattern-dark.svg b/web/berry/src/assets/images/auth/auth-pattern-dark.svg new file mode 100644 index 00000000..aa0e4ab2 --- /dev/null +++ b/web/berry/src/assets/images/auth/auth-pattern-dark.svg @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/berry/src/assets/images/auth/auth-pattern.svg b/web/berry/src/assets/images/auth/auth-pattern.svg new file mode 100644 index 00000000..b7ac8e27 --- /dev/null +++ b/web/berry/src/assets/images/auth/auth-pattern.svg @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/berry/src/assets/images/auth/auth-purple-card.svg b/web/berry/src/assets/images/auth/auth-purple-card.svg new file mode 100644 index 00000000..c724e0a3 --- /dev/null +++ b/web/berry/src/assets/images/auth/auth-purple-card.svg @@ -0,0 +1,69 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/berry/src/assets/images/auth/auth-signup-blue-card.svg b/web/berry/src/assets/images/auth/auth-signup-blue-card.svg new file mode 100644 index 00000000..ebb8e85f --- /dev/null +++ b/web/berry/src/assets/images/auth/auth-signup-blue-card.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/berry/src/assets/images/auth/auth-signup-white-card.svg b/web/berry/src/assets/images/auth/auth-signup-white-card.svg new file mode 100644 index 00000000..56b97e20 --- /dev/null +++ b/web/berry/src/assets/images/auth/auth-signup-white-card.svg @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/berry/src/assets/images/icons/earning.svg b/web/berry/src/assets/images/icons/earning.svg new file mode 100644 index 00000000..e877b599 --- /dev/null +++ b/web/berry/src/assets/images/icons/earning.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/web/berry/src/assets/images/icons/github.svg b/web/berry/src/assets/images/icons/github.svg new file mode 100644 index 00000000..e5b1b82a --- /dev/null +++ b/web/berry/src/assets/images/icons/github.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/berry/src/assets/images/icons/shape-avatar.svg b/web/berry/src/assets/images/icons/shape-avatar.svg new file mode 100644 index 00000000..38aac7e2 --- /dev/null +++ b/web/berry/src/assets/images/icons/shape-avatar.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/berry/src/assets/images/icons/social-google.svg b/web/berry/src/assets/images/icons/social-google.svg new file mode 100644 index 00000000..2231ce98 --- /dev/null +++ b/web/berry/src/assets/images/icons/social-google.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/web/berry/src/assets/images/icons/wechat.svg b/web/berry/src/assets/images/icons/wechat.svg new file mode 100644 index 00000000..a0b2e36c --- /dev/null +++ b/web/berry/src/assets/images/icons/wechat.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/berry/src/assets/images/invite/cover.jpg b/web/berry/src/assets/images/invite/cover.jpg new file mode 100644 index 00000000..93be1a40 Binary files /dev/null and b/web/berry/src/assets/images/invite/cover.jpg differ diff --git a/web/berry/src/assets/images/invite/cwok_casual_19.webp b/web/berry/src/assets/images/invite/cwok_casual_19.webp new file mode 100644 index 00000000..1cf2c376 Binary files /dev/null and b/web/berry/src/assets/images/invite/cwok_casual_19.webp differ diff --git a/web/berry/src/assets/images/logo-2.svg b/web/berry/src/assets/images/logo-2.svg new file mode 100644 index 00000000..2e674a7e --- /dev/null +++ b/web/berry/src/assets/images/logo-2.svg @@ -0,0 +1,15 @@ + + + + Layer 1 + + + + + + + + + + + \ No newline at end of file diff --git a/web/berry/src/assets/images/logo.svg b/web/berry/src/assets/images/logo.svg new file mode 100644 index 00000000..348c7e5a --- /dev/null +++ b/web/berry/src/assets/images/logo.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/berry/src/assets/images/users/user-round.svg b/web/berry/src/assets/images/users/user-round.svg new file mode 100644 index 00000000..eaef7ed9 --- /dev/null +++ b/web/berry/src/assets/images/users/user-round.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/berry/src/assets/scss/_themes-vars.module.scss b/web/berry/src/assets/scss/_themes-vars.module.scss new file mode 100644 index 00000000..a470b033 --- /dev/null +++ b/web/berry/src/assets/scss/_themes-vars.module.scss @@ -0,0 +1,157 @@ +// paper & background +$paper: #ffffff; + +// primary +$primaryLight: #eef2f6; +$primaryMain: #2196f3; +$primaryDark: #1e88e5; +$primary200: #90caf9; +$primary800: #1565c0; + +// secondary +$secondaryLight: #ede7f6; +$secondaryMain: #673ab7; +$secondaryDark: #5e35b1; +$secondary200: #b39ddb; +$secondary800: #4527a0; + +// success Colors +$successLight: #b9f6ca; +$success200: #69f0ae; +$successMain: #00e676; +$successDark: #00c853; + +// error +$errorLight: #ef9a9a; +$errorMain: #f44336; +$errorDark: #c62828; + +// orange +$orangeLight: #fbe9e7; +$orangeMain: #ffab91; +$orangeDark: #d84315; + +// warning +$warningLight: #fff8e1; +$warningMain: #ffe57f; +$warningDark: #ffc107; + +// grey +$grey50: #f8fafc; +$grey100: #eef2f6; +$grey200: #e3e8ef; +$grey300: #cdd5df; +$grey500: #697586; +$grey600: #4b5565; +$grey700: #364152; +$grey900: #121926; + +// ==============================|| DARK THEME VARIANTS ||============================== // + +// paper & background +$darkBackground: #1a223f; // level 3 +$darkPaper: #111936; // level 4 + +// dark 800 & 900 +$darkLevel1: #29314f; // level 1 +$darkLevel2: #212946; // level 2 + +// primary dark +$darkPrimaryLight: #eef2f6; +$darkPrimaryMain: #2196f3; +$darkPrimaryDark: #1e88e5; +$darkPrimary200: #90caf9; +$darkPrimary800: #1565c0; + +// secondary dark +$darkSecondaryLight: #d1c4e9; +$darkSecondaryMain: #7c4dff; +$darkSecondaryDark: #651fff; +$darkSecondary200: #b39ddb; +$darkSecondary800: #6200ea; + +// text variants +$darkTextTitle: #d7dcec; +$darkTextPrimary: #bdc8f0; +$darkTextSecondary: #8492c4; + +// ==============================|| JAVASCRIPT ||============================== // + +:export { + // paper & background + paper: $paper; + + // primary + primaryLight: $primaryLight; + primary200: $primary200; + primaryMain: $primaryMain; + primaryDark: $primaryDark; + primary800: $primary800; + + // secondary + secondaryLight: $secondaryLight; + secondary200: $secondary200; + secondaryMain: $secondaryMain; + secondaryDark: $secondaryDark; + secondary800: $secondary800; + + // success + successLight: $successLight; + success200: $success200; + successMain: $successMain; + successDark: $successDark; + + // error + errorLight: $errorLight; + errorMain: $errorMain; + errorDark: $errorDark; + + // orange + orangeLight: $orangeLight; + orangeMain: $orangeMain; + orangeDark: $orangeDark; + + // warning + warningLight: $warningLight; + warningMain: $warningMain; + warningDark: $warningDark; + + // grey + grey50: $grey50; + grey100: $grey100; + grey200: $grey200; + grey300: $grey300; + grey500: $grey500; + grey600: $grey600; + grey700: $grey700; + grey900: $grey900; + + // ==============================|| DARK THEME VARIANTS ||============================== // + + // paper & background + darkPaper: $darkPaper; + darkBackground: $darkBackground; + + // dark 800 & 900 + darkLevel1: $darkLevel1; + darkLevel2: $darkLevel2; + + // text variants + darkTextTitle: $darkTextTitle; + darkTextPrimary: $darkTextPrimary; + darkTextSecondary: $darkTextSecondary; + + // primary dark + darkPrimaryLight: $darkPrimaryLight; + darkPrimaryMain: $darkPrimaryMain; + darkPrimaryDark: $darkPrimaryDark; + darkPrimary200: $darkPrimary200; + darkPrimary800: $darkPrimary800; + + // secondary dark + darkSecondaryLight: $darkSecondaryLight; + darkSecondaryMain: $darkSecondaryMain; + darkSecondaryDark: $darkSecondaryDark; + darkSecondary200: $darkSecondary200; + darkSecondary800: $darkSecondary800; +} diff --git a/web/berry/src/assets/scss/style.scss b/web/berry/src/assets/scss/style.scss new file mode 100644 index 00000000..17d566e6 --- /dev/null +++ b/web/berry/src/assets/scss/style.scss @@ -0,0 +1,128 @@ +// color variants +@import 'themes-vars.module.scss'; + +// third-party +@import '~react-perfect-scrollbar/dist/css/styles.css'; + +// ==============================|| LIGHT BOX ||============================== // +.fullscreen .react-images__blanket { + z-index: 1200; +} + +// ==============================|| APEXCHART ||============================== // + +.apexcharts-legend-series .apexcharts-legend-marker { + margin-right: 8px; +} + +// ==============================|| PERFECT SCROLLBAR ||============================== // + +.scrollbar-container { + .ps__rail-y { + &:hover > .ps__thumb-y, + &:focus > .ps__thumb-y, + &.ps--clicking .ps__thumb-y { + background-color: $grey500; + width: 5px; + } + } + .ps__thumb-y { + background-color: $grey500; + border-radius: 6px; + width: 5px; + right: 0; + } +} + +.scrollbar-container.ps, +.scrollbar-container > .ps { + &.ps--active-y > .ps__rail-y { + width: 5px; + background-color: transparent !important; + z-index: 999; + &:hover, + &.ps--clicking { + width: 5px; + background-color: transparent; + } + } + &.ps--scrolling-y > .ps__rail-y, + &.ps--scrolling-x > .ps__rail-x { + opacity: 0.4; + background-color: transparent; + } +} + +// ==============================|| ANIMATION KEYFRAMES ||============================== // + +@keyframes wings { + 50% { + transform: translateY(-40px); + } + 100% { + transform: translateY(0px); + } +} + +@keyframes blink { + 50% { + opacity: 0; + } + 100% { + opacity: 1; + } +} + +@keyframes bounce { + 0%, + 20%, + 53%, + to { + animation-timing-function: cubic-bezier(0.215, 0.61, 0.355, 1); + transform: translateZ(0); + } + 40%, + 43% { + animation-timing-function: cubic-bezier(0.755, 0.05, 0.855, 0.06); + transform: translate3d(0, -5px, 0); + } + 70% { + animation-timing-function: cubic-bezier(0.755, 0.05, 0.855, 0.06); + transform: translate3d(0, -7px, 0); + } + 80% { + transition-timing-function: cubic-bezier(0.215, 0.61, 0.355, 1); + transform: translateZ(0); + } + 90% { + transform: translate3d(0, -2px, 0); + } +} + +@keyframes slideY { + 0%, + 50%, + 100% { + transform: translateY(0px); + } + 25% { + transform: translateY(-10px); + } + 75% { + transform: translateY(10px); + } +} + +@keyframes slideX { + 0%, + 50%, + 100% { + transform: translateX(0px); + } + 25% { + transform: translateX(-10px); + } + 75% { + transform: translateX(10px); + } +} diff --git a/web/berry/src/config.js b/web/berry/src/config.js new file mode 100644 index 00000000..eeeda99a --- /dev/null +++ b/web/berry/src/config.js @@ -0,0 +1,29 @@ +const config = { + // basename: only at build time to set, and Don't add '/' at end off BASENAME for breadcrumbs, also Don't put only '/' use blank('') instead, + // like '/berry-material-react/react/default' + basename: '/', + defaultPath: '/panel/dashboard', + fontFamily: `'Roboto', sans-serif, Helvetica, Arial, sans-serif`, + borderRadius: 12, + siteInfo: { + chat_link: '', + display_in_currency: true, + email_verification: false, + footer_html: '', + github_client_id: '', + github_oauth: false, + logo: '', + quota_per_unit: 500000, + server_address: '', + start_time: 0, + system_name: 'One API', + top_up_link: '', + turnstile_check: false, + turnstile_site_key: '', + version: '', + wechat_login: false, + wechat_qrcode: '' + } +}; + +export default config; diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js new file mode 100644 index 00000000..3ce27838 --- /dev/null +++ b/web/berry/src/constants/ChannelConstants.js @@ -0,0 +1,146 @@ +export const CHANNEL_OPTIONS = { + 1: { + key: 1, + text: 'OpenAI', + value: 1, + color: 'primary' + }, + 14: { + key: 14, + text: 'Anthropic Claude', + value: 14, + color: 'info' + }, + 3: { + key: 3, + text: 'Azure OpenAI', + value: 3, + color: 'orange' + }, + 11: { + key: 11, + text: 'Google PaLM2', + value: 11, + color: 'orange' + }, + 24: { + key: 24, + text: 'Google Gemini', + value: 24, + color: 'orange' + }, + 15: { + key: 15, + text: '百度文心千帆', + value: 15, + color: 'default' + }, + 17: { + key: 17, + text: '阿里通义千问', + value: 17, + color: 'default' + }, + 18: { + key: 18, + text: '讯飞星火认知', + value: 18, + color: 'default' + }, + 16: { + key: 16, + text: '智谱 ChatGLM', + value: 16, + color: 'default' + }, + 19: { + key: 19, + text: '360 智脑', + value: 19, + color: 'default' + }, + 23: { + key: 23, + text: '腾讯混元', + value: 23, + color: 'default' + }, + 8: { + key: 8, + text: '自定义渠道', + value: 8, + color: 'primary' + }, + 22: { + key: 22, + text: '知识库:FastGPT', + value: 22, + color: 'default' + }, + 21: { + key: 21, + text: '知识库:AI Proxy', + value: 21, + color: 'purple' + }, + 20: { + key: 20, + text: '代理:OpenRouter', + value: 20, + color: 'primary' + }, + 2: { + key: 2, + text: '代理:API2D', + value: 2, + color: 'primary' + }, + 5: { + key: 5, + text: '代理:OpenAI-SB', + value: 5, + color: 'primary' + }, + 7: { + key: 7, + text: '代理:OhMyGPT', + value: 7, + color: 'primary' + }, + 10: { + key: 10, + text: '代理:AI Proxy', + value: 10, + color: 'primary' + }, + 4: { + key: 4, + text: '代理:CloseAI', + value: 4, + color: 'primary' + }, + 6: { + key: 6, + text: '代理:OpenAI Max', + value: 6, + color: 'primary' + }, + 9: { + key: 9, + text: '代理:AI.LS', + value: 9, + color: 'primary' + }, + 12: { + key: 12, + text: '代理:API2GPT', + value: 12, + color: 'primary' + }, + 13: { + key: 13, + text: '代理:AIGC2D', + value: 13, + color: 'primary' + } +}; diff --git a/web/berry/src/constants/CommonConstants.js b/web/berry/src/constants/CommonConstants.js new file mode 100644 index 00000000..1a37d5f6 --- /dev/null +++ b/web/berry/src/constants/CommonConstants.js @@ -0,0 +1 @@ +export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend! diff --git a/web/berry/src/constants/SnackbarConstants.js b/web/berry/src/constants/SnackbarConstants.js new file mode 100644 index 00000000..a05c6652 --- /dev/null +++ b/web/berry/src/constants/SnackbarConstants.js @@ -0,0 +1,27 @@ +export const snackbarConstants = { + Common: { + ERROR: { + variant: 'error', + autoHideDuration: 5000 + }, + WARNING: { + variant: 'warning', + autoHideDuration: 10000 + }, + SUCCESS: { + variant: 'success', + autoHideDuration: 1500 + }, + INFO: { + variant: 'info', + autoHideDuration: 3000 + }, + NOTICE: { + variant: 'info', + autoHideDuration: 20000 + } + }, + Mobile: { + anchorOrigin: { vertical: 'bottom', horizontal: 'center' } + } +}; diff --git a/web/berry/src/constants/index.js b/web/berry/src/constants/index.js new file mode 100644 index 00000000..716ef6aa --- /dev/null +++ b/web/berry/src/constants/index.js @@ -0,0 +1,3 @@ +export * from './SnackbarConstants'; +export * from './CommonConstants'; +export * from './ChannelConstants'; diff --git a/web/berry/src/contexts/StatusContext.js b/web/berry/src/contexts/StatusContext.js new file mode 100644 index 00000000..ed9f1621 --- /dev/null +++ b/web/berry/src/contexts/StatusContext.js @@ -0,0 +1,70 @@ +import { useEffect, useCallback, createContext } from "react"; +import { API } from "utils/api"; +import { showNotice, showError } from "utils/common"; +import { SET_SITE_INFO } from "store/actions"; +import { useDispatch } from "react-redux"; + +export const LoadStatusContext = createContext(); + +// eslint-disable-next-line +const StatusProvider = ({ children }) => { + const dispatch = useDispatch(); + + const loadStatus = useCallback(async () => { + const res = await API.get("/api/status"); + const { success, data } = res.data; + let system_name = ""; + if (success) { + if (!data.chat_link) { + delete data.chat_link; + } + localStorage.setItem("siteInfo", JSON.stringify(data)); + localStorage.setItem("quota_per_unit", data.quota_per_unit); + localStorage.setItem("display_in_currency", data.display_in_currency); + dispatch({ type: SET_SITE_INFO, payload: data }); + if ( + data.version !== process.env.REACT_APP_VERSION && + data.version !== "v0.0.0" && + data.version !== "" && + process.env.REACT_APP_VERSION !== "" + ) { + showNotice( + `新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面` + ); + } + if (data.system_name) { + system_name = data.system_name; + } + } else { + const backupSiteInfo = localStorage.getItem("siteInfo"); + if (backupSiteInfo) { + const data = JSON.parse(backupSiteInfo); + if (data.system_name) { + system_name = data.system_name; + } + dispatch({ + type: SET_SITE_INFO, + payload: data, + }); + } + showError("无法正常连接至服务器!"); + } + + if (system_name) { + document.title = system_name; + } + }, [dispatch]); + + useEffect(() => { + loadStatus().then(); + }, [loadStatus]); + + return ( + + {" "} + {children}{" "} + + ); +}; + +export default StatusProvider; diff --git a/web/berry/src/contexts/UserContext.js b/web/berry/src/contexts/UserContext.js new file mode 100644 index 00000000..491da9d9 --- /dev/null +++ b/web/berry/src/contexts/UserContext.js @@ -0,0 +1,29 @@ +// contexts/User/index.jsx +import React, { useEffect, useCallback, createContext, useState } from 'react'; +import { LOGIN } from 'store/actions'; +import { useDispatch } from 'react-redux'; + +export const UserContext = createContext(); + +// eslint-disable-next-line +const UserProvider = ({ children }) => { + const dispatch = useDispatch(); + const [isUserLoaded, setIsUserLoaded] = useState(false); + + const loadUser = useCallback(() => { + let user = localStorage.getItem('user'); + if (user) { + let data = JSON.parse(user); + dispatch({ type: LOGIN, payload: data }); + } + setIsUserLoaded(true); + }, [dispatch]); + + useEffect(() => { + loadUser(); + }, [loadUser]); + + return {children} ; +}; + +export default UserProvider; diff --git a/web/berry/src/hooks/useAuth.js b/web/berry/src/hooks/useAuth.js new file mode 100644 index 00000000..fa7cb934 --- /dev/null +++ b/web/berry/src/hooks/useAuth.js @@ -0,0 +1,13 @@ +import { isAdmin } from 'utils/common'; +import { useNavigate } from 'react-router-dom'; +const navigate = useNavigate(); + +const useAuth = () => { + const userIsAdmin = isAdmin(); + + if (!userIsAdmin) { + navigate('/panel/404'); + } +}; + +export default useAuth; diff --git a/web/berry/src/hooks/useLogin.js b/web/berry/src/hooks/useLogin.js new file mode 100644 index 00000000..53626577 --- /dev/null +++ b/web/berry/src/hooks/useLogin.js @@ -0,0 +1,78 @@ +import { API } from 'utils/api'; +import { useDispatch } from 'react-redux'; +import { LOGIN } from 'store/actions'; +import { useNavigate } from 'react-router'; +import { showSuccess } from 'utils/common'; + +const useLogin = () => { + const dispatch = useDispatch(); + const navigate = useNavigate(); + const login = async (username, password) => { + try { + const res = await API.post(`/api/user/login`, { + username, + password + }); + const { success, message, data } = res.data; + if (success) { + localStorage.setItem('user', JSON.stringify(data)); + dispatch({ type: LOGIN, payload: data }); + navigate('/panel'); + } + return { success, message }; + } catch (err) { + // 请求失败,设置错误信息 + return { success: false, message: '' }; + } + }; + + const githubLogin = async (code, state) => { + try { + const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/panel'); + } else { + dispatch({ type: LOGIN, payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/panel'); + } + } + return { success, message }; + } catch (err) { + // 请求失败,设置错误信息 + return { success: false, message: '' }; + } + }; + + const wechatLogin = async (code) => { + try { + const res = await API.get(`/api/oauth/wechat?code=${code}`); + const { success, message, data } = res.data; + if (success) { + dispatch({ type: LOGIN, payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/panel'); + } + return { success, message }; + } catch (err) { + // 请求失败,设置错误信息 + return { success: false, message: '' }; + } + }; + + const logout = async () => { + await API.get('/api/user/logout'); + localStorage.removeItem('user'); + dispatch({ type: LOGIN, payload: null }); + navigate('/'); + }; + + return { login, logout, githubLogin, wechatLogin }; +}; + +export default useLogin; diff --git a/web/berry/src/hooks/useRegister.js b/web/berry/src/hooks/useRegister.js new file mode 100644 index 00000000..5377e96d --- /dev/null +++ b/web/berry/src/hooks/useRegister.js @@ -0,0 +1,43 @@ +import { API } from 'utils/api'; +import { useNavigate } from 'react-router'; +import { showSuccess } from 'utils/common'; + +const useRegister = () => { + const navigate = useNavigate(); + const register = async (input, turnstile) => { + try { + let affCode = localStorage.getItem('aff'); + if (affCode) { + input = { ...input, aff_code: affCode }; + } + const res = await API.post(`/api/user/register?turnstile=${turnstile}`, input); + const { success, message } = res.data; + if (success) { + showSuccess('注册成功!'); + navigate('/login'); + } + return { success, message }; + } catch (err) { + // 请求失败,设置错误信息 + return { success: false, message: '' }; + } + }; + + const sendVerificationCode = async (email, turnstile) => { + try { + const res = await API.get(`/api/verification?email=${email}&turnstile=${turnstile}`); + const { success, message } = res.data; + if (success) { + showSuccess('验证码发送成功,请检查你的邮箱!'); + } + return { success, message }; + } catch (err) { + // 请求失败,设置错误信息 + return { success: false, message: '' }; + } + }; + + return { register, sendVerificationCode }; +}; + +export default useRegister; diff --git a/web/berry/src/hooks/useScriptRef.js b/web/berry/src/hooks/useScriptRef.js new file mode 100644 index 00000000..bd300cbb --- /dev/null +++ b/web/berry/src/hooks/useScriptRef.js @@ -0,0 +1,18 @@ +import { useEffect, useRef } from 'react'; + +// ==============================|| ELEMENT REFERENCE HOOKS ||============================== // + +const useScriptRef = () => { + const scripted = useRef(true); + + useEffect( + () => () => { + scripted.current = true; + }, + [] + ); + + return scripted; +}; + +export default useScriptRef; diff --git a/web/berry/src/index.js b/web/berry/src/index.js new file mode 100644 index 00000000..d1411be3 --- /dev/null +++ b/web/berry/src/index.js @@ -0,0 +1,31 @@ +import { createRoot } from 'react-dom/client'; + +// third party +import { BrowserRouter } from 'react-router-dom'; +import { Provider } from 'react-redux'; + +// project imports +import * as serviceWorker from 'serviceWorker'; +import App from 'App'; +import { store } from 'store'; + +// style + assets +import 'assets/scss/style.scss'; +import config from './config'; + +// ==============================|| REACT DOM RENDER ||============================== // + +const container = document.getElementById('root'); +const root = createRoot(container); // createRoot(container!) if you use TypeScript +root.render( + + + + + +); + +// If you want your app to work offline and load faster, you can change +// unregister() to register() below. Note this comes with some pitfalls. +// Learn more about service workers: https://bit.ly/CRA-PWA +serviceWorker.register(); diff --git a/web/berry/src/layout/MainLayout/Header/ProfileSection/index.js b/web/berry/src/layout/MainLayout/Header/ProfileSection/index.js new file mode 100644 index 00000000..3e351254 --- /dev/null +++ b/web/berry/src/layout/MainLayout/Header/ProfileSection/index.js @@ -0,0 +1,173 @@ +import { useState, useRef, useEffect } from 'react'; + +import { useSelector } from 'react-redux'; +import { useNavigate } from 'react-router-dom'; +// material-ui +import { useTheme } from '@mui/material/styles'; +import { + Avatar, + Chip, + ClickAwayListener, + List, + ListItemButton, + ListItemIcon, + ListItemText, + Paper, + Popper, + Typography +} from '@mui/material'; + +// project imports +import MainCard from 'ui-component/cards/MainCard'; +import Transitions from 'ui-component/extended/Transitions'; +import User1 from 'assets/images/users/user-round.svg'; +import useLogin from 'hooks/useLogin'; + +// assets +import { IconLogout, IconSettings, IconUserScan } from '@tabler/icons-react'; + +// ==============================|| PROFILE MENU ||============================== // + +const ProfileSection = () => { + const theme = useTheme(); + const navigate = useNavigate(); + const customization = useSelector((state) => state.customization); + const { logout } = useLogin(); + + const [open, setOpen] = useState(false); + /** + * anchorRef is used on different componets and specifying one type leads to other components throwing an error + * */ + const anchorRef = useRef(null); + const handleLogout = async () => { + logout(); + }; + + const handleClose = (event) => { + if (anchorRef.current && anchorRef.current.contains(event.target)) { + return; + } + setOpen(false); + }; + + const handleToggle = () => { + setOpen((prevOpen) => !prevOpen); + }; + + const prevOpen = useRef(open); + useEffect(() => { + if (prevOpen.current === true && open === false) { + anchorRef.current.focus(); + } + + prevOpen.current = open; + }, [open]); + + return ( + <> + + } + label={} + variant="outlined" + ref={anchorRef} + aria-controls={open ? 'menu-list-grow' : undefined} + aria-haspopup="true" + onClick={handleToggle} + color="primary" + /> + + {({ TransitionProps }) => ( + + + + + + navigate('/panel/profile')}> + + + + 设置} /> + + + + + + + 登出} /> + + + + + + + )} + + + ); +}; + +export default ProfileSection; diff --git a/web/berry/src/layout/MainLayout/Header/index.js b/web/berry/src/layout/MainLayout/Header/index.js new file mode 100644 index 00000000..51d40c75 --- /dev/null +++ b/web/berry/src/layout/MainLayout/Header/index.js @@ -0,0 +1,68 @@ +import PropTypes from 'prop-types'; + +// material-ui +import { useTheme } from '@mui/material/styles'; +import { Avatar, Box, ButtonBase } from '@mui/material'; + +// project imports +import LogoSection from '../LogoSection'; +import ProfileSection from './ProfileSection'; + +// assets +import { IconMenu2 } from '@tabler/icons-react'; + +// ==============================|| MAIN NAVBAR / HEADER ||============================== // + +const Header = ({ handleLeftDrawerToggle }) => { + const theme = useTheme(); + + return ( + <> + {/* logo & toggler button */} + + + + + + + + + + + + + + + + + ); +}; + +Header.propTypes = { + handleLeftDrawerToggle: PropTypes.func +}; + +export default Header; diff --git a/web/berry/src/layout/MainLayout/LogoSection/index.js b/web/berry/src/layout/MainLayout/LogoSection/index.js new file mode 100644 index 00000000..1d70e48c --- /dev/null +++ b/web/berry/src/layout/MainLayout/LogoSection/index.js @@ -0,0 +1,23 @@ +import { Link } from 'react-router-dom'; +import { useDispatch, useSelector } from 'react-redux'; + +// material-ui +import { ButtonBase } from '@mui/material'; + +// project imports +import Logo from 'ui-component/Logo'; +import { MENU_OPEN } from 'store/actions'; + +// ==============================|| MAIN LOGO ||============================== // + +const LogoSection = () => { + const defaultId = useSelector((state) => state.customization.defaultId); + const dispatch = useDispatch(); + return ( + dispatch({ type: MENU_OPEN, id: defaultId })} component={Link} to="/"> + + + ); +}; + +export default LogoSection; diff --git a/web/berry/src/layout/MainLayout/Sidebar/MenuCard/index.js b/web/berry/src/layout/MainLayout/Sidebar/MenuCard/index.js new file mode 100644 index 00000000..16b13231 --- /dev/null +++ b/web/berry/src/layout/MainLayout/Sidebar/MenuCard/index.js @@ -0,0 +1,130 @@ +// import PropTypes from 'prop-types'; +import { useSelector } from 'react-redux'; + +// material-ui +import { styled, useTheme } from '@mui/material/styles'; +import { + Avatar, + Card, + CardContent, + // Grid, + // LinearProgress, + List, + ListItem, + ListItemAvatar, + ListItemText, + Typography + // linearProgressClasses +} from '@mui/material'; +import User1 from 'assets/images/users/user-round.svg'; +import { useNavigate } from 'react-router-dom'; + +// assets +// import TableChartOutlinedIcon from '@mui/icons-material/TableChartOutlined'; + +// styles +// const BorderLinearProgress = styled(LinearProgress)(({ theme }) => ({ +// height: 10, +// borderRadius: 30, +// [`&.${linearProgressClasses.colorPrimary}`]: { +// backgroundColor: '#fff' +// }, +// [`& .${linearProgressClasses.bar}`]: { +// borderRadius: 5, +// backgroundColor: theme.palette.primary.main +// } +// })); + +const CardStyle = styled(Card)(({ theme }) => ({ + background: theme.palette.primary.light, + marginBottom: '22px', + overflow: 'hidden', + position: 'relative', + '&:after': { + content: '""', + position: 'absolute', + width: '157px', + height: '157px', + background: theme.palette.primary[200], + borderRadius: '50%', + top: '-105px', + right: '-96px' + } +})); + +// ==============================|| PROGRESS BAR WITH LABEL ||============================== // + +// function LinearProgressWithLabel({ value, ...others }) { +// const theme = useTheme(); + +// return ( +// +// +// +// +// +// Progress +// +// +// +// {`${Math.round(value)}%`} +// +// +// +// +// +// +// +// ); +// } + +// LinearProgressWithLabel.propTypes = { +// value: PropTypes.number +// }; + +// ==============================|| SIDEBAR MENU Card ||============================== // + +const MenuCard = () => { + const theme = useTheme(); + const account = useSelector((state) => state.account); + const navigate = useNavigate(); + + return ( + + + + + + navigate('/panel/profile')} + > + + + {account.user?.username} + + } + secondary={ 欢迎回来 } + /> + + + {/* */} + + + ); +}; + +export default MenuCard; diff --git a/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavCollapse/index.js b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavCollapse/index.js new file mode 100644 index 00000000..0632d56f --- /dev/null +++ b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavCollapse/index.js @@ -0,0 +1,158 @@ +import PropTypes from 'prop-types'; +import { useEffect, useState } from 'react'; +import { useSelector } from 'react-redux'; +import { useLocation, useNavigate } from 'react-router'; + +// material-ui +import { useTheme } from '@mui/material/styles'; +import { Collapse, List, ListItemButton, ListItemIcon, ListItemText, Typography } from '@mui/material'; + +// project imports +import NavItem from '../NavItem'; + +// assets +import FiberManualRecordIcon from '@mui/icons-material/FiberManualRecord'; +import { IconChevronDown, IconChevronUp } from '@tabler/icons-react'; + +// ==============================|| SIDEBAR MENU LIST COLLAPSE ITEMS ||============================== // + +const NavCollapse = ({ menu, level }) => { + const theme = useTheme(); + const customization = useSelector((state) => state.customization); + const navigate = useNavigate(); + + const [open, setOpen] = useState(false); + const [selected, setSelected] = useState(null); + + const handleClick = () => { + setOpen(!open); + setSelected(!selected ? menu.id : null); + if (menu?.id !== 'authentication') { + navigate(menu.children[0]?.url); + } + }; + + const { pathname } = useLocation(); + const checkOpenForParent = (child, id) => { + child.forEach((item) => { + if (item.url === pathname) { + setOpen(true); + setSelected(id); + } + }); + }; + + // menu collapse for sub-levels + useEffect(() => { + setOpen(false); + setSelected(null); + if (menu.children) { + menu.children.forEach((item) => { + if (item.children?.length) { + checkOpenForParent(item.children, menu.id); + } + if (item.url === pathname) { + setSelected(menu.id); + setOpen(true); + } + }); + } + + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [pathname, menu.children]); + + // menu collapse & item + const menus = menu.children?.map((item) => { + switch (item.type) { + case 'collapse': + return ; + case 'item': + return ; + default: + return ( + + Menu Items Error + + ); + } + }); + + const Icon = menu.icon; + const menuIcon = menu.icon ? ( + + ) : ( + 0 ? 'inherit' : 'medium'} + /> + ); + + return ( + <> + 1 ? 'transparent !important' : 'inherit', + py: level > 1 ? 1 : 1.25, + pl: `${level * 24}px` + }} + selected={selected === menu.id} + onClick={handleClick} + > + {menuIcon} + + {menu.title} + + } + secondary={ + menu.caption && ( + + {menu.caption} + + ) + } + /> + {open ? ( + + ) : ( + + )} + + + + {menus} + + + + ); +}; + +NavCollapse.propTypes = { + menu: PropTypes.object, + level: PropTypes.number +}; + +export default NavCollapse; diff --git a/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavGroup/index.js b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavGroup/index.js new file mode 100644 index 00000000..b6479bc2 --- /dev/null +++ b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavGroup/index.js @@ -0,0 +1,61 @@ +import PropTypes from 'prop-types'; + +// material-ui +import { useTheme } from '@mui/material/styles'; +import { Divider, List, Typography } from '@mui/material'; + +// project imports +import NavItem from '../NavItem'; +import NavCollapse from '../NavCollapse'; + +// ==============================|| SIDEBAR MENU LIST GROUP ||============================== // + +const NavGroup = ({ item }) => { + const theme = useTheme(); + + // menu list collapse & items + const items = item.children?.map((menu) => { + switch (menu.type) { + case 'collapse': + return ; + case 'item': + return ; + default: + return ( + + Menu Items Error + + ); + } + }); + + return ( + <> + + {item.title} + {item.caption && ( + + {item.caption} + + )} + + ) + } + > + {items} + + + {/* group divider */} + + + ); +}; + +NavGroup.propTypes = { + item: PropTypes.object +}; + +export default NavGroup; diff --git a/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavItem/index.js b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavItem/index.js new file mode 100644 index 00000000..ddce9cf4 --- /dev/null +++ b/web/berry/src/layout/MainLayout/Sidebar/MenuList/NavItem/index.js @@ -0,0 +1,115 @@ +import PropTypes from 'prop-types'; +import { forwardRef, useEffect } from 'react'; +import { Link, useLocation } from 'react-router-dom'; +import { useDispatch, useSelector } from 'react-redux'; + +// material-ui +import { useTheme } from '@mui/material/styles'; +import { Avatar, Chip, ListItemButton, ListItemIcon, ListItemText, Typography, useMediaQuery } from '@mui/material'; + +// project imports +import { MENU_OPEN, SET_MENU } from 'store/actions'; + +// assets +import FiberManualRecordIcon from '@mui/icons-material/FiberManualRecord'; + +// ==============================|| SIDEBAR MENU LIST ITEMS ||============================== // + +const NavItem = ({ item, level }) => { + const theme = useTheme(); + const dispatch = useDispatch(); + const { pathname } = useLocation(); + const customization = useSelector((state) => state.customization); + const matchesSM = useMediaQuery(theme.breakpoints.down('lg')); + + const Icon = item.icon; + const itemIcon = item?.icon ? ( + + ) : ( + id === item?.id) > -1 ? 8 : 6, + height: customization.isOpen.findIndex((id) => id === item?.id) > -1 ? 8 : 6 + }} + fontSize={level > 0 ? 'inherit' : 'medium'} + /> + ); + + let itemTarget = '_self'; + if (item.target) { + itemTarget = '_blank'; + } + + let listItemProps = { + component: forwardRef((props, ref) => ) + }; + if (item?.external) { + listItemProps = { component: 'a', href: item.url, target: itemTarget }; + } + + const itemHandler = (id) => { + dispatch({ type: MENU_OPEN, id }); + if (matchesSM) dispatch({ type: SET_MENU, opened: false }); + }; + + // active menu item on page load + useEffect(() => { + const currentIndex = document.location.pathname + .toString() + .split('/') + .findIndex((id) => id === item.id); + if (currentIndex > -1) { + dispatch({ type: MENU_OPEN, id: item.id }); + } + // eslint-disable-next-line + }, [pathname]); + + return ( + 1 ? 'transparent !important' : 'inherit', + py: level > 1 ? 1 : 1.25, + pl: `${level * 24}px` + }} + selected={customization.isOpen.findIndex((id) => id === item.id) > -1} + onClick={() => itemHandler(item.id)} + > + {itemIcon} + id === item.id) > -1 ? 'h5' : 'body1'} color="inherit"> + {item.title} + + } + secondary={ + item.caption && ( + + {item.caption} + + ) + } + /> + {item.chip && ( + {item.chip.avatar}} + /> + )} + + ); +}; + +NavItem.propTypes = { + item: PropTypes.object, + level: PropTypes.number +}; + +export default NavItem; diff --git a/web/berry/src/layout/MainLayout/Sidebar/MenuList/index.js b/web/berry/src/layout/MainLayout/Sidebar/MenuList/index.js new file mode 100644 index 00000000..4872057a --- /dev/null +++ b/web/berry/src/layout/MainLayout/Sidebar/MenuList/index.js @@ -0,0 +1,36 @@ +// material-ui +import { Typography } from '@mui/material'; + +// project imports +import NavGroup from './NavGroup'; +import menuItem from 'menu-items'; +import { isAdmin } from 'utils/common'; + +// ==============================|| SIDEBAR MENU LIST ||============================== // +const MenuList = () => { + const userIsAdmin = isAdmin(); + + return ( + <> + {menuItem.items.map((item) => { + if (item.type !== 'group') { + return ( + + Menu Items Error + + ); + } + + const filteredChildren = item.children.filter((child) => !child.isAdmin || userIsAdmin); + + if (filteredChildren.length === 0) { + return null; + } + + return ; + })} + + ); +}; + +export default MenuList; diff --git a/web/berry/src/layout/MainLayout/Sidebar/index.js b/web/berry/src/layout/MainLayout/Sidebar/index.js new file mode 100644 index 00000000..e3c6d12d --- /dev/null +++ b/web/berry/src/layout/MainLayout/Sidebar/index.js @@ -0,0 +1,94 @@ +import PropTypes from 'prop-types'; + +// material-ui +import { useTheme } from '@mui/material/styles'; +import { Box, Chip, Drawer, Stack, useMediaQuery } from '@mui/material'; + +// third-party +import PerfectScrollbar from 'react-perfect-scrollbar'; +import { BrowserView, MobileView } from 'react-device-detect'; + +// project imports +import MenuList from './MenuList'; +import LogoSection from '../LogoSection'; +import MenuCard from './MenuCard'; +import { drawerWidth } from 'store/constant'; + +// ==============================|| SIDEBAR DRAWER ||============================== // + +const Sidebar = ({ drawerOpen, drawerToggle, window }) => { + const theme = useTheme(); + const matchUpMd = useMediaQuery(theme.breakpoints.up('md')); + + const drawer = ( + <> + + + + + + + + + + + + + + + + + + + + + + + + + ); + + const container = window !== undefined ? () => window.document.body : undefined; + + return ( + + + {drawer} + + + ); +}; + +Sidebar.propTypes = { + drawerOpen: PropTypes.bool, + drawerToggle: PropTypes.func, + window: PropTypes.object +}; + +export default Sidebar; diff --git a/web/berry/src/layout/MainLayout/index.js b/web/berry/src/layout/MainLayout/index.js new file mode 100644 index 00000000..973a167b --- /dev/null +++ b/web/berry/src/layout/MainLayout/index.js @@ -0,0 +1,103 @@ +import { useDispatch, useSelector } from 'react-redux'; +import { Outlet } from 'react-router-dom'; +import AuthGuard from 'utils/route-guard/AuthGuard'; + +// material-ui +import { styled, useTheme } from '@mui/material/styles'; +import { AppBar, Box, CssBaseline, Toolbar, useMediaQuery } from '@mui/material'; +import AdminContainer from 'ui-component/AdminContainer'; + +// project imports +import Breadcrumbs from 'ui-component/extended/Breadcrumbs'; +import Header from './Header'; +import Sidebar from './Sidebar'; +import navigation from 'menu-items'; +import { drawerWidth } from 'store/constant'; +import { SET_MENU } from 'store/actions'; + +// assets +import { IconChevronRight } from '@tabler/icons-react'; + +// styles +const Main = styled('main', { shouldForwardProp: (prop) => prop !== 'open' })(({ theme, open }) => ({ + ...theme.typography.mainContent, + borderBottomLeftRadius: 0, + borderBottomRightRadius: 0, + transition: theme.transitions.create( + 'margin', + open + ? { + easing: theme.transitions.easing.easeOut, + duration: theme.transitions.duration.enteringScreen + } + : { + easing: theme.transitions.easing.sharp, + duration: theme.transitions.duration.leavingScreen + } + ), + [theme.breakpoints.up('md')]: { + marginLeft: open ? 0 : -(drawerWidth - 20), + width: `calc(100% - ${drawerWidth}px)` + }, + [theme.breakpoints.down('md')]: { + marginLeft: '20px', + width: `calc(100% - ${drawerWidth}px)`, + padding: '16px' + }, + [theme.breakpoints.down('sm')]: { + marginLeft: '10px', + width: `calc(100% - ${drawerWidth}px)`, + padding: '16px', + marginRight: '10px' + } +})); + +// ==============================|| MAIN LAYOUT ||============================== // + +const MainLayout = () => { + const theme = useTheme(); + const matchDownMd = useMediaQuery(theme.breakpoints.down('md')); + // Handle left drawer + const leftDrawerOpened = useSelector((state) => state.customization.opened); + const dispatch = useDispatch(); + const handleLeftDrawerToggle = () => { + dispatch({ type: SET_MENU, opened: !leftDrawerOpened }); + }; + + return ( + + + {/* header */} + + +
+ + + + {/* drawer */} + + + {/* main content */} +
+ {/* breadcrumb */} + + + + + + +
+ + ); +}; + +export default MainLayout; diff --git a/web/berry/src/layout/MinimalLayout/Header/index.js b/web/berry/src/layout/MinimalLayout/Header/index.js new file mode 100644 index 00000000..4f61da60 --- /dev/null +++ b/web/berry/src/layout/MinimalLayout/Header/index.js @@ -0,0 +1,75 @@ +// material-ui +import { useTheme } from "@mui/material/styles"; +import { Box, Button, Stack } from "@mui/material"; +import LogoSection from "layout/MainLayout/LogoSection"; +import { Link } from "react-router-dom"; +import { useLocation } from "react-router-dom"; +import { useSelector } from "react-redux"; + +// ==============================|| MAIN NAVBAR / HEADER ||============================== // + +const Header = () => { + const theme = useTheme(); + const { pathname } = useLocation(); + const account = useSelector((state) => state.account); + + return ( + <> + + + + + + + + + + + + {account.user ? ( + + ) : ( + + )} + + + ); +}; + +export default Header; diff --git a/web/berry/src/layout/MinimalLayout/index.js b/web/berry/src/layout/MinimalLayout/index.js new file mode 100644 index 00000000..c2919c6d --- /dev/null +++ b/web/berry/src/layout/MinimalLayout/index.js @@ -0,0 +1,39 @@ +import { Outlet } from 'react-router-dom'; +import { useTheme } from '@mui/material/styles'; +import { AppBar, Box, CssBaseline, Toolbar } from '@mui/material'; +import Header from './Header'; +import Footer from 'ui-component/Footer'; + +// ==============================|| MINIMAL LAYOUT ||============================== // + +const MinimalLayout = () => { + const theme = useTheme(); + + return ( + + + + +
+ + + + + + +