diff --git a/.gitignore b/.gitignore
index 60abb13e..974fcf63 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,4 +6,5 @@ upload
build
*.db-journal
logs
-data
\ No newline at end of file
+data
+/web/node_modules
diff --git a/Dockerfile b/Dockerfile
index bbe45905..1a69f7a2 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -23,7 +23,7 @@ ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /web/build ./web/build
-RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
+RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine
diff --git a/README.en.md b/README.en.md
index e7f254f7..eec0047b 100644
--- a/README.en.md
+++ b/README.en.md
@@ -134,12 +134,12 @@ The initial account username is `root` and password is `123456`.
git clone https://github.com/songquanpeng/one-api.git
# Build the frontend
- cd one-api/web
+ cd one-api/web/default
npm install
npm run build
# Build the backend
- cd ..
+ cd ../..
go mod download
go build -ldflags "-s -w" -o one-api
```
diff --git a/README.ja.md b/README.ja.md
index edfd2a28..e9149d71 100644
--- a/README.ja.md
+++ b/README.ja.md
@@ -135,12 +135,12 @@ sudo service nginx restart
git clone https://github.com/songquanpeng/one-api.git
# フロントエンドのビルド
- cd one-api/web
+ cd one-api/web/default
npm install
npm run build
# バックエンドのビルド
- cd ..
+ cd ../..
go mod download
go build -ldflags "-s -w" -o one-api
```
diff --git a/README.md b/README.md
index 02a62387..ff1fffd2 100644
--- a/README.md
+++ b/README.md
@@ -73,6 +73,9 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
+ [x] [360 智脑](https://ai.360.cn)
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
+ + [x] [Moonshot AI](https://platform.moonshot.cn/)
+ + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP)
+ + [ ] [MINIMAX](https://api.minimax.chat/) (WIP)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@@ -174,12 +177,12 @@ docker-compose ps
git clone https://github.com/songquanpeng/one-api.git
# 构建前端
- cd one-api/web
+ cd one-api/web/default
npm install
npm run build
# 构建后端
- cd ..
+ cd ../..
go mod download
go build -ldflags "-s -w" -o one-api
````
diff --git a/common/config/config.go b/common/config/config.go
new file mode 100644
index 00000000..dd0236b4
--- /dev/null
+++ b/common/config/config.go
@@ -0,0 +1,127 @@
+package config
+
+import (
+ "github.com/songquanpeng/one-api/common/helper"
+ "os"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+)
+
+var SystemName = "One API"
+var ServerAddress = "http://localhost:3000"
+var Footer = ""
+var Logo = ""
+var TopUpLink = ""
+var ChatLink = ""
+var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
+var DisplayInCurrencyEnabled = true
+var DisplayTokenStatEnabled = true
+
+// Any options with "Secret", "Token" in its key won't be return by GetOptions
+
+var SessionSecret = uuid.New().String()
+
+var OptionMap map[string]string
+var OptionMapRWMutex sync.RWMutex
+
+var ItemsPerPage = 10
+var MaxRecentItems = 100
+
+var PasswordLoginEnabled = true
+var PasswordRegisterEnabled = true
+var EmailVerificationEnabled = false
+var GitHubOAuthEnabled = false
+var WeChatAuthEnabled = false
+var TurnstileCheckEnabled = false
+var RegisterEnabled = true
+
+var EmailDomainRestrictionEnabled = false
+var EmailDomainWhitelist = []string{
+ "gmail.com",
+ "163.com",
+ "126.com",
+ "qq.com",
+ "outlook.com",
+ "hotmail.com",
+ "icloud.com",
+ "yahoo.com",
+ "foxmail.com",
+}
+
+var DebugEnabled = os.Getenv("DEBUG") == "true"
+var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
+
+var LogConsumeEnabled = true
+
+var SMTPServer = ""
+var SMTPPort = 587
+var SMTPAccount = ""
+var SMTPFrom = ""
+var SMTPToken = ""
+
+var GitHubClientId = ""
+var GitHubClientSecret = ""
+
+var WeChatServerAddress = ""
+var WeChatServerToken = ""
+var WeChatAccountQRCodeImageURL = ""
+
+var TurnstileSiteKey = ""
+var TurnstileSecretKey = ""
+
+var QuotaForNewUser = 0
+var QuotaForInviter = 0
+var QuotaForInvitee = 0
+var ChannelDisableThreshold = 5.0
+var AutomaticDisableChannelEnabled = false
+var AutomaticEnableChannelEnabled = false
+var QuotaRemindThreshold = 1000
+var PreConsumedQuota = 500
+var ApproximateTokenEnabled = false
+var RetryTimes = 0
+
+var RootUserEmail = ""
+
+var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
+
+var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
+var RequestInterval = time.Duration(requestInterval) * time.Second
+
+var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
+
+var BatchUpdateEnabled = false
+var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
+
+var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
+
+var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
+
+var Theme = helper.GetOrDefaultEnvString("THEME", "default")
+var ValidThemes = map[string]bool{
+ "default": true,
+ "berry": true,
+}
+
+// All duration's unit is seconds
+// Shouldn't larger then RateLimitKeyExpirationDuration
+var (
+ GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
+ GlobalApiRateLimitDuration int64 = 3 * 60
+
+ GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
+ GlobalWebRateLimitDuration int64 = 3 * 60
+
+ UploadRateLimitNum = 10
+ UploadRateLimitDuration int64 = 60
+
+ DownloadRateLimitNum = 10
+ DownloadRateLimitDuration int64 = 60
+
+ CriticalRateLimitNum = 20
+ CriticalRateLimitDuration int64 = 20 * 60
+)
+
+var RateLimitKeyExpirationDuration = 20 * time.Minute
diff --git a/common/constants.go b/common/constants.go
index 9ee791df..ccaa3560 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -1,114 +1,9 @@
package common
-import (
- "os"
- "strconv"
- "sync"
- "time"
-
- "github.com/google/uuid"
-)
+import "time"
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
-var SystemName = "One API"
-var ServerAddress = "http://localhost:3000"
-var Footer = ""
-var Logo = ""
-var TopUpLink = ""
-var ChatLink = ""
-var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
-var DisplayInCurrencyEnabled = true
-var DisplayTokenStatEnabled = true
-
-// Any options with "Secret", "Token" in its key won't be return by GetOptions
-
-var SessionSecret = uuid.New().String()
-
-var OptionMap map[string]string
-var OptionMapRWMutex sync.RWMutex
-
-var ItemsPerPage = 10
-var MaxRecentItems = 100
-
-var PasswordLoginEnabled = true
-var PasswordRegisterEnabled = true
-var EmailVerificationEnabled = false
-var GitHubOAuthEnabled = false
-var WeChatAuthEnabled = false
-var TurnstileCheckEnabled = false
-var RegisterEnabled = true
-
-var EmailDomainRestrictionEnabled = false
-var EmailDomainWhitelist = []string{
- "gmail.com",
- "163.com",
- "126.com",
- "qq.com",
- "outlook.com",
- "hotmail.com",
- "icloud.com",
- "yahoo.com",
- "foxmail.com",
-}
-
-var DebugEnabled = os.Getenv("DEBUG") == "true"
-var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
-
-var LogConsumeEnabled = true
-
-var SMTPServer = ""
-var SMTPPort = 587
-var SMTPAccount = ""
-var SMTPFrom = ""
-var SMTPToken = ""
-
-var GitHubClientId = ""
-var GitHubClientSecret = ""
-
-var WeChatServerAddress = ""
-var WeChatServerToken = ""
-var WeChatAccountQRCodeImageURL = ""
-
-var TurnstileSiteKey = ""
-var TurnstileSecretKey = ""
-
-var QuotaForNewUser = 0
-var QuotaForInviter = 0
-var QuotaForInvitee = 0
-var ChannelDisableThreshold = 5.0
-var AutomaticDisableChannelEnabled = false
-var AutomaticEnableChannelEnabled = false
-var QuotaRemindThreshold = 1000
-var PreConsumedQuota = 500
-var ApproximateTokenEnabled = false
-var RetryTimes = 0
-
-var RootUserEmail = ""
-
-var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
-
-var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
-var RequestInterval = time.Duration(requestInterval) * time.Second
-
-var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
-
-var BatchUpdateEnabled = false
-var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
-
-var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
-
-var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
-
-var Theme = GetOrDefaultString("THEME", "default")
-var ValidThemes = map[string]bool{
- "default": true,
- "berry": true,
-}
-
-const (
- RequestIdKey = "X-Oneapi-Request-Id"
-)
const (
RoleGuestUser = 0
@@ -117,34 +12,6 @@ const (
RoleRootUser = 100
)
-var (
- FileUploadPermission = RoleGuestUser
- FileDownloadPermission = RoleGuestUser
- ImageUploadPermission = RoleGuestUser
- ImageDownloadPermission = RoleGuestUser
-)
-
-// All duration's unit is seconds
-// Shouldn't larger then RateLimitKeyExpirationDuration
-var (
- GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
- GlobalApiRateLimitDuration int64 = 3 * 60
-
- GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
- GlobalWebRateLimitDuration int64 = 3 * 60
-
- UploadRateLimitNum = 10
- UploadRateLimitDuration int64 = 60
-
- DownloadRateLimitNum = 10
- DownloadRateLimitDuration int64 = 60
-
- CriticalRateLimitNum = 20
- CriticalRateLimitDuration int64 = 20 * 60
-)
-
-var RateLimitKeyExpirationDuration = 20 * time.Minute
-
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
@@ -196,32 +63,42 @@ const (
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
+ ChannelTypeMoonshot = 25
)
var ChannelBaseURLs = []string{
- "", // 0
- "https://api.openai.com", // 1
- "https://oa.api2d.net", // 2
- "", // 3
- "https://api.closeai-proxy.xyz", // 4
- "https://api.openai-sb.com", // 5
- "https://api.openaimax.com", // 6
- "https://api.ohmygpt.com", // 7
- "", // 8
- "https://api.caipacity.com", // 9
- "https://api.aiproxy.io", // 10
- "", // 11
- "https://api.api2gpt.com", // 12
- "https://api.aigc2d.com", // 13
- "https://api.anthropic.com", // 14
- "https://aip.baidubce.com", // 15
- "https://open.bigmodel.cn", // 16
- "https://dashscope.aliyuncs.com", // 17
- "", // 18
- "https://ai.360.cn", // 19
- "https://openrouter.ai/api", // 20
- "https://api.aiproxy.io", // 21
- "https://fastgpt.run/api/openapi", // 22
- "https://hunyuan.cloud.tencent.com", //23
- "", //24
+ "", // 0
+ "https://api.openai.com", // 1
+ "https://oa.api2d.net", // 2
+ "", // 3
+ "https://api.closeai-proxy.xyz", // 4
+ "https://api.openai-sb.com", // 5
+ "https://api.openaimax.com", // 6
+ "https://api.ohmygpt.com", // 7
+ "", // 8
+ "https://api.caipacity.com", // 9
+ "https://api.aiproxy.io", // 10
+ "https://generativelanguage.googleapis.com", // 11
+ "https://api.api2gpt.com", // 12
+ "https://api.aigc2d.com", // 13
+ "https://api.anthropic.com", // 14
+ "https://aip.baidubce.com", // 15
+ "https://open.bigmodel.cn", // 16
+ "https://dashscope.aliyuncs.com", // 17
+ "", // 18
+ "https://ai.360.cn", // 19
+ "https://openrouter.ai/api", // 20
+ "https://api.aiproxy.io", // 21
+ "https://fastgpt.run/api/openapi", // 22
+ "https://hunyuan.cloud.tencent.com", // 23
+ "https://generativelanguage.googleapis.com", // 24
+ "https://api.moonshot.cn", // 25
}
+
+const (
+ ConfigKeyPrefix = "cfg_"
+
+ ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
+ ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
+ ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
+)
diff --git a/common/database.go b/common/database.go
index 76f2cd55..9b52a0d5 100644
--- a/common/database.go
+++ b/common/database.go
@@ -1,7 +1,9 @@
package common
+import "github.com/songquanpeng/one-api/common/helper"
+
var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"
-var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)
+var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000)
diff --git a/common/email.go b/common/email.go
index b915f0f9..2689da6a 100644
--- a/common/email.go
+++ b/common/email.go
@@ -5,19 +5,20 @@ import (
"crypto/tls"
"encoding/base64"
"fmt"
+ "github.com/songquanpeng/one-api/common/config"
"net/smtp"
"strings"
"time"
)
func SendEmail(subject string, receiver string, content string) error {
- if SMTPFrom == "" { // for compatibility
- SMTPFrom = SMTPAccount
+ if config.SMTPFrom == "" { // for compatibility
+ config.SMTPFrom = config.SMTPAccount
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom
- parts := strings.Split(SMTPFrom, "@")
+ parts := strings.Split(config.SMTPFrom, "@")
var domain string
if len(parts) > 1 {
domain = parts[1]
@@ -36,21 +37,21 @@ func SendEmail(subject string, receiver string, content string) error {
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
- receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
- auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
- addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
+ receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
+ auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
+ addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";")
- if SMTPPort == 465 {
+ if config.SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
- ServerName: SMTPServer,
+ ServerName: config.SMTPServer,
}
- conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
+ conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
if err != nil {
return err
}
- client, err := smtp.NewClient(conn, SMTPServer)
+ client, err := smtp.NewClient(conn, config.SMTPServer)
if err != nil {
return err
}
@@ -58,7 +59,7 @@ func SendEmail(subject string, receiver string, content string) error {
if err = client.Auth(auth); err != nil {
return err
}
- if err = client.Mail(SMTPFrom); err != nil {
+ if err = client.Mail(config.SMTPFrom); err != nil {
return err
}
receiverEmails := strings.Split(receiver, ";")
@@ -80,7 +81,7 @@ func SendEmail(subject string, receiver string, content string) error {
return err
}
} else {
- err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
+ err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail)
}
return err
}
diff --git a/common/embed-file-system.go b/common/embed-file-system.go
index 3ea02cf8..7c0e4b4e 100644
--- a/common/embed-file-system.go
+++ b/common/embed-file-system.go
@@ -15,10 +15,7 @@ type embedFileSystem struct {
func (e embedFileSystem) Exists(prefix string, path string) bool {
_, err := e.Open(path)
- if err != nil {
- return false
- }
- return true
+ return err == nil
}
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
diff --git a/common/group-ratio.go b/common/group-ratio.go
index 1ec73c78..2de6e810 100644
--- a/common/group-ratio.go
+++ b/common/group-ratio.go
@@ -1,6 +1,9 @@
package common
-import "encoding/json"
+import (
+ "encoding/json"
+ "github.com/songquanpeng/one-api/common/logger"
+)
var GroupRatio = map[string]float64{
"default": 1,
@@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
if err != nil {
- SysError("error marshalling model ratio: " + err.Error())
+ logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error {
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
if !ok {
- SysError("group ratio not found: " + name)
+ logger.SysError("group ratio not found: " + name)
return 1
}
return ratio
diff --git a/common/helper/helper.go b/common/helper/helper.go
new file mode 100644
index 00000000..a0d88ec2
--- /dev/null
+++ b/common/helper/helper.go
@@ -0,0 +1,224 @@
+package helper
+
+import (
+ "fmt"
+ "github.com/google/uuid"
+ "github.com/songquanpeng/one-api/common/logger"
+ "html/template"
+ "log"
+ "math/rand"
+ "net"
+ "os"
+ "os/exec"
+ "runtime"
+ "strconv"
+ "strings"
+ "time"
+)
+
+func OpenBrowser(url string) {
+ var err error
+
+ switch runtime.GOOS {
+ case "linux":
+ err = exec.Command("xdg-open", url).Start()
+ case "windows":
+ err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
+ case "darwin":
+ err = exec.Command("open", url).Start()
+ }
+ if err != nil {
+ log.Println(err)
+ }
+}
+
+func GetIp() (ip string) {
+ ips, err := net.InterfaceAddrs()
+ if err != nil {
+ log.Println(err)
+ return ip
+ }
+
+ for _, a := range ips {
+ if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
+ if ipNet.IP.To4() != nil {
+ ip = ipNet.IP.String()
+ if strings.HasPrefix(ip, "10") {
+ return
+ }
+ if strings.HasPrefix(ip, "172") {
+ return
+ }
+ if strings.HasPrefix(ip, "192.168") {
+ return
+ }
+ ip = ""
+ }
+ }
+ }
+ return
+}
+
+var sizeKB = 1024
+var sizeMB = sizeKB * 1024
+var sizeGB = sizeMB * 1024
+
+func Bytes2Size(num int64) string {
+ numStr := ""
+ unit := "B"
+ if num/int64(sizeGB) > 1 {
+ numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
+ unit = "GB"
+ } else if num/int64(sizeMB) > 1 {
+ numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
+ unit = "MB"
+ } else if num/int64(sizeKB) > 1 {
+ numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
+ unit = "KB"
+ } else {
+ numStr = fmt.Sprintf("%d", num)
+ }
+ return numStr + " " + unit
+}
+
+func Seconds2Time(num int) (time string) {
+ if num/31104000 > 0 {
+ time += strconv.Itoa(num/31104000) + " 年 "
+ num %= 31104000
+ }
+ if num/2592000 > 0 {
+ time += strconv.Itoa(num/2592000) + " 个月 "
+ num %= 2592000
+ }
+ if num/86400 > 0 {
+ time += strconv.Itoa(num/86400) + " 天 "
+ num %= 86400
+ }
+ if num/3600 > 0 {
+ time += strconv.Itoa(num/3600) + " 小时 "
+ num %= 3600
+ }
+ if num/60 > 0 {
+ time += strconv.Itoa(num/60) + " 分钟 "
+ num %= 60
+ }
+ time += strconv.Itoa(num) + " 秒"
+ return
+}
+
+func Interface2String(inter interface{}) string {
+ switch inter := inter.(type) {
+ case string:
+ return inter
+ case int:
+ return fmt.Sprintf("%d", inter)
+ case float64:
+ return fmt.Sprintf("%f", inter)
+ }
+ return "Not Implemented"
+}
+
+func UnescapeHTML(x string) interface{} {
+ return template.HTML(x)
+}
+
+func IntMax(a int, b int) int {
+ if a >= b {
+ return a
+ } else {
+ return b
+ }
+}
+
+func GetUUID() string {
+ code := uuid.New().String()
+ code = strings.Replace(code, "-", "", -1)
+ return code
+}
+
+const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+
+func init() {
+ rand.Seed(time.Now().UnixNano())
+}
+
+func GenerateKey() string {
+ rand.Seed(time.Now().UnixNano())
+ key := make([]byte, 48)
+ for i := 0; i < 16; i++ {
+ key[i] = keyChars[rand.Intn(len(keyChars))]
+ }
+ uuid_ := GetUUID()
+ for i := 0; i < 32; i++ {
+ c := uuid_[i]
+ if i%2 == 0 && c >= 'a' && c <= 'z' {
+ c = c - 'a' + 'A'
+ }
+ key[i+16] = c
+ }
+ return string(key)
+}
+
+func GetRandomString(length int) string {
+ rand.Seed(time.Now().UnixNano())
+ key := make([]byte, length)
+ for i := 0; i < length; i++ {
+ key[i] = keyChars[rand.Intn(len(keyChars))]
+ }
+ return string(key)
+}
+
+func GetTimestamp() int64 {
+ return time.Now().Unix()
+}
+
+func GetTimeString() string {
+ now := time.Now()
+ return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
+}
+
+func Max(a int, b int) int {
+ if a >= b {
+ return a
+ } else {
+ return b
+ }
+}
+
+func GetOrDefaultEnvInt(env string, defaultValue int) int {
+ if env == "" || os.Getenv(env) == "" {
+ return defaultValue
+ }
+ num, err := strconv.Atoi(os.Getenv(env))
+ if err != nil {
+ logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
+ return defaultValue
+ }
+ return num
+}
+
+func GetOrDefaultEnvString(env string, defaultValue string) string {
+ if env == "" || os.Getenv(env) == "" {
+ return defaultValue
+ }
+ return os.Getenv(env)
+}
+
+func AssignOrDefault(value string, defaultValue string) string {
+ if len(value) != 0 {
+ return value
+ }
+ return defaultValue
+}
+
+func MessageWithRequestId(message string, id string) string {
+ return fmt.Sprintf("%s (request id: %s)", message, id)
+}
+
+func String2Int(str string) int {
+ num, err := strconv.Atoi(str)
+ if err != nil {
+ return 0
+ }
+ return num
+}
diff --git a/common/image/image_test.go b/common/image/image_test.go
index 8e47b109..15ed78bc 100644
--- a/common/image/image_test.go
+++ b/common/image/image_test.go
@@ -12,7 +12,7 @@ import (
"strings"
"testing"
- img "one-api/common/image"
+ img "github.com/songquanpeng/one-api/common/image"
"github.com/stretchr/testify/assert"
_ "golang.org/x/image/webp"
diff --git a/common/init.go b/common/init.go
index 12df5f51..b392bfee 100644
--- a/common/init.go
+++ b/common/init.go
@@ -3,6 +3,8 @@ package common
import (
"flag"
"fmt"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
"log"
"os"
"path/filepath"
@@ -37,9 +39,9 @@ func init() {
if os.Getenv("SESSION_SECRET") != "" {
if os.Getenv("SESSION_SECRET") == "random_string" {
- SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
+ logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
- SessionSecret = os.Getenv("SESSION_SECRET")
+ config.SessionSecret = os.Getenv("SESSION_SECRET")
}
}
if os.Getenv("SQLITE_PATH") != "" {
@@ -57,5 +59,6 @@ func init() {
log.Fatal(err)
}
}
+ logger.LogDir = *LogDir
}
}
diff --git a/common/logger/constants.go b/common/logger/constants.go
new file mode 100644
index 00000000..78d32062
--- /dev/null
+++ b/common/logger/constants.go
@@ -0,0 +1,7 @@
+package logger
+
+const (
+ RequestIdKey = "X-Oneapi-Request-Id"
+)
+
+var LogDir string
diff --git a/common/logger.go b/common/logger/logger.go
similarity index 76%
rename from common/logger.go
rename to common/logger/logger.go
index 61627217..f970ee61 100644
--- a/common/logger.go
+++ b/common/logger/logger.go
@@ -1,4 +1,4 @@
-package common
+package logger
import (
"context"
@@ -25,7 +25,7 @@ var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
- if *LogDir != "" {
+ if LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
@@ -35,7 +35,7 @@ func SetupLogger() {
setupLogLock.Unlock()
setupLogWorking = false
}()
- logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
+ logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
@@ -55,18 +55,30 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
-func LogInfo(ctx context.Context, msg string) {
+func Info(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
-func LogWarn(ctx context.Context, msg string) {
+func Warn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
-func LogError(ctx context.Context, msg string) {
+func Error(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
+func Infof(ctx context.Context, format string, a ...any) {
+ Info(ctx, fmt.Sprintf(format, a...))
+}
+
+func Warnf(ctx context.Context, format string, a ...any) {
+ Warn(ctx, fmt.Sprintf(format, a...))
+}
+
+func Errorf(ctx context.Context, format string, a ...any) {
+ Error(ctx, fmt.Sprintf(format, a...))
+}
+
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
@@ -90,11 +102,3 @@ func FatalLog(v ...any) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
-
-func LogQuota(quota int) string {
- if DisplayInCurrencyEnabled {
- return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
- } else {
- return fmt.Sprintf("%d 点额度", quota)
- }
-}
diff --git a/common/model-ratio.go b/common/model-ratio.go
index 35a96397..c4788f2c 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -2,6 +2,7 @@ package common
import (
"encoding/json"
+ "github.com/songquanpeng/one-api/common/logger"
"strings"
"time"
)
@@ -29,6 +30,12 @@ var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-3": 4000,
}
+const (
+ USD2RMB = 7
+ USD = 500 // $0.002 = 1 -> $1 = 500
+ RMB = USD / USD2RMB
+)
+
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
@@ -37,53 +44,62 @@ var DalleImagePromptLengthLimitations = map[string]int{
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
- "gpt-4": 15,
- "gpt-4-0314": 15,
- "gpt-4-0613": 15,
- "gpt-4-32k": 30,
- "gpt-4-32k-0314": 30,
- "gpt-4-32k-0613": 30,
- "gpt-4-1106-preview": 5, // $0.01 / 1K tokens
- "gpt-4-gizmo": 30, // $0.06 / 1K tokens
- "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
- "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
- "gpt-3.5-turbo-0301": 0.75,
- "gpt-3.5-turbo-0613": 0.75,
- "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
- "gpt-3.5-turbo-16k-0613": 1.5,
- "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
- "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
- "davinci-002": 1, // $0.002 / 1K tokens
- "babbage-002": 0.2, // $0.0004 / 1K tokens
- "text-ada-001": 0.2,
- "text-babbage-001": 0.25,
- "text-curie-001": 1,
- "text-davinci-002": 10,
- "text-davinci-003": 10,
- "text-davinci-edit-001": 10,
- "code-davinci-edit-001": 10,
- "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
- "tts-1": 7.5, // $0.015 / 1K characters
- "tts-1-1106": 7.5,
- "tts-1-hd": 15, // $0.030 / 1K characters
- "tts-1-hd-1106": 15,
- "davinci": 10,
- "curie": 10,
- "babbage": 10,
- "ada": 10,
- "text-embedding-ada-002": 0.05,
- "text-search-ada-doc-001": 10,
- "text-moderation-stable": 0.1,
- "text-moderation-latest": 0.1,
- "dall-e-2": 8, // $0.016 - $0.020 / image
- "dall-e-3": 20, // $0.040 - $0.120 / image
- "claude-instant-1": 0.815, // $1.63 / 1M tokens
- "claude-2": 5.51, // $11.02 / 1M tokens
- "claude-2.0": 5.51, // $11.02 / 1M tokens
- "claude-2.1": 5.51, // $11.02 / 1M tokens
- "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
- "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
- "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
+
+ // https://openai.com/pricing
+ "gpt-4": 15,
+ "gpt-4-0314": 15,
+ "gpt-4-0613": 15,
+ "gpt-4-gizmo": 15,
+ "gpt-4-32k": 30,
+ "gpt-4-32k-0314": 30,
+ "gpt-4-32k-0613": 30,
+ "gpt-4-1106-preview": 5, // $0.01 / 1K tokens
+ "gpt-4-0125-preview": 5, // $0.01 / 1K tokens
+ "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
+ "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
+ "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
+ "gpt-3.5-turbo-0301": 0.75,
+ "gpt-3.5-turbo-0613": 0.75,
+ "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
+ "gpt-3.5-turbo-16k-0613": 1.5,
+ "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
+ "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
+ "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
+ "davinci-002": 1, // $0.002 / 1K tokens
+ "babbage-002": 0.2, // $0.0004 / 1K tokens
+ "text-ada-001": 0.2,
+ "text-babbage-001": 0.25,
+ "text-curie-001": 1,
+ "text-davinci-002": 10,
+ "text-davinci-003": 10,
+ "text-davinci-edit-001": 10,
+ "code-davinci-edit-001": 10,
+ "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+ "tts-1": 7.5, // $0.015 / 1K characters
+ "tts-1-1106": 7.5,
+ "tts-1-hd": 15, // $0.030 / 1K characters
+ "tts-1-hd-1106": 15,
+ "davinci": 10,
+ "curie": 10,
+ "babbage": 10,
+ "ada": 10,
+ "text-embedding-ada-002": 0.05,
+ "text-embedding-3-small": 0.01,
+ "text-embedding-3-large": 0.065,
+ "text-search-ada-doc-001": 10,
+ "text-moderation-stable": 0.1,
+ "text-moderation-latest": 0.1,
+ "dall-e-2": 8, // $0.016 - $0.020 / image
+ "dall-e-3": 20, // $0.040 - $0.120 / image
+ "claude-instant-1": 0.815, // $1.63 / 1M tokens
+ "claude-2": 5.51, // $11.02 / 1M tokens
+ "claude-2.0": 5.51, // $11.02 / 1M tokens
+ "claude-2.1": 5.51, // $11.02 / 1M tokens
+ // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
+ "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
+ "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
+ "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
+ "ERNIE-Bot-8k": 0.024 * RMB,
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
@@ -98,17 +114,27 @@ var ModelRatio = map[string]float64{
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
+ "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
+ "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
+ "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
+ "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
+ "ChatStd": 0.01 * RMB,
+ "ChatPro": 0.1 * RMB,
+ // https://platform.moonshot.cn/pricing
+ "moonshot-v1-8k": 0.012 * RMB,
+ "moonshot-v1-32k": 0.024 * RMB,
+ "moonshot-v1-128k": 0.06 * RMB,
}
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
- SysError("error marshalling model ratio: " + err.Error())
+ logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -127,14 +153,37 @@ func GetModelRatio(name string) float64 {
return ModelRatio["gpt-4-gizmo"]
}
if !ok {
- SysError("model ratio not found: " + name)
+ logger.SysError("model ratio not found: " + name)
return 30
}
return ratio
}
+var CompletionRatio = map[string]float64{}
+
+func CompletionRatio2JSONString() string {
+ jsonBytes, err := json.Marshal(CompletionRatio)
+ if err != nil {
+ logger.SysError("error marshalling completion ratio: " + err.Error())
+ }
+ return string(jsonBytes)
+}
+
+func UpdateCompletionRatioByJSONString(jsonStr string) error {
+ CompletionRatio = make(map[string]float64)
+ return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
+}
+
func GetCompletionRatio(name string) float64 {
+ if ratio, ok := CompletionRatio[name]; ok {
+ return ratio
+ }
if strings.HasPrefix(name, "gpt-3.5") {
+ if strings.HasSuffix(name, "0125") {
+ // https://openai.com/blog/new-embedding-models-and-api-updates
+ // Updated GPT-3.5 Turbo model and lower pricing
+ return 3
+ }
if strings.HasSuffix(name, "1106") {
return 2
}
diff --git a/common/redis.go b/common/redis.go
index 12c477b8..f3205567 100644
--- a/common/redis.go
+++ b/common/redis.go
@@ -3,6 +3,7 @@ package common
import (
"context"
"github.com/go-redis/redis/v8"
+ "github.com/songquanpeng/one-api/common/logger"
"os"
"time"
)
@@ -14,18 +15,18 @@ var RedisEnabled = true
func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" {
RedisEnabled = false
- SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
+ logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
return nil
}
if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false
- SysLog("SYNC_FREQUENCY not set, Redis is disabled")
+ logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil
}
- SysLog("Redis is enabled")
+ logger.SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
- FatalLog("failed to parse Redis connection string: " + err.Error())
+ logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
RDB = redis.NewClient(opt)
@@ -34,7 +35,7 @@ func InitRedisClient() (err error) {
_, err = RDB.Ping(ctx).Result()
if err != nil {
- FatalLog("Redis ping test failed: " + err.Error())
+ logger.FatalLog("Redis ping test failed: " + err.Error())
}
return err
}
@@ -42,7 +43,7 @@ func InitRedisClient() (err error) {
func ParseRedisOption() *redis.Options {
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
- FatalLog("failed to parse Redis connection string: " + err.Error())
+ logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
return opt
}
diff --git a/common/utils.go b/common/utils.go
index 9a7038e2..24615225 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -2,215 +2,13 @@ package common
import (
"fmt"
- "github.com/google/uuid"
- "html/template"
- "log"
- "math/rand"
- "net"
- "os"
- "os/exec"
- "runtime"
- "strconv"
- "strings"
- "time"
+ "github.com/songquanpeng/one-api/common/config"
)
-func OpenBrowser(url string) {
- var err error
-
- switch runtime.GOOS {
- case "linux":
- err = exec.Command("xdg-open", url).Start()
- case "windows":
- err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
- case "darwin":
- err = exec.Command("open", url).Start()
- }
- if err != nil {
- log.Println(err)
- }
-}
-
-func GetIp() (ip string) {
- ips, err := net.InterfaceAddrs()
- if err != nil {
- log.Println(err)
- return ip
- }
-
- for _, a := range ips {
- if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
- if ipNet.IP.To4() != nil {
- ip = ipNet.IP.String()
- if strings.HasPrefix(ip, "10") {
- return
- }
- if strings.HasPrefix(ip, "172") {
- return
- }
- if strings.HasPrefix(ip, "192.168") {
- return
- }
- ip = ""
- }
- }
- }
- return
-}
-
-var sizeKB = 1024
-var sizeMB = sizeKB * 1024
-var sizeGB = sizeMB * 1024
-
-func Bytes2Size(num int64) string {
- numStr := ""
- unit := "B"
- if num/int64(sizeGB) > 1 {
- numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
- unit = "GB"
- } else if num/int64(sizeMB) > 1 {
- numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
- unit = "MB"
- } else if num/int64(sizeKB) > 1 {
- numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
- unit = "KB"
+func LogQuota(quota int) string {
+ if config.DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit)
} else {
- numStr = fmt.Sprintf("%d", num)
- }
- return numStr + " " + unit
-}
-
-func Seconds2Time(num int) (time string) {
- if num/31104000 > 0 {
- time += strconv.Itoa(num/31104000) + " 年 "
- num %= 31104000
- }
- if num/2592000 > 0 {
- time += strconv.Itoa(num/2592000) + " 个月 "
- num %= 2592000
- }
- if num/86400 > 0 {
- time += strconv.Itoa(num/86400) + " 天 "
- num %= 86400
- }
- if num/3600 > 0 {
- time += strconv.Itoa(num/3600) + " 小时 "
- num %= 3600
- }
- if num/60 > 0 {
- time += strconv.Itoa(num/60) + " 分钟 "
- num %= 60
- }
- time += strconv.Itoa(num) + " 秒"
- return
-}
-
-func Interface2String(inter interface{}) string {
- switch inter.(type) {
- case string:
- return inter.(string)
- case int:
- return fmt.Sprintf("%d", inter.(int))
- case float64:
- return fmt.Sprintf("%f", inter.(float64))
- }
- return "Not Implemented"
-}
-
-func UnescapeHTML(x string) interface{} {
- return template.HTML(x)
-}
-
-func IntMax(a int, b int) int {
- if a >= b {
- return a
- } else {
- return b
+ return fmt.Sprintf("%d 点额度", quota)
}
}
-
-func GetUUID() string {
- code := uuid.New().String()
- code = strings.Replace(code, "-", "", -1)
- return code
-}
-
-const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
-
-func init() {
- rand.Seed(time.Now().UnixNano())
-}
-
-func GenerateKey() string {
- rand.Seed(time.Now().UnixNano())
- key := make([]byte, 48)
- for i := 0; i < 16; i++ {
- key[i] = keyChars[rand.Intn(len(keyChars))]
- }
- uuid_ := GetUUID()
- for i := 0; i < 32; i++ {
- c := uuid_[i]
- if i%2 == 0 && c >= 'a' && c <= 'z' {
- c = c - 'a' + 'A'
- }
- key[i+16] = c
- }
- return string(key)
-}
-
-func GetRandomString(length int) string {
- rand.Seed(time.Now().UnixNano())
- key := make([]byte, length)
- for i := 0; i < length; i++ {
- key[i] = keyChars[rand.Intn(len(keyChars))]
- }
- return string(key)
-}
-
-func GetTimestamp() int64 {
- return time.Now().Unix()
-}
-
-func GetTimeString() string {
- now := time.Now()
- return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
-}
-
-func Max(a int, b int) int {
- if a >= b {
- return a
- } else {
- return b
- }
-}
-
-func GetOrDefault(env string, defaultValue int) int {
- if env == "" || os.Getenv(env) == "" {
- return defaultValue
- }
- num, err := strconv.Atoi(os.Getenv(env))
- if err != nil {
- SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
- return defaultValue
- }
- return num
-}
-
-func GetOrDefaultString(env string, defaultValue string) string {
- if env == "" || os.Getenv(env) == "" {
- return defaultValue
- }
- return os.Getenv(env)
-}
-
-func MessageWithRequestId(message string, id string) string {
- return fmt.Sprintf("%s (request id: %s)", message, id)
-}
-
-func String2Int(str string) int {
- num, err := strconv.Atoi(str)
- if err != nil {
- return 0
- }
- return num
-}
diff --git a/controller/billing.go b/controller/billing.go
index e27fd614..7317913d 100644
--- a/controller/billing.go
+++ b/controller/billing.go
@@ -2,9 +2,9 @@ package controller
import (
"github.com/gin-gonic/gin"
- "one-api/common"
- "one-api/model"
- "one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/model"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func GetSubscription(c *gin.Context) {
@@ -13,7 +13,7 @@ func GetSubscription(c *gin.Context) {
var err error
var token *model.Token
var expiredTime int64
- if common.DisplayTokenStatEnabled {
+ if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime
@@ -22,13 +22,15 @@ func GetSubscription(c *gin.Context) {
} else {
userId := c.GetInt("id")
remainQuota, err = model.GetUserQuota(userId)
- usedQuota, err = model.GetUserUsedQuota(userId)
+ if err != nil {
+ usedQuota, err = model.GetUserUsedQuota(userId)
+ }
}
if expiredTime <= 0 {
expiredTime = 0
}
if err != nil {
- Error := openai.Error{
+ Error := relaymodel.Error{
Message: err.Error(),
Type: "upstream_error",
}
@@ -39,8 +41,8 @@ func GetSubscription(c *gin.Context) {
}
quota := remainQuota + usedQuota
amount := float64(quota)
- if common.DisplayInCurrencyEnabled {
- amount /= common.QuotaPerUnit
+ if config.DisplayInCurrencyEnabled {
+ amount /= config.QuotaPerUnit
}
if token != nil && token.UnlimitedQuota {
amount = 100000000
@@ -61,7 +63,7 @@ func GetUsage(c *gin.Context) {
var quota int
var err error
var token *model.Token
- if common.DisplayTokenStatEnabled {
+ if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
quota = token.UsedQuota
@@ -70,7 +72,7 @@ func GetUsage(c *gin.Context) {
quota, err = model.GetUserUsedQuota(userId)
}
if err != nil {
- Error := openai.Error{
+ Error := relaymodel.Error{
Message: err.Error(),
Type: "one_api_error",
}
@@ -80,8 +82,8 @@ func GetUsage(c *gin.Context) {
return
}
amount := float64(quota)
- if common.DisplayInCurrencyEnabled {
- amount /= common.QuotaPerUnit
+ if config.DisplayInCurrencyEnabled {
+ amount /= config.QuotaPerUnit
}
usage := OpenAIUsageResponse{
Object: "list",
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 29346cde..abeab26a 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -4,11 +4,13 @@ import (
"encoding/json"
"errors"
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
+ "github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
- "one-api/common"
- "one-api/model"
- "one-api/relay/util"
"strconv"
"time"
@@ -314,7 +316,7 @@ func updateAllChannelsBalance() error {
disableChannel(channel.Id, channel.Name, "余额不足")
}
}
- time.Sleep(common.RequestInterval)
+ time.Sleep(config.RequestInterval)
}
return nil
}
@@ -339,8 +341,8 @@ func UpdateAllChannelsBalance(c *gin.Context) {
func AutomaticallyUpdateChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
- common.SysLog("updating all channels")
+ logger.SysLog("updating all channels")
_ = updateAllChannelsBalance()
- common.SysLog("channels update done")
+ logger.SysLog("channels update done")
}
}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index f64f0ee3..b498f4f1 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -5,12 +5,18 @@ import (
"encoding/json"
"errors"
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/helper"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
- "one-api/common"
- "one-api/model"
- "one-api/relay/channel/openai"
- "one-api/relay/util"
+ "net/http/httptest"
+ "net/url"
"strconv"
"sync"
"time"
@@ -18,87 +24,13 @@ import (
"github.com/gin-gonic/gin"
)
-func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) {
- switch channel.Type {
- case common.ChannelTypePaLM:
- fallthrough
- case common.ChannelTypeGemini:
- fallthrough
- case common.ChannelTypeAnthropic:
- fallthrough
- case common.ChannelTypeBaidu:
- fallthrough
- case common.ChannelTypeZhipu:
- fallthrough
- case common.ChannelTypeAli:
- fallthrough
- case common.ChannelType360:
- fallthrough
- case common.ChannelTypeXunfei:
- return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
- case common.ChannelTypeAzure:
- request.Model = "gpt-35-turbo"
- defer func() {
- if err != nil {
- err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
- }
- }()
- default:
- request.Model = "gpt-3.5-turbo"
- }
- requestURL := common.ChannelBaseURLs[channel.Type]
- if channel.Type == common.ChannelTypeAzure {
- requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
- } else {
- if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
- requestURL = baseURL
- }
-
- requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
- }
- jsonData, err := json.Marshal(request)
- if err != nil {
- return err, nil
- }
- req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
- if err != nil {
- return err, nil
- }
- if channel.Type == common.ChannelTypeAzure {
- req.Header.Set("api-key", channel.Key)
- } else {
- req.Header.Set("Authorization", "Bearer "+channel.Key)
- }
- req.Header.Set("Content-Type", "application/json")
- resp, err := util.HTTPClient.Do(req)
- if err != nil {
- return err, nil
- }
- defer resp.Body.Close()
- var response openai.SlimTextResponse
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return err, nil
- }
- err = json.Unmarshal(body, &response)
- if err != nil {
- return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
- }
- if response.Usage.CompletionTokens == 0 {
- if response.Error.Message == "" {
- response.Error.Message = "补全 tokens 非预期返回 0"
- }
- return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
- }
- return nil, nil
-}
-
-func buildTestRequest() *openai.ChatRequest {
- testRequest := &openai.ChatRequest{
- Model: "", // this will be set later
+func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
+ testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 1,
+ Stream: false,
+ Model: "gpt-3.5-turbo",
}
- testMessage := openai.Message{
+ testMessage := relaymodel.Message{
Role: "user",
Content: "hi",
}
@@ -106,6 +38,65 @@ func buildTestRequest() *openai.ChatRequest {
return testRequest
}
+func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = &http.Request{
+ Method: "POST",
+ URL: &url.URL{Path: "/v1/chat/completions"},
+ Body: nil,
+ Header: make(http.Header),
+ }
+ c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set("channel", channel.Type)
+ c.Set("base_url", channel.GetBaseURL())
+ meta := util.GetRelayMeta(c)
+ apiType := constant.ChannelType2APIType(channel.Type)
+ adaptor := helper.GetAdaptor(apiType)
+ if adaptor == nil {
+ return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
+ }
+ adaptor.Init(meta)
+ modelName := adaptor.GetModelList()[0]
+ request := buildTestRequest()
+ request.Model = modelName
+ meta.OriginModelName, meta.ActualModelName = modelName, modelName
+ convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
+ if err != nil {
+ return err, nil
+ }
+ jsonData, err := json.Marshal(convertedRequest)
+ if err != nil {
+ return err, nil
+ }
+ requestBody := bytes.NewBuffer(jsonData)
+ c.Request.Body = io.NopCloser(requestBody)
+ resp, err := adaptor.DoRequest(c, meta, requestBody)
+ if err != nil {
+ return err, nil
+ }
+ if resp.StatusCode != http.StatusOK {
+ err := util.RelayErrorHandler(resp)
+ return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
+ }
+ usage, respErr := adaptor.DoResponse(c, resp, meta)
+ if respErr != nil {
+ return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
+ }
+ if usage == nil {
+ return errors.New("usage is nil"), nil
+ }
+ result := w.Result()
+ // print result.Body
+ respBody, err := io.ReadAll(result.Body)
+ if err != nil {
+ return err, nil
+ }
+ logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
+ return nil, nil
+}
+
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -123,9 +114,8 @@ func TestChannel(c *gin.Context) {
})
return
}
- testRequest := buildTestRequest()
tik := time.Now()
- err, _ = testChannel(channel, *testRequest)
+ err, _ = testChannel(channel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
@@ -150,12 +140,12 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func notifyRootUser(subject string, content string) {
- if common.RootUserEmail == "" {
- common.RootUserEmail = model.GetRootUserEmail()
+ if config.RootUserEmail == "" {
+ config.RootUserEmail = model.GetRootUserEmail()
}
- err := common.SendEmail(subject, common.RootUserEmail, content)
+ err := common.SendEmail(subject, config.RootUserEmail, content)
if err != nil {
- common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+ logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
@@ -176,8 +166,8 @@ func enableChannel(channelId int, channelName string) {
}
func testAllChannels(notify bool) error {
- if common.RootUserEmail == "" {
- common.RootUserEmail = model.GetRootUserEmail()
+ if config.RootUserEmail == "" {
+ config.RootUserEmail = model.GetRootUserEmail()
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
@@ -190,8 +180,7 @@ func testAllChannels(notify bool) error {
if err != nil {
return err
}
- testRequest := buildTestRequest()
- var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
+ var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
@@ -199,7 +188,7 @@ func testAllChannels(notify bool) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
- err, openaiErr := testChannel(channel, *testRequest)
+ err, openaiErr := testChannel(channel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold {
@@ -213,15 +202,15 @@ func testAllChannels(notify bool) error {
enableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
- time.Sleep(common.RequestInterval)
+ time.Sleep(config.RequestInterval)
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
- err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
+ err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
- common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+ logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
}()
@@ -247,8 +236,8 @@ func TestAllChannels(c *gin.Context) {
func AutomaticallyTestChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
- common.SysLog("testing all channels")
+ logger.SysLog("testing all channels")
_ = testAllChannels(false)
- common.SysLog("channel test finished")
+ logger.SysLog("channel test finished")
}
}
diff --git a/controller/channel.go b/controller/channel.go
index 904abc23..bdfa00d9 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -2,9 +2,10 @@ package controller
import (
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/model"
"net/http"
- "one-api/common"
- "one-api/model"
"strconv"
"strings"
)
@@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) {
if p < 0 {
p = 0
}
- channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
+ channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) {
})
return
}
- channel.CreatedTime = common.GetTimestamp()
+ channel.CreatedTime = helper.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
diff --git a/controller/github.go b/controller/github.go
index ee995379..7d7fa106 100644
--- a/controller/github.go
+++ b/controller/github.go
@@ -7,9 +7,12 @@ import (
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
"net/http"
- "one-api/common"
- "one-api/model"
"strconv"
"time"
)
@@ -30,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
- values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
+ values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
@@ -46,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
}
res, err := client.Do(req)
if err != nil {
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res.Body.Close()
@@ -62,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res2.Body.Close()
@@ -93,7 +96,7 @@ func GitHubOAuth(c *gin.Context) {
return
}
- if !common.GitHubOAuthEnabled {
+ if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
@@ -122,7 +125,7 @@ func GitHubOAuth(c *gin.Context) {
return
}
} else {
- if common.RegisterEnabled {
+ if config.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" {
user.DisplayName = githubUser.Name
@@ -160,7 +163,7 @@ func GitHubOAuth(c *gin.Context) {
}
func GitHubBind(c *gin.Context) {
- if !common.GitHubOAuthEnabled {
+ if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
@@ -216,7 +219,7 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
- state := common.GetRandomString(12)
+ state := helper.GetRandomString(12)
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
diff --git a/controller/group.go b/controller/group.go
index 2b2f6006..128a3527 100644
--- a/controller/group.go
+++ b/controller/group.go
@@ -2,13 +2,13 @@ package controller
import (
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
"net/http"
- "one-api/common"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
- for groupName, _ := range common.GroupRatio {
+ for groupName := range common.GroupRatio {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
diff --git a/controller/log.go b/controller/log.go
index 3265ce20..4e582982 100644
--- a/controller/log.go
+++ b/controller/log.go
@@ -1,12 +1,13 @@
package controller
import (
- "net/http"
- "one-api/common"
- "one-api/model"
- "strconv"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/model"
+ "net/http"
+ "strconv"
+
)
func GetAllLogs(c *gin.Context) {
@@ -21,7 +22,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
- logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
+ logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -77,7 +78,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
- logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
+ logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
diff --git a/controller/misc.go b/controller/misc.go
index 2bcbb41f..036bdbd1 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -3,9 +3,10 @@ package controller
import (
"encoding/json"
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/model"
"net/http"
- "one-api/common"
- "one-api/model"
"strings"
"github.com/gin-gonic/gin"
@@ -18,55 +19,55 @@ func GetStatus(c *gin.Context) {
"data": gin.H{
"version": common.Version,
"start_time": common.StartTime,
- "email_verification": common.EmailVerificationEnabled,
- "github_oauth": common.GitHubOAuthEnabled,
- "github_client_id": common.GitHubClientId,
- "system_name": common.SystemName,
- "logo": common.Logo,
- "footer_html": common.Footer,
- "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
- "wechat_login": common.WeChatAuthEnabled,
- "server_address": common.ServerAddress,
- "turnstile_check": common.TurnstileCheckEnabled,
- "turnstile_site_key": common.TurnstileSiteKey,
- "top_up_link": common.TopUpLink,
- "chat_link": common.ChatLink,
- "quota_per_unit": common.QuotaPerUnit,
- "display_in_currency": common.DisplayInCurrencyEnabled,
+ "email_verification": config.EmailVerificationEnabled,
+ "github_oauth": config.GitHubOAuthEnabled,
+ "github_client_id": config.GitHubClientId,
+ "system_name": config.SystemName,
+ "logo": config.Logo,
+ "footer_html": config.Footer,
+ "wechat_qrcode": config.WeChatAccountQRCodeImageURL,
+ "wechat_login": config.WeChatAuthEnabled,
+ "server_address": config.ServerAddress,
+ "turnstile_check": config.TurnstileCheckEnabled,
+ "turnstile_site_key": config.TurnstileSiteKey,
+ "top_up_link": config.TopUpLink,
+ "chat_link": config.ChatLink,
+ "quota_per_unit": config.QuotaPerUnit,
+ "display_in_currency": config.DisplayInCurrencyEnabled,
},
})
return
}
func GetNotice(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
+ config.OptionMapRWMutex.RLock()
+ defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": common.OptionMap["Notice"],
+ "data": config.OptionMap["Notice"],
})
return
}
func GetAbout(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
+ config.OptionMapRWMutex.RLock()
+ defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": common.OptionMap["About"],
+ "data": config.OptionMap["About"],
})
return
}
func GetHomePageContent(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
+ config.OptionMapRWMutex.RLock()
+ defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": common.OptionMap["HomePageContent"],
+ "data": config.OptionMap["HomePageContent"],
})
return
}
@@ -80,9 +81,9 @@ func SendEmailVerification(c *gin.Context) {
})
return
}
- if common.EmailDomainRestrictionEnabled {
+ if config.EmailDomainRestrictionEnabled {
allowed := false
- for _, domain := range common.EmailDomainWhitelist {
+ for _, domain := range config.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) {
allowed = true
break
@@ -105,10 +106,10 @@ func SendEmailVerification(c *gin.Context) {
}
code := common.GenerateVerificationCode(6)
common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
- subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
+ subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName)
content := fmt.Sprintf("
您好,你正在进行%s邮箱验证。
"+
"您的验证码为: %s
"+
- "验证码 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, code, common.VerificationValidMinutes)
+ "验证码 %d 分钟内有效,如果不是本人操作,请忽略。
", config.SystemName, code, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -142,12 +143,12 @@ func SendPasswordResetEmail(c *gin.Context) {
}
code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
- link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
- subject := fmt.Sprintf("%s密码重置", common.SystemName)
+ link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code)
+ subject := fmt.Sprintf("%s密码重置", config.SystemName)
content := fmt.Sprintf("您好,你正在进行%s密码重置。
"+
"点击 此处 进行密码重置。
"+
"如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s
"+
- "重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, link, link, common.VerificationValidMinutes)
+ "重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
", config.SystemName, link, link, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content)
if err != nil {
c.JSON(http.StatusOK, gin.H{
diff --git a/controller/model.go b/controller/model.go
index b7ec1b6a..f5760901 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -3,7 +3,11 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
- "one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/channel/ai360"
+ "github.com/songquanpeng/one-api/relay/channel/moonshot"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/helper"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -53,547 +57,46 @@ func init() {
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
- openAIModels = []OpenAIModels{
- {
- Id: "dall-e-2",
+ for i := 0; i < constant.APITypeDummy; i++ {
+ if i == constant.APITypeAIProxyLibrary {
+ continue
+ }
+ adaptor := helper.GetAdaptor(i)
+ channelName := adaptor.GetChannelName()
+ modelNames := adaptor.GetModelList()
+ for _, modelName := range modelNames {
+ openAIModels = append(openAIModels, OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: channelName,
+ Permission: permission,
+ Root: modelName,
+ Parent: nil,
+ })
+ }
+ }
+ for _, modelName := range ai360.ModelList {
+ openAIModels = append(openAIModels, OpenAIModels{
+ Id: modelName,
Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "dall-e-2",
- Parent: nil,
- },
- {
- Id: "dall-e-3",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "dall-e-3",
- Parent: nil,
- },
- {
- Id: "whisper-1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "whisper-1",
- Parent: nil,
- },
- {
- Id: "tts-1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "tts-1",
- Parent: nil,
- },
- {
- Id: "tts-1-1106",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "tts-1-1106",
- Parent: nil,
- },
- {
- Id: "tts-1-hd",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "tts-1-hd",
- Parent: nil,
- },
- {
- Id: "tts-1-hd-1106",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "tts-1-hd-1106",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-0301",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-0301",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-0613",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-0613",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-16k",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-16k",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-16k-0613",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-16k-0613",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-1106",
- Object: "model",
- Created: 1699593571,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-1106",
- Parent: nil,
- },
- {
- Id: "gpt-3.5-turbo-instruct",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-3.5-turbo-instruct",
- Parent: nil,
- },
- {
- Id: "gpt-4",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4",
- Parent: nil,
- },
- {
- Id: "gpt-4-0314",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-0314",
- Parent: nil,
- },
- {
- Id: "gpt-4-0613",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-0613",
- Parent: nil,
- },
- {
- Id: "gpt-4-32k",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-32k",
- Parent: nil,
- },
- {
- Id: "gpt-4-32k-0314",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-32k-0314",
- Parent: nil,
- },
- {
- Id: "gpt-4-32k-0613",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-32k-0613",
- Parent: nil,
- },
- {
- Id: "gpt-4-1106-preview",
- Object: "model",
- Created: 1699593571,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-1106-preview",
- Parent: nil,
- },
- {
- Id: "gpt-4-vision-preview",
- Object: "model",
- Created: 1699593571,
- OwnedBy: "openai",
- Permission: permission,
- Root: "gpt-4-vision-preview",
- Parent: nil,
- },
- {
- Id: "text-embedding-ada-002",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-embedding-ada-002",
- Parent: nil,
- },
- {
- Id: "text-davinci-003",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-davinci-003",
- Parent: nil,
- },
- {
- Id: "text-davinci-002",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-davinci-002",
- Parent: nil,
- },
- {
- Id: "text-curie-001",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-curie-001",
- Parent: nil,
- },
- {
- Id: "text-babbage-001",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-babbage-001",
- Parent: nil,
- },
- {
- Id: "text-ada-001",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-ada-001",
- Parent: nil,
- },
- {
- Id: "text-moderation-latest",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-moderation-latest",
- Parent: nil,
- },
- {
- Id: "text-moderation-stable",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-moderation-stable",
- Parent: nil,
- },
- {
- Id: "text-davinci-edit-001",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "text-davinci-edit-001",
- Parent: nil,
- },
- {
- Id: "code-davinci-edit-001",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "code-davinci-edit-001",
- Parent: nil,
- },
- {
- Id: "davinci-002",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "davinci-002",
- Parent: nil,
- },
- {
- Id: "babbage-002",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "openai",
- Permission: permission,
- Root: "babbage-002",
- Parent: nil,
- },
- {
- Id: "claude-instant-1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "anthropic",
- Permission: permission,
- Root: "claude-instant-1",
- Parent: nil,
- },
- {
- Id: "claude-2",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "anthropic",
- Permission: permission,
- Root: "claude-2",
- Parent: nil,
- },
- {
- Id: "claude-2.1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "anthropic",
- Permission: permission,
- Root: "claude-2.1",
- Parent: nil,
- },
- {
- Id: "claude-2.0",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "anthropic",
- Permission: permission,
- Root: "claude-2.0",
- Parent: nil,
- },
- {
- Id: "ERNIE-Bot",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "baidu",
- Permission: permission,
- Root: "ERNIE-Bot",
- Parent: nil,
- },
- {
- Id: "ERNIE-Bot-turbo",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "baidu",
- Permission: permission,
- Root: "ERNIE-Bot-turbo",
- Parent: nil,
- },
- {
- Id: "ERNIE-Bot-4",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "baidu",
- Permission: permission,
- Root: "ERNIE-Bot-4",
- Parent: nil,
- },
- {
- Id: "Embedding-V1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "baidu",
- Permission: permission,
- Root: "Embedding-V1",
- Parent: nil,
- },
- {
- Id: "PaLM-2",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "google palm",
- Permission: permission,
- Root: "PaLM-2",
- Parent: nil,
- },
- {
- Id: "gemini-pro",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "google gemini",
- Permission: permission,
- Root: "gemini-pro",
- Parent: nil,
- },
- {
- Id: "gemini-pro-vision",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "google gemini",
- Permission: permission,
- Root: "gemini-pro-vision",
- Parent: nil,
- },
- {
- Id: "chatglm_turbo",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "zhipu",
- Permission: permission,
- Root: "chatglm_turbo",
- Parent: nil,
- },
- {
- Id: "chatglm_pro",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "zhipu",
- Permission: permission,
- Root: "chatglm_pro",
- Parent: nil,
- },
- {
- Id: "chatglm_std",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "zhipu",
- Permission: permission,
- Root: "chatglm_std",
- Parent: nil,
- },
- {
- Id: "chatglm_lite",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "zhipu",
- Permission: permission,
- Root: "chatglm_lite",
- Parent: nil,
- },
- {
- Id: "qwen-turbo",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "ali",
- Permission: permission,
- Root: "qwen-turbo",
- Parent: nil,
- },
- {
- Id: "qwen-plus",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "ali",
- Permission: permission,
- Root: "qwen-plus",
- Parent: nil,
- },
- {
- Id: "qwen-max",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "ali",
- Permission: permission,
- Root: "qwen-max",
- Parent: nil,
- },
- {
- Id: "qwen-max-longcontext",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "ali",
- Permission: permission,
- Root: "qwen-max-longcontext",
- Parent: nil,
- },
- {
- Id: "text-embedding-v1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "ali",
- Permission: permission,
- Root: "text-embedding-v1",
- Parent: nil,
- },
- {
- Id: "SparkDesk",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "xunfei",
- Permission: permission,
- Root: "SparkDesk",
- Parent: nil,
- },
- {
- Id: "360GPT_S2_V9",
- Object: "model",
- Created: 1677649963,
+ Created: 1626777600,
OwnedBy: "360",
Permission: permission,
- Root: "360GPT_S2_V9",
+ Root: modelName,
Parent: nil,
- },
- {
- Id: "embedding-bert-512-v1",
+ })
+ }
+ for _, modelName := range moonshot.ModelList {
+ openAIModels = append(openAIModels, OpenAIModels{
+ Id: modelName,
Object: "model",
- Created: 1677649963,
- OwnedBy: "360",
+ Created: 1626777600,
+ OwnedBy: "moonshot",
Permission: permission,
- Root: "embedding-bert-512-v1",
+ Root: modelName,
Parent: nil,
- },
- {
- Id: "embedding_s1_v1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "360",
- Permission: permission,
- Root: "embedding_s1_v1",
- Parent: nil,
- },
- {
- Id: "semantic_similarity_s1_v1",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "360",
- Permission: permission,
- Root: "semantic_similarity_s1_v1",
- Parent: nil,
- },
- {
- Id: "hunyuan",
- Object: "model",
- Created: 1677649963,
- OwnedBy: "tencent",
- Permission: permission,
- Root: "hunyuan",
- Parent: nil,
- },
+ })
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
@@ -613,7 +116,7 @@ func RetrieveModel(c *gin.Context) {
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
- Error := openai.Error{
+ Error := relaymodel.Error{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",
diff --git a/controller/option.go b/controller/option.go
index 3b1cbad2..f86e3a64 100644
--- a/controller/option.go
+++ b/controller/option.go
@@ -2,9 +2,10 @@ package controller
import (
"encoding/json"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/model"
"net/http"
- "one-api/common"
- "one-api/model"
"strings"
"github.com/gin-gonic/gin"
@@ -12,17 +13,17 @@ import (
func GetOptions(c *gin.Context) {
var options []*model.Option
- common.OptionMapRWMutex.Lock()
- for k, v := range common.OptionMap {
+ config.OptionMapRWMutex.Lock()
+ for k, v := range config.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue
}
options = append(options, &model.Option{
Key: k,
- Value: common.Interface2String(v),
+ Value: helper.Interface2String(v),
})
}
- common.OptionMapRWMutex.Unlock()
+ config.OptionMapRWMutex.Unlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -43,7 +44,7 @@ func UpdateOption(c *gin.Context) {
}
switch option.Key {
case "Theme":
- if !common.ValidThemes[option.Value] {
+ if !config.ValidThemes[option.Value] {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的主题",
@@ -51,7 +52,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "GitHubOAuthEnabled":
- if option.Value == "true" && common.GitHubClientId == "" {
+ if option.Value == "true" && config.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
@@ -59,7 +60,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "EmailDomainRestrictionEnabled":
- if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
+ if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
@@ -67,7 +68,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "WeChatAuthEnabled":
- if option.Value == "true" && common.WeChatServerAddress == "" {
+ if option.Value == "true" && config.WeChatServerAddress == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用微信登录,请先填入微信登录相关配置信息!",
@@ -75,7 +76,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "TurnstileCheckEnabled":
- if option.Value == "true" && common.TurnstileSiteKey == "" {
+ if option.Value == "true" && config.TurnstileSiteKey == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
diff --git a/controller/redemption.go b/controller/redemption.go
index 0f656be0..31c9348d 100644
--- a/controller/redemption.go
+++ b/controller/redemption.go
@@ -2,9 +2,10 @@ package controller
import (
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/model"
"net/http"
- "one-api/common"
- "one-api/model"
"strconv"
)
@@ -13,7 +14,7 @@ func GetAllRedemptions(c *gin.Context) {
if p < 0 {
p = 0
}
- redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage)
+ redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -105,12 +106,12 @@ func AddRedemption(c *gin.Context) {
}
var keys []string
for i := 0; i < redemption.Count; i++ {
- key := common.GetUUID()
+ key := helper.GetUUID()
cleanRedemption := model.Redemption{
UserId: c.GetInt("id"),
Name: redemption.Name,
Key: key,
- CreatedTime: common.GetTimestamp(),
+ CreatedTime: helper.GetTimestamp(),
Quota: redemption.Quota,
}
err = cleanRedemption.Insert()
diff --git a/controller/relay.go b/controller/relay.go
index ef5a4bbb..0d619813 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -2,44 +2,23 @@ package controller
import (
"fmt"
- "net/http"
- "one-api/common"
- "one-api/relay/channel/openai"
- "one-api/relay/constant"
- "one-api/relay/controller"
- "one-api/relay/util"
- "strconv"
- "strings"
-
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/controller"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "net/http"
+ "strconv"
)
// https://platform.openai.com/docs/api-reference/chat
func Relay(c *gin.Context) {
- relayMode := constant.RelayModeUnknown
- if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
- relayMode = constant.RelayModeChatCompletions
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
- relayMode = constant.RelayModeCompletions
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
- relayMode = constant.RelayModeEmbeddings
- } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
- relayMode = constant.RelayModeEmbeddings
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
- relayMode = constant.RelayModeModerations
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
- relayMode = constant.RelayModeImagesGenerations
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
- relayMode = constant.RelayModeEdits
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
- relayMode = constant.RelayModeAudioSpeech
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
- relayMode = constant.RelayModeAudioTranscription
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
- relayMode = constant.RelayModeAudioTranslation
- }
- var err *openai.ErrorWithStatusCode
+ relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+ var err *model.ErrorWithStatusCode
switch relayMode {
case constant.RelayModeImagesGenerations:
err = controller.RelayImageHelper(c, relayMode)
@@ -50,14 +29,14 @@ func Relay(c *gin.Context) {
case constant.RelayModeAudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
default:
- err = controller.RelayTextHelper(c, relayMode)
+ err = controller.RelayTextHelper(c)
}
if err != nil {
- requestId := c.GetString(common.RequestIdKey)
+ requestId := c.GetString(logger.RequestIdKey)
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
- retryTimes = common.RetryTimes
+ retryTimes = config.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
@@ -66,14 +45,16 @@ func Relay(c *gin.Context) {
err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
+
err.Error.Message = common.MessageWithRequestId("Request From https://api.adamchatbot.chat Error", requestId)
+
c.JSON(err.StatusCode, gin.H{
"error": err.Error,
})
}
channelId := c.GetInt("channel_id")
- common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
+ logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
channelId := c.GetInt("channel_id")
@@ -84,7 +65,7 @@ func Relay(c *gin.Context) {
}
func RelayNotImplemented(c *gin.Context) {
- err := openai.Error{
+ err := model.Error{
Message: "API not implemented",
Type: "one_api_error",
Param: "",
@@ -96,7 +77,7 @@ func RelayNotImplemented(c *gin.Context) {
}
func RelayNotFound(c *gin.Context) {
- err := openai.Error{
+ err := model.Error{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",
diff --git a/controller/token.go b/controller/token.go
index 6917053e..6012c482 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -2,9 +2,11 @@ package controller
import (
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/model"
"net/http"
- "one-api/common"
- "one-api/model"
"strconv"
)
@@ -14,7 +16,7 @@ func GetAllTokens(c *gin.Context) {
if p < 0 {
p = 0
}
- tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage)
+ tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -135,9 +137,9 @@ func AddToken(c *gin.Context) {
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
- Key: common.GenerateKey(),
- CreatedTime: common.GetTimestamp(),
- AccessedTime: common.GetTimestamp(),
+ Key: helper.GenerateKey(),
+ CreatedTime: helper.GetTimestamp(),
+ AccessedTime: helper.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
@@ -203,7 +205,7 @@ func UpdateToken(c *gin.Context) {
return
}
if token.Status == common.TokenStatusEnabled {
- if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
+ if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
diff --git a/controller/user.go b/controller/user.go
index 174300ed..c11b940e 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -3,9 +3,11 @@ package controller
import (
"encoding/json"
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/model"
"net/http"
- "one-api/common"
- "one-api/model"
"strconv"
"time"
@@ -19,7 +21,7 @@ type LoginRequest struct {
}
func Login(c *gin.Context) {
- if !common.PasswordLoginEnabled {
+ if !config.PasswordLoginEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了密码登录",
"success": false,
@@ -106,14 +108,14 @@ func Logout(c *gin.Context) {
}
func Register(c *gin.Context) {
- if !common.RegisterEnabled {
+ if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册",
"success": false,
})
return
}
- if !common.PasswordRegisterEnabled {
+ if !config.PasswordRegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
"success": false,
@@ -136,7 +138,7 @@ func Register(c *gin.Context) {
})
return
}
- if common.EmailVerificationEnabled {
+ if config.EmailVerificationEnabled {
if user.Email == "" || user.VerificationCode == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -160,7 +162,7 @@ func Register(c *gin.Context) {
DisplayName: user.Username,
InviterId: inviterId,
}
- if common.EmailVerificationEnabled {
+ if config.EmailVerificationEnabled {
cleanUser.Email = user.Email
}
if err := cleanUser.Insert(inviterId); err != nil {
@@ -182,7 +184,7 @@ func GetAllUsers(c *gin.Context) {
if p < 0 {
p = 0
}
- users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage)
+ users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -282,7 +284,7 @@ func GenerateAccessToken(c *gin.Context) {
})
return
}
- user.AccessToken = common.GetUUID()
+ user.AccessToken = helper.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{
@@ -319,7 +321,7 @@ func GetAffCode(c *gin.Context) {
return
}
if user.AffCode == "" {
- user.AffCode = common.GetRandomString(4)
+ user.AffCode = helper.GetRandomString(4)
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -726,7 +728,7 @@ func EmailBind(c *gin.Context) {
return
}
if user.Role == common.RoleRootUser {
- common.RootUserEmail = email
+ config.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{
"success": true,
diff --git a/controller/wechat.go b/controller/wechat.go
index ff4c9fb6..74be5604 100644
--- a/controller/wechat.go
+++ b/controller/wechat.go
@@ -5,9 +5,10 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/model"
"net/http"
- "one-api/common"
- "one-api/model"
"strconv"
"time"
)
@@ -22,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) {
if code == "" {
return "", errors.New("无效的参数")
}
- req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
+ req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil)
if err != nil {
return "", err
}
- req.Header.Set("Authorization", common.WeChatServerToken)
+ req.Header.Set("Authorization", config.WeChatServerToken)
client := http.Client{
Timeout: 5 * time.Second,
}
@@ -50,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) {
}
func WeChatAuth(c *gin.Context) {
- if !common.WeChatAuthEnabled {
+ if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册",
"success": false,
@@ -79,7 +80,7 @@ func WeChatAuth(c *gin.Context) {
return
}
} else {
- if common.RegisterEnabled {
+ if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User"
user.Role = common.RoleCommonUser
@@ -112,7 +113,7 @@ func WeChatAuth(c *gin.Context) {
}
func WeChatBind(c *gin.Context) {
- if !common.WeChatAuthEnabled {
+ if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册",
"success": false,
diff --git a/go.mod b/go.mod
index 68dd5eb6..4ab23003 100644
--- a/go.mod
+++ b/go.mod
@@ -1,4 +1,4 @@
-module one-api
+module github.com/songquanpeng/one-api
// +heroku goVersion go1.18
go 1.18
diff --git a/main.go b/main.go
index 28a41287..1f43a45f 100644
--- a/main.go
+++ b/main.go
@@ -6,12 +6,14 @@ import (
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
- "one-api/common"
- "one-api/controller"
- "one-api/middleware"
- "one-api/model"
- "one-api/relay/channel/openai"
- "one-api/router"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/controller"
+ "github.com/songquanpeng/one-api/middleware"
+ "github.com/songquanpeng/one-api/model"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/router"
"os"
"strconv"
)
@@ -20,65 +22,65 @@ import (
var buildFS embed.FS
func main() {
- common.SetupLogger()
- common.SysLog(fmt.Sprintf("One API %s started", common.Version))
+ logger.SetupLogger()
+ logger.SysLog(fmt.Sprintf("One API %s started", common.Version))
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
}
- if common.DebugEnabled {
- common.SysLog("running in debug mode")
+ if config.DebugEnabled {
+ logger.SysLog("running in debug mode")
}
// Initialize SQL Database
err := model.InitDB()
if err != nil {
- common.FatalLog("failed to initialize database: " + err.Error())
+ logger.FatalLog("failed to initialize database: " + err.Error())
}
defer func() {
err := model.CloseDB()
if err != nil {
- common.FatalLog("failed to close database: " + err.Error())
+ logger.FatalLog("failed to close database: " + err.Error())
}
}()
// Initialize Redis
err = common.InitRedisClient()
if err != nil {
- common.FatalLog("failed to initialize Redis: " + err.Error())
+ logger.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize options
model.InitOptionMap()
- common.SysLog(fmt.Sprintf("using theme %s", common.Theme))
+ logger.SysLog(fmt.Sprintf("using theme %s", config.Theme))
if common.RedisEnabled {
// for compatibility with old versions
- common.MemoryCacheEnabled = true
+ config.MemoryCacheEnabled = true
}
- if common.MemoryCacheEnabled {
- common.SysLog("memory cache enabled")
- common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
+ if config.MemoryCacheEnabled {
+ logger.SysLog("memory cache enabled")
+ logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency))
model.InitChannelCache()
}
- if common.MemoryCacheEnabled {
- go model.SyncOptions(common.SyncFrequency)
- go model.SyncChannelCache(common.SyncFrequency)
+ if config.MemoryCacheEnabled {
+ go model.SyncOptions(config.SyncFrequency)
+ go model.SyncChannelCache(config.SyncFrequency)
}
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil {
- common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
+ logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyUpdateChannels(frequency)
}
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil {
- common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
+ logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyTestChannels(frequency)
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
- common.BatchUpdateEnabled = true
- common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
+ config.BatchUpdateEnabled = true
+ logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s")
model.InitBatchUpdater()
}
openai.InitTokenEncoders()
@@ -91,7 +93,7 @@ func main() {
server.Use(middleware.RequestId())
middleware.SetUpLogger(server)
// Initialize session store
- store := cookie.NewStore([]byte(common.SessionSecret))
+ store := cookie.NewStore([]byte(config.SessionSecret))
server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS)
@@ -101,6 +103,6 @@ func main() {
}
err = server.Run(":" + port)
if err != nil {
- common.FatalLog("failed to start HTTP server: " + err.Error())
+ logger.FatalLog("failed to start HTTP server: " + err.Error())
}
}
diff --git a/middleware/auth.go b/middleware/auth.go
index ad7e64b7..42a599d0 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -3,9 +3,9 @@ package middleware
import (
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/model"
"net/http"
- "one-api/common"
- "one-api/model"
"strings"
)
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 5bf4d43c..4b5bb965 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -2,9 +2,10 @@ package middleware
import (
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
"net/http"
- "one-api/common"
- "one-api/model"
"strconv"
"strings"
@@ -73,7 +74,7 @@ func Distribute() func(c *gin.Context) {
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil {
- common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
+ logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
abortWithMessage(c, http.StatusServiceUnavailable, message)
@@ -86,17 +87,22 @@ func Distribute() func(c *gin.Context) {
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
+ // this is for backward compatibility
switch channel.Type {
case common.ChannelTypeAzure:
- c.Set("api_version", channel.Other)
+ c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei:
- c.Set("api_version", channel.Other)
+ c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeGemini:
- c.Set("api_version", channel.Other)
+ c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary:
- c.Set("library_id", channel.Other)
+ c.Set(common.ConfigKeyLibraryID, channel.Other)
case common.ChannelTypeAli:
- c.Set("plugin", channel.Other)
+ c.Set(common.ConfigKeyPlugin, channel.Other)
+ }
+ cfg, _ := channel.LoadConfig()
+ for k, v := range cfg {
+ c.Set(common.ConfigKeyPrefix+k, v)
}
c.Next()
}
diff --git a/middleware/logger.go b/middleware/logger.go
index 02f2e0a9..6aae4f23 100644
--- a/middleware/logger.go
+++ b/middleware/logger.go
@@ -3,14 +3,14 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
- "one-api/common"
+ "github.com/songquanpeng/one-api/common/logger"
)
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
- requestID = param.Keys[common.RequestIdKey].(string)
+ requestID = param.Keys[logger.RequestIdKey].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go
index 8e5cff6c..0f300f2b 100644
--- a/middleware/rate-limit.go
+++ b/middleware/rate-limit.go
@@ -4,8 +4,9 @@ import (
"context"
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
"net/http"
- "one-api/common"
"time"
)
@@ -26,7 +27,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
}
if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
} else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
@@ -47,14 +48,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
// time.Since will return negative number!
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests)
c.Abort()
return
} else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
}
}
}
@@ -75,7 +76,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
}
} else {
// It's safe to call multi times.
- inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
+ inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
memoryRateLimiter(c, maxRequestNum, duration, mark)
}
@@ -83,21 +84,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
}
func GlobalWebRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW")
+ return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW")
}
func GlobalAPIRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")
+ return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA")
}
func CriticalRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")
+ return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT")
}
func DownloadRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW")
+ return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW")
}
func UploadRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
+ return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP")
}
diff --git a/middleware/recover.go b/middleware/recover.go
index 8338a514..02e3e3bb 100644
--- a/middleware/recover.go
+++ b/middleware/recover.go
@@ -3,8 +3,8 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/logger"
"net/http"
- "one-api/common"
"runtime/debug"
)
@@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
- common.SysError(fmt.Sprintf("panic detected: %v", err))
- common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
+ logger.SysError(fmt.Sprintf("panic detected: %v", err))
+ logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
diff --git a/middleware/request-id.go b/middleware/request-id.go
index e623be7a..7cb66e93 100644
--- a/middleware/request-id.go
+++ b/middleware/request-id.go
@@ -3,16 +3,17 @@ package middleware
import (
"context"
"github.com/gin-gonic/gin"
- "one-api/common"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
- id := common.GetTimeString() + common.GetRandomString(8)
- c.Set(common.RequestIdKey, id)
- ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
+ id := helper.GetTimeString() + helper.GetRandomString(8)
+ c.Set(logger.RequestIdKey, id)
+ ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
- c.Header(common.RequestIdKey, id)
+ c.Header(logger.RequestIdKey, id)
c.Next()
}
}
diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go
index 26688810..403bcb34 100644
--- a/middleware/turnstile-check.go
+++ b/middleware/turnstile-check.go
@@ -4,9 +4,10 @@ import (
"encoding/json"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
"net/http"
"net/url"
- "one-api/common"
)
type turnstileCheckResponse struct {
@@ -15,7 +16,7 @@ type turnstileCheckResponse struct {
func TurnstileCheck() gin.HandlerFunc {
return func(c *gin.Context) {
- if common.TurnstileCheckEnabled {
+ if config.TurnstileCheckEnabled {
session := sessions.Default(c)
turnstileChecked := session.Get("turnstile")
if turnstileChecked != nil {
@@ -32,12 +33,12 @@ func TurnstileCheck() gin.HandlerFunc {
return
}
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
- "secret": {common.TurnstileSecretKey},
+ "secret": {config.TurnstileSecretKey},
"response": {response},
"remoteip": {c.ClientIP()},
})
if err != nil {
- common.SysError(err.Error())
+ logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc {
var res turnstileCheckResponse
err = json.NewDecoder(rawRes.Body).Decode(&res)
if err != nil {
- common.SysError(err.Error())
+ logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
diff --git a/middleware/utils.go b/middleware/utils.go
index 536125cc..bc14c367 100644
--- a/middleware/utils.go
+++ b/middleware/utils.go
@@ -2,16 +2,17 @@ package middleware
import (
"github.com/gin-gonic/gin"
- "one-api/common"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
- "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
+ "message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
"type": "one_api_error",
},
})
c.Abort()
- common.LogError(c.Request.Context(), message)
+ logger.Error(c.Request.Context(), message)
}
diff --git a/model/ability.go b/model/ability.go
index 3da83be8..7127abc3 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -1,7 +1,7 @@
package model
import (
- "one-api/common"
+ "github.com/songquanpeng/one-api/common"
"strings"
)
diff --git a/model/cache.go b/model/cache.go
index c6d0c70a..297df153 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -4,8 +4,10 @@ import (
"encoding/json"
"errors"
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
"math/rand"
- "one-api/common"
"sort"
"strconv"
"strings"
@@ -14,10 +16,10 @@ import (
)
var (
- TokenCacheSeconds = common.SyncFrequency
- UserId2GroupCacheSeconds = common.SyncFrequency
- UserId2QuotaCacheSeconds = common.SyncFrequency
- UserId2StatusCacheSeconds = common.SyncFrequency
+ TokenCacheSeconds = config.SyncFrequency
+ UserId2GroupCacheSeconds = config.SyncFrequency
+ UserId2QuotaCacheSeconds = config.SyncFrequency
+ UserId2StatusCacheSeconds = config.SyncFrequency
)
func CacheGetTokenByKey(key string) (*Token, error) {
@@ -42,7 +44,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
}
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set token error: " + err.Error())
+ logger.SysError("Redis set token error: " + err.Error())
}
return &token, nil
}
@@ -62,7 +64,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
}
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set user group error: " + err.Error())
+ logger.SysError("Redis set user group error: " + err.Error())
}
}
return group, err
@@ -80,7 +82,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set user quota error: " + err.Error())
+ logger.SysError("Redis set user quota error: " + err.Error())
}
return quota, err
}
@@ -127,7 +129,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
}
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set user enabled error: " + err.Error())
+ logger.SysError("Redis set user enabled error: " + err.Error())
}
return userEnabled, err
}
@@ -178,19 +180,19 @@ func InitChannelCache() {
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
- common.SysLog("channels synced from database")
+ logger.SysLog("channels synced from database")
}
func SyncChannelCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing channels from database")
+ logger.SysLog("syncing channels from database")
InitChannelCache()
}
}
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
- if !common.MemoryCacheEnabled {
+ if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
}
channelSyncLock.RLock()
diff --git a/model/channel.go b/model/channel.go
index 7e7b42e6..19af2263 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -1,8 +1,13 @@
package model
import (
+ "encoding/json"
+ "fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
- "one-api/common"
)
type Channel struct {
@@ -16,7 +21,7 @@ type Channel struct {
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
- Other string `json:"other"`
+ Other string `json:"other"` // DEPRECATED: please save config to field Config
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
@@ -24,6 +29,7 @@ type Channel struct {
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
+ Config string `json:"config"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -42,7 +48,7 @@ func SearchChannels(keyword string) (channels []*Channel, err error) {
if common.UsingPostgreSQL {
keyCol = `"key"`
}
- err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
+ err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
return channels, err
}
@@ -86,11 +92,17 @@ func (channel *Channel) GetBaseURL() string {
return *channel.BaseURL
}
-func (channel *Channel) GetModelMapping() string {
- if channel.ModelMapping == nil {
- return ""
+func (channel *Channel) GetModelMapping() map[string]string {
+ if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" {
+ return nil
}
- return *channel.ModelMapping
+ modelMapping := make(map[string]string)
+ err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping)
+ if err != nil {
+ logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error()))
+ return nil
+ }
+ return modelMapping
}
func (channel *Channel) Insert() error {
@@ -116,21 +128,21 @@ func (channel *Channel) Update() error {
func (channel *Channel) UpdateResponseTime(responseTime int64) {
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
- TestTime: common.GetTimestamp(),
+ TestTime: helper.GetTimestamp(),
ResponseTime: int(responseTime),
}).Error
if err != nil {
- common.SysError("failed to update response time: " + err.Error())
+ logger.SysError("failed to update response time: " + err.Error())
}
}
func (channel *Channel) UpdateBalance(balance float64) {
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
- BalanceUpdatedTime: common.GetTimestamp(),
+ BalanceUpdatedTime: helper.GetTimestamp(),
Balance: balance,
}).Error
if err != nil {
- common.SysError("failed to update balance: " + err.Error())
+ logger.SysError("failed to update balance: " + err.Error())
}
}
@@ -144,19 +156,31 @@ func (channel *Channel) Delete() error {
return err
}
+func (channel *Channel) LoadConfig() (map[string]string, error) {
+ if channel.Config == "" {
+ return nil, nil
+ }
+ cfg := make(map[string]string)
+ err := json.Unmarshal([]byte(channel.Config), &cfg)
+ if err != nil {
+ return nil, err
+ }
+ return cfg, nil
+}
+
func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil {
- common.SysError("failed to update ability status: " + err.Error())
+ logger.SysError("failed to update ability status: " + err.Error())
}
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
- common.SysError("failed to update channel status: " + err.Error())
+ logger.SysError("failed to update channel status: " + err.Error())
}
}
func UpdateChannelUsedQuota(id int, quota int) {
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return
}
@@ -166,7 +190,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
func updateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
- common.SysError("failed to update channel used quota: " + err.Error())
+ logger.SysError("failed to update channel used quota: " + err.Error())
}
}
diff --git a/model/log.go b/model/log.go
index 53ffe696..b894a423 100644
--- a/model/log.go
+++ b/model/log.go
@@ -3,7 +3,10 @@ package model
import (
"context"
"fmt"
- "one-api/common"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
)
@@ -32,31 +35,31 @@ const (
)
func RecordLog(userId int, logType int, content string) {
- if logType == LogTypeConsume && !common.LogConsumeEnabled {
+ if logType == LogTypeConsume && !config.LogConsumeEnabled {
return
}
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
- CreatedAt: common.GetTimestamp(),
+ CreatedAt: helper.GetTimestamp(),
Type: logType,
Content: content,
}
err := DB.Create(log).Error
if err != nil {
- common.SysError("failed to record log: " + err.Error())
+ logger.SysError("failed to record log: " + err.Error())
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
- common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
- if !common.LogConsumeEnabled {
+ logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
+ if !config.LogConsumeEnabled {
return
}
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
- CreatedAt: common.GetTimestamp(),
+ CreatedAt: helper.GetTimestamp(),
Type: LogTypeConsume,
Content: content,
PromptTokens: promptTokens,
@@ -68,7 +71,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
}
err := DB.Create(log).Error
if err != nil {
- common.LogError(ctx, "failed to record log: "+err.Error())
+ logger.Error(ctx, "failed to record log: "+err.Error())
}
}
func GetLogsByKey(logType int, startTimestamp int64, endTimestamp int64, key string, startIdx int, num int) (logs []*Log, err error) {
@@ -145,12 +148,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
}
func SearchAllLogs(keyword string) (logs []*Log, err error) {
- err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
+ err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
return logs, err
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
- err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
+ err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err
}
diff --git a/model/main.go b/model/main.go
index 9723e638..18ed01d0 100644
--- a/model/main.go
+++ b/model/main.go
@@ -2,11 +2,14 @@ package model
import (
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
- "one-api/common"
"os"
"strings"
"time"
@@ -18,7 +21,7 @@ func createRootAccountIfNeed() error {
var user User
//if user.Status != util.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil {
- common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
+ logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
hashedPassword, err := common.Password2Hash("123456")
if err != nil {
return err
@@ -29,7 +32,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
DisplayName: "Root User",
- AccessToken: common.GetUUID(),
+ AccessToken: helper.GetUUID(),
Quota: 100000000,
}
DB.Create(&rootUser)
@@ -42,7 +45,7 @@ func chooseDB() (*gorm.DB, error) {
dsn := os.Getenv("SQL_DSN")
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
- common.SysLog("using PostgreSQL as database")
+ logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
@@ -52,13 +55,13 @@ func chooseDB() (*gorm.DB, error) {
})
}
// Use MySQL
- common.SysLog("using MySQL as database")
+ logger.SysLog("using MySQL as database")
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use SQLite
- common.SysLog("SQL_DSN not set, using SQLite as database")
+ logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
@@ -69,7 +72,7 @@ func chooseDB() (*gorm.DB, error) {
func InitDB() (err error) {
db, err := chooseDB()
if err == nil {
- if common.DebugEnabled {
+ if config.DebugEnabled {
db = db.Debug()
}
DB = db
@@ -77,14 +80,14 @@ func InitDB() (err error) {
if err != nil {
return err
}
- sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
- sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
- sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
+ sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
+ sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
+ sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
- if !common.IsMasterNode {
+ if !config.IsMasterNode {
return nil
}
- common.SysLog("database migration started")
+ logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
return err
@@ -113,11 +116,11 @@ func InitDB() (err error) {
if err != nil {
return err
}
- common.SysLog("database migrated")
+ logger.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
} else {
- common.FatalLog(err)
+ logger.FatalLog(err)
}
return err
}
diff --git a/model/option.go b/model/option.go
index 20575c9a..6002c795 100644
--- a/model/option.go
+++ b/model/option.go
@@ -1,7 +1,9 @@
package model
import (
- "one-api/common"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
"strconv"
"strings"
"time"
@@ -20,60 +22,57 @@ func AllOption() ([]*Option, error) {
}
func InitOptionMap() {
- common.OptionMapRWMutex.Lock()
- common.OptionMap = make(map[string]string)
- common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
- common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
- common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
- common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission)
- common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
- common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
- common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
- common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
- common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
- common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
- common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
- common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
- common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
- common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
- common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
- common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
- common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
- common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
- common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
- common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
- common.OptionMap["SMTPServer"] = ""
- common.OptionMap["SMTPFrom"] = ""
- common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
- common.OptionMap["SMTPAccount"] = ""
- common.OptionMap["SMTPToken"] = ""
- common.OptionMap["Notice"] = ""
- common.OptionMap["About"] = ""
- common.OptionMap["HomePageContent"] = ""
- common.OptionMap["Footer"] = common.Footer
- common.OptionMap["SystemName"] = common.SystemName
- common.OptionMap["Logo"] = common.Logo
- common.OptionMap["ServerAddress"] = ""
- common.OptionMap["GitHubClientId"] = ""
- common.OptionMap["GitHubClientSecret"] = ""
- common.OptionMap["WeChatServerAddress"] = ""
- common.OptionMap["WeChatServerToken"] = ""
- common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
- common.OptionMap["TurnstileSiteKey"] = ""
- common.OptionMap["TurnstileSecretKey"] = ""
- common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
- common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
- common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
- common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
- common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
- common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
- common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
- common.OptionMap["TopUpLink"] = common.TopUpLink
- common.OptionMap["ChatLink"] = common.ChatLink
- common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
- common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
- common.OptionMap["Theme"] = common.Theme
- common.OptionMapRWMutex.Unlock()
+ config.OptionMapRWMutex.Lock()
+ config.OptionMap = make(map[string]string)
+ config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled)
+ config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
+ config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
+ config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
+ config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
+ config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
+ config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
+ config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled)
+ config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled)
+ config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled)
+ config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled)
+ config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled)
+ config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled)
+ config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64)
+ config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled)
+ config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",")
+ config.OptionMap["SMTPServer"] = ""
+ config.OptionMap["SMTPFrom"] = ""
+ config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
+ config.OptionMap["SMTPAccount"] = ""
+ config.OptionMap["SMTPToken"] = ""
+ config.OptionMap["Notice"] = ""
+ config.OptionMap["About"] = ""
+ config.OptionMap["HomePageContent"] = ""
+ config.OptionMap["Footer"] = config.Footer
+ config.OptionMap["SystemName"] = config.SystemName
+ config.OptionMap["Logo"] = config.Logo
+ config.OptionMap["ServerAddress"] = ""
+ config.OptionMap["GitHubClientId"] = ""
+ config.OptionMap["GitHubClientSecret"] = ""
+ config.OptionMap["WeChatServerAddress"] = ""
+ config.OptionMap["WeChatServerToken"] = ""
+ config.OptionMap["WeChatAccountQRCodeImageURL"] = ""
+ config.OptionMap["TurnstileSiteKey"] = ""
+ config.OptionMap["TurnstileSecretKey"] = ""
+ config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser)
+ config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter)
+ config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee)
+ config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold)
+ config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota)
+ config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
+ config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
+ config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
+ config.OptionMap["TopUpLink"] = config.TopUpLink
+ config.OptionMap["ChatLink"] = config.ChatLink
+ config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
+ config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes)
+ config.OptionMap["Theme"] = config.Theme
+ config.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@@ -82,7 +81,7 @@ func loadOptionsFromDatabase() {
for _, option := range options {
err := updateOptionMap(option.Key, option.Value)
if err != nil {
- common.SysError("failed to update option map: " + err.Error())
+ logger.SysError("failed to update option map: " + err.Error())
}
}
}
@@ -90,7 +89,7 @@ func loadOptionsFromDatabase() {
func SyncOptions(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing options from database")
+ logger.SysLog("syncing options from database")
loadOptionsFromDatabase()
}
}
@@ -112,117 +111,106 @@ func UpdateOption(key string, value string) error {
}
func updateOptionMap(key string, value string) (err error) {
- common.OptionMapRWMutex.Lock()
- defer common.OptionMapRWMutex.Unlock()
- common.OptionMap[key] = value
- if strings.HasSuffix(key, "Permission") {
- intValue, _ := strconv.Atoi(value)
- switch key {
- case "FileUploadPermission":
- common.FileUploadPermission = intValue
- case "FileDownloadPermission":
- common.FileDownloadPermission = intValue
- case "ImageUploadPermission":
- common.ImageUploadPermission = intValue
- case "ImageDownloadPermission":
- common.ImageDownloadPermission = intValue
- }
- }
+ config.OptionMapRWMutex.Lock()
+ defer config.OptionMapRWMutex.Unlock()
+ config.OptionMap[key] = value
if strings.HasSuffix(key, "Enabled") {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
- common.PasswordRegisterEnabled = boolValue
+ config.PasswordRegisterEnabled = boolValue
case "PasswordLoginEnabled":
- common.PasswordLoginEnabled = boolValue
+ config.PasswordLoginEnabled = boolValue
case "EmailVerificationEnabled":
- common.EmailVerificationEnabled = boolValue
+ config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
- common.GitHubOAuthEnabled = boolValue
+ config.GitHubOAuthEnabled = boolValue
case "WeChatAuthEnabled":
- common.WeChatAuthEnabled = boolValue
+ config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled":
- common.TurnstileCheckEnabled = boolValue
+ config.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
- common.RegisterEnabled = boolValue
+ config.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled":
- common.EmailDomainRestrictionEnabled = boolValue
+ config.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
- common.AutomaticDisableChannelEnabled = boolValue
+ config.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
- common.AutomaticEnableChannelEnabled = boolValue
+ config.AutomaticEnableChannelEnabled = boolValue
case "ApproximateTokenEnabled":
- common.ApproximateTokenEnabled = boolValue
+ config.ApproximateTokenEnabled = boolValue
case "LogConsumeEnabled":
- common.LogConsumeEnabled = boolValue
+ config.LogConsumeEnabled = boolValue
case "DisplayInCurrencyEnabled":
- common.DisplayInCurrencyEnabled = boolValue
+ config.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled":
- common.DisplayTokenStatEnabled = boolValue
+ config.DisplayTokenStatEnabled = boolValue
}
}
switch key {
case "EmailDomainWhitelist":
- common.EmailDomainWhitelist = strings.Split(value, ",")
+ config.EmailDomainWhitelist = strings.Split(value, ",")
case "SMTPServer":
- common.SMTPServer = value
+ config.SMTPServer = value
case "SMTPPort":
intValue, _ := strconv.Atoi(value)
- common.SMTPPort = intValue
+ config.SMTPPort = intValue
case "SMTPAccount":
- common.SMTPAccount = value
+ config.SMTPAccount = value
case "SMTPFrom":
- common.SMTPFrom = value
+ config.SMTPFrom = value
case "SMTPToken":
- common.SMTPToken = value
+ config.SMTPToken = value
case "ServerAddress":
- common.ServerAddress = value
+ config.ServerAddress = value
case "GitHubClientId":
- common.GitHubClientId = value
+ config.GitHubClientId = value
case "GitHubClientSecret":
- common.GitHubClientSecret = value
+ config.GitHubClientSecret = value
case "Footer":
- common.Footer = value
+ config.Footer = value
case "SystemName":
- common.SystemName = value
+ config.SystemName = value
case "Logo":
- common.Logo = value
+ config.Logo = value
case "WeChatServerAddress":
- common.WeChatServerAddress = value
+ config.WeChatServerAddress = value
case "WeChatServerToken":
- common.WeChatServerToken = value
+ config.WeChatServerToken = value
case "WeChatAccountQRCodeImageURL":
- common.WeChatAccountQRCodeImageURL = value
+ config.WeChatAccountQRCodeImageURL = value
case "TurnstileSiteKey":
- common.TurnstileSiteKey = value
+ config.TurnstileSiteKey = value
case "TurnstileSecretKey":
- common.TurnstileSecretKey = value
+ config.TurnstileSecretKey = value
case "QuotaForNewUser":
- common.QuotaForNewUser, _ = strconv.Atoi(value)
+ config.QuotaForNewUser, _ = strconv.Atoi(value)
case "QuotaForInviter":
- common.QuotaForInviter, _ = strconv.Atoi(value)
+ config.QuotaForInviter, _ = strconv.Atoi(value)
case "QuotaForInvitee":
- common.QuotaForInvitee, _ = strconv.Atoi(value)
+ config.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold":
- common.QuotaRemindThreshold, _ = strconv.Atoi(value)
+ config.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota":
- common.PreConsumedQuota, _ = strconv.Atoi(value)
+ config.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes":
- common.RetryTimes, _ = strconv.Atoi(value)
+ config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
+ case "CompletionRatio":
+ err = common.UpdateCompletionRatioByJSONString(value)
case "TopUpLink":
- common.TopUpLink = value
+ config.TopUpLink = value
case "ChatLink":
- common.ChatLink = value
+ config.ChatLink = value
case "ChannelDisableThreshold":
- common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
+ config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":
- common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
+ config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "Theme":
- common.Theme = value
+ config.Theme = value
}
return err
}
diff --git a/model/redemption.go b/model/redemption.go
index f16412b5..2c5a4141 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -3,8 +3,9 @@ package model
import (
"errors"
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/helper"
"gorm.io/gorm"
- "one-api/common"
)
type Redemption struct {
@@ -67,7 +68,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return err
}
- redemption.RedeemedTime = common.GetTimestamp()
+ redemption.RedeemedTime = helper.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed
err = tx.Save(redemption).Error
return err
diff --git a/model/token.go b/model/token.go
index 73705d96..995ff1a9 100644
--- a/model/token.go
+++ b/model/token.go
@@ -3,8 +3,15 @@ package model
import (
"errors"
"fmt"
+
"one-api/common"
+
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
+
"gorm.io/gorm"
)
@@ -48,7 +55,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
}
token, err = CacheGetTokenByKey(key)
if err != nil {
- common.SysError("CacheGetTokenByKey failed: " + err.Error())
+ logger.SysError("CacheGetTokenByKey failed: " + err.Error())
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("无效的令牌")
}
@@ -62,12 +69,12 @@ func ValidateUserToken(key string) (token *Token, err error) {
if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
}
- if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
+ if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
if !common.RedisEnabled {
token.Status = common.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
- common.SysError("failed to update token status" + err.Error())
+ logger.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌已过期")
@@ -78,7 +85,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
- common.SysError("failed to update token status" + err.Error())
+ logger.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌额度已用尽")
@@ -147,7 +154,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil
}
@@ -159,7 +166,7 @@ func increaseTokenQuota(id int, quota int) (err error) {
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota),
- "accessed_time": common.GetTimestamp(),
+ "accessed_time": helper.GetTimestamp(),
},
).Error
return err
@@ -169,7 +176,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil
}
@@ -181,7 +188,7 @@ func decreaseTokenQuota(id int, quota int) (err error) {
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota),
- "accessed_time": common.GetTimestamp(),
+ "accessed_time": helper.GetTimestamp(),
},
).Error
return err
@@ -205,24 +212,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if userQuota < quota {
return errors.New("用户额度不足")
}
- quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold
+ quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold
noMoreQuota := userQuota-quota <= 0
if quotaTooLow || noMoreQuota {
go func() {
email, err := GetUserEmail(token.UserId)
if err != nil {
- common.SysError("failed to fetch user email: " + err.Error())
+ logger.SysError("failed to fetch user email: " + err.Error())
}
prompt := "您的额度即将用尽"
if noMoreQuota {
prompt = "您的额度已用尽"
}
if email != "" {
- topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
+ topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress)
err = common.SendEmail(prompt, email,
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
- common.SysError("failed to send email" + err.Error())
+ logger.SysError("failed to send email" + err.Error())
}
}
}()
diff --git a/model/user.go b/model/user.go
index 1c2c0a75..6979c70b 100644
--- a/model/user.go
+++ b/model/user.go
@@ -3,8 +3,11 @@ package model
import (
"errors"
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
- "one-api/common"
"strings"
)
@@ -89,24 +92,24 @@ func (user *User) Insert(inviterId int) error {
return err
}
}
- user.Quota = common.QuotaForNewUser
- user.AccessToken = common.GetUUID()
- user.AffCode = common.GetRandomString(4)
+ user.Quota = config.QuotaForNewUser
+ user.AccessToken = helper.GetUUID()
+ user.AffCode = helper.GetRandomString(4)
result := DB.Create(user)
if result.Error != nil {
return result.Error
}
- if common.QuotaForNewUser > 0 {
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
+ if config.QuotaForNewUser > 0 {
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
}
if inviterId != 0 {
- if common.QuotaForInvitee > 0 {
- _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
+ if config.QuotaForInvitee > 0 {
+ _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
}
- if common.QuotaForInviter > 0 {
- _ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
- RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
+ if config.QuotaForInviter > 0 {
+ _ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
+ RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
}
}
return nil
@@ -232,7 +235,7 @@ func IsAdmin(userId int) bool {
var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil {
- common.SysError("no such user " + err.Error())
+ logger.SysError("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
@@ -291,7 +294,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
@@ -307,7 +310,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil
}
@@ -325,7 +328,7 @@ func GetRootUserEmail() (email string) {
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return
@@ -341,7 +344,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
},
).Error
if err != nil {
- common.SysError("failed to update user used quota and request count: " + err.Error())
+ logger.SysError("failed to update user used quota and request count: " + err.Error())
}
}
@@ -352,14 +355,14 @@ func updateUserUsedQuota(id int, quota int) {
},
).Error
if err != nil {
- common.SysError("failed to update user used quota: " + err.Error())
+ logger.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
- common.SysError("failed to update user request count: " + err.Error())
+ logger.SysError("failed to update user request count: " + err.Error())
}
}
diff --git a/model/utils.go b/model/utils.go
index 1c28340b..d481973a 100644
--- a/model/utils.go
+++ b/model/utils.go
@@ -1,7 +1,8 @@
package model
import (
- "one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
"sync"
"time"
)
@@ -28,7 +29,7 @@ func init() {
func InitBatchUpdater() {
go func() {
for {
- time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
+ time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second)
batchUpdate()
}
}()
@@ -45,7 +46,7 @@ func addNewRecord(type_ int, id int, value int) {
}
func batchUpdate() {
- common.SysLog("batch update started")
+ logger.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
@@ -57,12 +58,12 @@ func batchUpdate() {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
- common.SysError("failed to batch update user quota: " + err.Error())
+ logger.SysError("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
- common.SysError("failed to batch update token quota: " + err.Error())
+ logger.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
@@ -73,5 +74,5 @@ func batchUpdate() {
}
}
}
- common.SysLog("batch update finished")
+ logger.SysLog("batch update finished")
}
diff --git a/relay/channel/ai360/constants.go b/relay/channel/ai360/constants.go
new file mode 100644
index 00000000..cfc3cb28
--- /dev/null
+++ b/relay/channel/ai360/constants.go
@@ -0,0 +1,8 @@
+package ai360
+
+var ModelList = []string{
+ "360GPT_S2_V9",
+ "embedding-bert-512-v1",
+ "embedding_s1_v1",
+ "semantic_similarity_s1_v1",
+}
diff --git a/relay/channel/aiproxy/adaptor.go b/relay/channel/aiproxy/adaptor.go
new file mode 100644
index 00000000..2b4e3022
--- /dev/null
+++ b/relay/channel/aiproxy/adaptor.go
@@ -0,0 +1,60 @@
+package aiproxy
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("Authorization", "Bearer "+meta.APIKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ aiProxyLibraryRequest := ConvertRequest(*request)
+ aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
+ return aiProxyLibraryRequest, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ err, usage = StreamHandler(c, resp)
+ } else {
+ err, usage = Handler(c, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "aiproxy"
+}
diff --git a/relay/channel/aiproxy/constants.go b/relay/channel/aiproxy/constants.go
new file mode 100644
index 00000000..c4df51c4
--- /dev/null
+++ b/relay/channel/aiproxy/constants.go
@@ -0,0 +1,9 @@
+package aiproxy
+
+import "github.com/songquanpeng/one-api/relay/channel/openai"
+
+var ModelList = []string{""}
+
+func init() {
+ ModelList = openai.ModelList
+}
diff --git a/relay/channel/aiproxy/main.go b/relay/channel/aiproxy/main.go
index bee4d9d3..0d3d0b60 100644
--- a/relay/channel/aiproxy/main.go
+++ b/relay/channel/aiproxy/main.go
@@ -5,18 +5,21 @@ import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
- "one-api/common"
- "one-api/relay/channel/openai"
- "one-api/relay/constant"
"strconv"
"strings"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
-func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest {
+func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].StringContent()
@@ -43,16 +46,16 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
- Id: common.GetUUID(),
+ Id: helper.GetUUID(),
Object: "chat.completion",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
@@ -63,9 +66,9 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{
- Id: common.GetUUID(),
+ Id: helper.GetUUID(),
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
@@ -75,16 +78,16 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &openai.ChatCompletionsStreamResponse{
- Id: common.GetUUID(),
+ Id: helper.GetUUID(),
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: response.Model,
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
- var usage openai.Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -122,7 +125,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var AIProxyLibraryResponse LibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
@@ -131,7 +134,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -140,7 +143,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -155,7 +158,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, &usage
}
-func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var AIProxyLibraryResponse LibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -170,8 +173,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
@@ -187,5 +190,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
+ }
return nil, &fullTextResponse.Usage
}
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
new file mode 100644
index 00000000..6c6f433e
--- /dev/null
+++ b/relay/channel/ali/adaptor.go
@@ -0,0 +1,83 @@
+package ali
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
+ if meta.Mode == constant.RelayModeEmbeddings {
+ fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
+ }
+ return fullRequestURL, nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("Authorization", "Bearer "+meta.APIKey)
+ if meta.IsStream {
+ req.Header.Set("X-DashScope-SSE", "enable")
+ }
+ if c.GetString(common.ConfigKeyPlugin) != "" {
+ req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
+ }
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
+ return baiduEmbeddingRequest, nil
+ default:
+ baiduRequest := ConvertRequest(*request)
+ return baiduRequest, nil
+ }
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ err, usage = StreamHandler(c, resp)
+ } else {
+ switch meta.Mode {
+ case constant.RelayModeEmbeddings:
+ err, usage = EmbeddingHandler(c, resp)
+ default:
+ err, usage = Handler(c, resp)
+ }
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "ali"
+}
diff --git a/relay/channel/ali/constants.go b/relay/channel/ali/constants.go
new file mode 100644
index 00000000..16bcfca4
--- /dev/null
+++ b/relay/channel/ali/constants.go
@@ -0,0 +1,6 @@
+package ali
+
+var ModelList = []string{
+ "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
+ "text-embedding-v1",
+}
diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go
index f45a515a..b9625584 100644
--- a/relay/channel/ali/main.go
+++ b/relay/channel/ali/main.go
@@ -4,10 +4,13 @@ import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
- "one-api/common"
- "one-api/relay/channel/openai"
"strings"
)
@@ -15,7 +18,7 @@ import (
const EnableSearchModelSuffix = "-internet"
-func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
+func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
@@ -38,11 +41,12 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
Parameters: Parameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
+ Seed: uint64(request.Seed),
},
}
}
-func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
+func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
@@ -53,7 +57,7 @@ func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequ
}
}
-func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
@@ -66,8 +70,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSta
}
if aliResponse.Code != "" {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
@@ -93,7 +97,7 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
- Usage: openai.Usage{TotalTokens: response.Usage.TotalTokens},
+ Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
@@ -109,7 +113,7 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: response.Output.Text,
},
@@ -118,9 +122,9 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: response.RequestId,
Object: "chat.completion",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
- Usage: openai.Usage{
+ Usage: model.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
@@ -139,15 +143,15 @@ func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletions
response := openai.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "qwen",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
- var usage openai.Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -185,7 +189,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var aliResponse ChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
@@ -198,7 +202,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -215,7 +219,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, &usage
}
-func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -230,8 +234,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go
new file mode 100644
index 00000000..4b873715
--- /dev/null
+++ b/relay/channel/anthropic/adaptor.go
@@ -0,0 +1,65 @@
+package anthropic
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("x-api-key", meta.APIKey)
+ anthropicVersion := c.Request.Header.Get("anthropic-version")
+ if anthropicVersion == "" {
+ anthropicVersion = "2023-06-01"
+ }
+ req.Header.Set("anthropic-version", anthropicVersion)
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return ConvertRequest(*request), nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ var responseText string
+ err, responseText = StreamHandler(c, resp)
+ usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ } else {
+ err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "authropic"
+}
diff --git a/relay/channel/anthropic/constants.go b/relay/channel/anthropic/constants.go
new file mode 100644
index 00000000..b98c15c2
--- /dev/null
+++ b/relay/channel/anthropic/constants.go
@@ -0,0 +1,5 @@
+package anthropic
+
+var ModelList = []string{
+ "claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
+}
diff --git a/relay/channel/anthropic/main.go b/relay/channel/anthropic/main.go
index a4272d7b..e2c575fa 100644
--- a/relay/channel/anthropic/main.go
+++ b/relay/channel/anthropic/main.go
@@ -5,10 +5,13 @@ import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
- "one-api/common"
- "one-api/relay/channel/openai"
"strings"
)
@@ -23,7 +26,7 @@ func stopReasonClaude2OpenAI(reason string) string {
}
}
-func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request {
+func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeRequest := Request{
Model: textRequest.Model,
Prompt: "",
@@ -70,7 +73,7 @@ func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletio
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
@@ -78,18 +81,18 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := openai.TextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
- responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
- createdTime := common.GetTimestamp()
+ responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
+ createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -125,7 +128,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var claudeResponse Response
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
@@ -134,7 +137,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
@@ -151,7 +154,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, responseText
}
-func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -166,8 +169,8 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
@@ -177,9 +180,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
- fullTextResponse.Model = model
- completionTokens := openai.CountTokenText(claudeResponse.Completion, model)
- usage := openai.Usage{
+ fullTextResponse.Model = modelName
+ completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName)
+ usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go
new file mode 100644
index 00000000..d2d06ce0
--- /dev/null
+++ b/relay/channel/baidu/adaptor.go
@@ -0,0 +1,93 @@
+package baidu
+
+import (
+ "errors"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
+ var fullRequestURL string
+ switch meta.ActualModelName {
+ case "ERNIE-Bot-4":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
+ case "ERNIE-Bot-8K":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
+ case "ERNIE-Bot":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
+ case "ERNIE-Speed":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
+ case "ERNIE-Bot-turbo":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
+ case "BLOOMZ-7B":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
+ case "Embedding-V1":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
+ }
+ var accessToken string
+ var err error
+ if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
+ return "", err
+ }
+ fullRequestURL += "?access_token=" + accessToken
+ return fullRequestURL, nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("Authorization", "Bearer "+meta.APIKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
+ return baiduEmbeddingRequest, nil
+ default:
+ baiduRequest := ConvertRequest(*request)
+ return baiduRequest, nil
+ }
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ err, usage = StreamHandler(c, resp)
+ } else {
+ switch meta.Mode {
+ case constant.RelayModeEmbeddings:
+ err, usage = EmbeddingHandler(c, resp)
+ default:
+ err, usage = Handler(c, resp)
+ }
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "baidu"
+}
diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go
new file mode 100644
index 00000000..0fa8f2d6
--- /dev/null
+++ b/relay/channel/baidu/constants.go
@@ -0,0 +1,10 @@
+package baidu
+
+var ModelList = []string{
+ "ERNIE-Bot-4",
+ "ERNIE-Bot-8K",
+ "ERNIE-Bot",
+ "ERNIE-Speed",
+ "ERNIE-Bot-turbo",
+ "Embedding-V1",
+}
diff --git a/relay/channel/baidu/main.go b/relay/channel/baidu/main.go
index 47969492..4f2b13fc 100644
--- a/relay/channel/baidu/main.go
+++ b/relay/channel/baidu/main.go
@@ -6,12 +6,14 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
- "one-api/common"
- "one-api/relay/channel/openai"
- "one-api/relay/constant"
- "one-api/relay/util"
"strings"
"sync"
"time"
@@ -19,49 +21,49 @@ import (
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
-type BaiduTokenResponse struct {
+type TokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}
-type BaiduMessage struct {
+type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
-type BaiduChatRequest struct {
- Messages []BaiduMessage `json:"messages"`
- Stream bool `json:"stream"`
- UserId string `json:"user_id,omitempty"`
+type ChatRequest struct {
+ Messages []Message `json:"messages"`
+ Stream bool `json:"stream"`
+ UserId string `json:"user_id,omitempty"`
}
-type BaiduError struct {
+type Error struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
var baiduTokenStore sync.Map
-func ConvertRequest(request openai.GeneralOpenAIRequest) *BaiduChatRequest {
- messages := make([]BaiduMessage, 0, len(request.Messages))
+func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
+ messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
- messages = append(messages, BaiduMessage{
+ messages = append(messages, Message{
Role: "user",
Content: message.StringContent(),
})
- messages = append(messages, BaiduMessage{
+ messages = append(messages, Message{
Role: "assistant",
Content: "Okay",
})
} else {
- messages = append(messages, BaiduMessage{
+ messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
}
- return &BaiduChatRequest{
+ return &ChatRequest{
Messages: messages,
Stream: request.Stream,
}
@@ -70,7 +72,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *BaiduChatRequest {
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: response.Result,
},
@@ -102,7 +104,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatC
return &response
}
-func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
+func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Input: request.ParseInput(),
}
@@ -125,8 +127,8 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin
return &openAIEmbeddingResponse
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
- var usage openai.Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -160,7 +162,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var baiduResponse ChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
@@ -171,7 +173,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -188,7 +190,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, &usage
}
-func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -203,8 +205,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -225,7 +227,7 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return nil, &fullTextResponse.Usage
}
-func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -240,8 +242,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSta
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
diff --git a/relay/channel/baidu/model.go b/relay/channel/baidu/model.go
index caaebafb..cc1feb2f 100644
--- a/relay/channel/baidu/model.go
+++ b/relay/channel/baidu/model.go
@@ -1,19 +1,19 @@
package baidu
import (
- "one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
"time"
)
type ChatResponse struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Result string `json:"result"`
- IsTruncated bool `json:"is_truncated"`
- NeedClearHistory bool `json:"need_clear_history"`
- Usage openai.Usage `json:"usage"`
- BaiduError
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Result string `json:"result"`
+ IsTruncated bool `json:"is_truncated"`
+ NeedClearHistory bool `json:"need_clear_history"`
+ Usage model.Usage `json:"usage"`
+ Error
}
type ChatStreamResponse struct {
@@ -37,8 +37,8 @@ type EmbeddingResponse struct {
Object string `json:"object"`
Created int64 `json:"created"`
Data []EmbeddingData `json:"data"`
- Usage openai.Usage `json:"usage"`
- BaiduError
+ Usage model.Usage `json:"usage"`
+ Error
}
type AccessToken struct {
diff --git a/relay/channel/common.go b/relay/channel/common.go
new file mode 100644
index 00000000..c6e1abf2
--- /dev/null
+++ b/relay/channel/common.go
@@ -0,0 +1,51 @@
+package channel
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) {
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+ if meta.IsStream && c.Request.Header.Get("Accept") == "" {
+ req.Header.Set("Accept", "text/event-stream")
+ }
+}
+
+func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ fullRequestURL, err := a.GetRequestURL(meta)
+ if err != nil {
+ return nil, fmt.Errorf("get request url failed: %w", err)
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("new request failed: %w", err)
+ }
+ err = a.SetupRequestHeader(c, req, meta)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ resp, err := DoRequest(c, req)
+ if err != nil {
+ return nil, fmt.Errorf("do request failed: %w", err)
+ }
+ return resp, nil
+}
+
+func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
+ resp, err := util.HTTPClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ if resp == nil {
+ return nil, errors.New("resp is nil")
+ }
+ _ = req.Body.Close()
+ _ = c.Request.Body.Close()
+ return resp, nil
+}
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
new file mode 100644
index 00000000..f3305e5d
--- /dev/null
+++ b/relay/channel/gemini/adaptor.go
@@ -0,0 +1,66 @@
+package gemini
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/helper"
+ channelhelper "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ version := helper.AssignOrDefault(meta.APIVersion, "v1")
+ action := "generateContent"
+ if meta.IsStream {
+ action = "streamGenerateContent"
+ }
+ return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channelhelper.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("x-goog-api-key", meta.APIKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return ConvertRequest(*request), nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channelhelper.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ var responseText string
+ err, responseText = StreamHandler(c, resp)
+ usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ } else {
+ err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "google gemini"
+}
diff --git a/relay/channel/gemini/constants.go b/relay/channel/gemini/constants.go
new file mode 100644
index 00000000..5bb0c168
--- /dev/null
+++ b/relay/channel/gemini/constants.go
@@ -0,0 +1,6 @@
+package gemini
+
+var ModelList = []string{
+ "gemini-pro",
+ "gemini-pro-vision",
+}
diff --git a/relay/channel/google/gemini.go b/relay/channel/gemini/main.go
similarity index 69%
rename from relay/channel/google/gemini.go
rename to relay/channel/gemini/main.go
index f49caadf..c24694c8 100644
--- a/relay/channel/google/gemini.go
+++ b/relay/channel/gemini/main.go
@@ -1,15 +1,19 @@
-package google
+package gemini
import (
"bufio"
"encoding/json"
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/image"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
- "one-api/common"
- "one-api/common/image"
- "one-api/relay/channel/openai"
- "one-api/relay/constant"
"strings"
"github.com/gin-gonic/gin"
@@ -18,39 +22,39 @@ import (
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
- GeminiVisionMaxImageNum = 16
+ VisionMaxImageNum = 16
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
-func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest {
- geminiRequest := GeminiChatRequest{
- Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
- SafetySettings: []GeminiChatSafetySettings{
+func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
+ geminiRequest := ChatRequest{
+ Contents: make([]ChatContent, 0, len(textRequest.Messages)),
+ SafetySettings: []ChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
- Threshold: common.GeminiSafetySetting,
+ Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
- Threshold: common.GeminiSafetySetting,
+ Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
- Threshold: common.GeminiSafetySetting,
+ Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
- Threshold: common.GeminiSafetySetting,
+ Threshold: config.GeminiSafetySetting,
},
},
- GenerationConfig: GeminiChatGenerationConfig{
+ GenerationConfig: ChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
- geminiRequest.Tools = []GeminiChatTools{
+ geminiRequest.Tools = []ChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
@@ -58,30 +62,30 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
- content := GeminiChatContent{
+ content := ChatContent{
Role: message.Role,
- Parts: []GeminiPart{
+ Parts: []Part{
{
Text: message.StringContent(),
},
},
}
openaiContent := message.ParseContent()
- var parts []GeminiPart
+ var parts []Part
imageNum := 0
for _, part := range openaiContent {
- if part.Type == openai.ContentTypeText {
- parts = append(parts, GeminiPart{
+ if part.Type == model.ContentTypeText {
+ parts = append(parts, Part{
Text: part.Text,
})
- } else if part.Type == openai.ContentTypeImageURL {
+ } else if part.Type == model.ContentTypeImageURL {
imageNum += 1
- if imageNum > GeminiVisionMaxImageNum {
+ if imageNum > VisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
- parts = append(parts, GeminiPart{
- InlineData: &GeminiInlineData{
+ parts = append(parts, Part{
+ InlineData: &InlineData{
MimeType: mimeType,
Data: data,
},
@@ -103,9 +107,9 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
- geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
+ geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
Role: "model",
- Parts: []GeminiPart{
+ Parts: []Part{
{
Text: "Okay",
},
@@ -118,12 +122,12 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
return &geminiRequest
}
-type GeminiChatResponse struct {
- Candidates []GeminiChatCandidate `json:"candidates"`
- PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
+type ChatResponse struct {
+ Candidates []ChatCandidate `json:"candidates"`
+ PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
}
-func (g *GeminiChatResponse) GetResponseText() string {
+func (g *ChatResponse) GetResponseText() string {
if g == nil {
return ""
}
@@ -133,33 +137,33 @@ func (g *GeminiChatResponse) GetResponseText() string {
return ""
}
-type GeminiChatCandidate struct {
- Content GeminiChatContent `json:"content"`
- FinishReason string `json:"finishReason"`
- Index int64 `json:"index"`
- SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
+type ChatCandidate struct {
+ Content ChatContent `json:"content"`
+ FinishReason string `json:"finishReason"`
+ Index int64 `json:"index"`
+ SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
-type GeminiChatSafetyRating struct {
+type ChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
-type GeminiChatPromptFeedback struct {
- SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
+type ChatPromptFeedback struct {
+ SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
-func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse {
+func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: "",
},
@@ -173,7 +177,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextRespons
return &fullTextResponse
}
-func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse {
+func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &constant.StopFinishReason
@@ -184,7 +188,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai
return &response
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
@@ -229,15 +233,15 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := openai.ChatCompletionsStreamResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "gemini-pro",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -254,7 +258,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, responseText
}
-func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -263,14 +267,14 @@ func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
- var geminiResponse GeminiChatResponse
+ var geminiResponse ChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: "No candidates returned",
Type: "server_error",
Param: "",
@@ -280,9 +284,9 @@ func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
- fullTextResponse.Model = model
- completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model)
- usage := openai.Usage{
+ fullTextResponse.Model = modelName
+ completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName)
+ usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
diff --git a/relay/channel/gemini/model.go b/relay/channel/gemini/model.go
new file mode 100644
index 00000000..d1e3c4fd
--- /dev/null
+++ b/relay/channel/gemini/model.go
@@ -0,0 +1,41 @@
+package gemini
+
+type ChatRequest struct {
+ Contents []ChatContent `json:"contents"`
+ SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"`
+ GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"`
+ Tools []ChatTools `json:"tools,omitempty"`
+}
+
+type InlineData struct {
+ MimeType string `json:"mimeType"`
+ Data string `json:"data"`
+}
+
+type Part struct {
+ Text string `json:"text,omitempty"`
+ InlineData *InlineData `json:"inlineData,omitempty"`
+}
+
+type ChatContent struct {
+ Role string `json:"role,omitempty"`
+ Parts []Part `json:"parts"`
+}
+
+type ChatSafetySettings struct {
+ Category string `json:"category"`
+ Threshold string `json:"threshold"`
+}
+
+type ChatTools struct {
+ FunctionDeclarations any `json:"functionDeclarations,omitempty"`
+}
+
+type ChatGenerationConfig struct {
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+ TopK float64 `json:"topK,omitempty"`
+ MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
+ CandidateCount int `json:"candidateCount,omitempty"`
+ StopSequences []string `json:"stopSequences,omitempty"`
+}
diff --git a/relay/channel/google/model.go b/relay/channel/google/model.go
deleted file mode 100644
index 694c2dd1..00000000
--- a/relay/channel/google/model.go
+++ /dev/null
@@ -1,80 +0,0 @@
-package google
-
-import (
- "one-api/relay/channel/openai"
-)
-
-type GeminiChatRequest struct {
- Contents []GeminiChatContent `json:"contents"`
- SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
- GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
- Tools []GeminiChatTools `json:"tools,omitempty"`
-}
-
-type GeminiInlineData struct {
- MimeType string `json:"mimeType"`
- Data string `json:"data"`
-}
-
-type GeminiPart struct {
- Text string `json:"text,omitempty"`
- InlineData *GeminiInlineData `json:"inlineData,omitempty"`
-}
-
-type GeminiChatContent struct {
- Role string `json:"role,omitempty"`
- Parts []GeminiPart `json:"parts"`
-}
-
-type GeminiChatSafetySettings struct {
- Category string `json:"category"`
- Threshold string `json:"threshold"`
-}
-
-type GeminiChatTools struct {
- FunctionDeclarations any `json:"functionDeclarations,omitempty"`
-}
-
-type GeminiChatGenerationConfig struct {
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"topP,omitempty"`
- TopK float64 `json:"topK,omitempty"`
- MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- StopSequences []string `json:"stopSequences,omitempty"`
-}
-
-type PaLMChatMessage struct {
- Author string `json:"author"`
- Content string `json:"content"`
-}
-
-type PaLMFilter struct {
- Reason string `json:"reason"`
- Message string `json:"message"`
-}
-
-type PaLMPrompt struct {
- Messages []PaLMChatMessage `json:"messages"`
-}
-
-type PaLMChatRequest struct {
- Prompt PaLMPrompt `json:"prompt"`
- Temperature float64 `json:"temperature,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- TopP float64 `json:"topP,omitempty"`
- TopK int `json:"topK,omitempty"`
-}
-
-type PaLMError struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Status string `json:"status"`
-}
-
-type PaLMChatResponse struct {
- Candidates []PaLMChatMessage `json:"candidates"`
- Messages []openai.Message `json:"messages"`
- Filters []PaLMFilter `json:"filters"`
- Error PaLMError `json:"error"`
-}
diff --git a/relay/channel/interface.go b/relay/channel/interface.go
new file mode 100644
index 00000000..e25db83f
--- /dev/null
+++ b/relay/channel/interface.go
@@ -0,0 +1,20 @@
+package channel
+
+import (
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+type Adaptor interface {
+ Init(meta *util.RelayMeta)
+ GetRequestURL(meta *util.RelayMeta) (string, error)
+ SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error
+ ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
+ DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error)
+ DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
+ GetModelList() []string
+ GetChannelName() string
+}
diff --git a/relay/channel/moonshot/constants.go b/relay/channel/moonshot/constants.go
new file mode 100644
index 00000000..1b86f0fa
--- /dev/null
+++ b/relay/channel/moonshot/constants.go
@@ -0,0 +1,7 @@
+package moonshot
+
+var ModelList = []string{
+ "moonshot-v1-8k",
+ "moonshot-v1-32k",
+ "moonshot-v1-128k",
+}
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
new file mode 100644
index 00000000..1313e317
--- /dev/null
+++ b/relay/channel/openai/adaptor.go
@@ -0,0 +1,103 @@
+package openai
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/channel/ai360"
+ "github.com/songquanpeng/one-api/relay/channel/moonshot"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+ "strings"
+)
+
+type Adaptor struct {
+ ChannelType int
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+ a.ChannelType = meta.ChannelType
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ if meta.ChannelType == common.ChannelTypeAzure {
+ // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
+ requestURL := strings.Split(meta.RequestURLPath, "?")[0]
+ requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
+ task := strings.TrimPrefix(requestURL, "/v1/")
+ model_ := meta.ActualModelName
+ model_ = strings.Replace(model_, ".", "", -1)
+ // https://github.com/songquanpeng/one-api/issues/67
+ model_ = strings.TrimSuffix(model_, "-0301")
+ model_ = strings.TrimSuffix(model_, "-0314")
+ model_ = strings.TrimSuffix(model_, "-0613")
+
+ requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
+ return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
+ }
+ return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ if meta.ChannelType == common.ChannelTypeAzure {
+ req.Header.Set("api-key", meta.APIKey)
+ return nil
+ }
+ req.Header.Set("Authorization", "Bearer "+meta.APIKey)
+ if meta.ChannelType == common.ChannelTypeOpenRouter {
+ req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
+ req.Header.Set("X-Title", "One API")
+ }
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ var responseText string
+ err, responseText = StreamHandler(c, resp, meta.Mode)
+ usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ } else {
+ err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ switch a.ChannelType {
+ case common.ChannelType360:
+ return ai360.ModelList
+ case common.ChannelTypeMoonshot:
+ return moonshot.ModelList
+ default:
+ return ModelList
+ }
+}
+
+func (a *Adaptor) GetChannelName() string {
+ switch a.ChannelType {
+ case common.ChannelTypeAzure:
+ return "azure"
+ case common.ChannelType360:
+ return "360"
+ case common.ChannelTypeMoonshot:
+ return "moonshot"
+ default:
+ return "openai"
+ }
+}
diff --git a/relay/channel/openai/constants.go b/relay/channel/openai/constants.go
new file mode 100644
index 00000000..ea236ea1
--- /dev/null
+++ b/relay/channel/openai/constants.go
@@ -0,0 +1,19 @@
+package openai
+
+var ModelList = []string{
+ "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
+ "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
+ "gpt-3.5-turbo-instruct",
+ "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
+ "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
+ "gpt-4-turbo-preview",
+ "gpt-4-vision-preview",
+ "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
+ "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
+ "text-moderation-latest", "text-moderation-stable",
+ "text-davinci-edit-001",
+ "davinci-002", "babbage-002",
+ "dall-e-2", "dall-e-3",
+ "whisper-1",
+ "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
+}
diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go
new file mode 100644
index 00000000..9bca8cab
--- /dev/null
+++ b/relay/channel/openai/helper.go
@@ -0,0 +1,11 @@
+package openai
+
+import "github.com/songquanpeng/one-api/relay/model"
+
+func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
+ usage := &model.Usage{}
+ usage.PromptTokens = promptTokens
+ usage.CompletionTokens = CountTokenText(responseText, modeName)
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ return usage
+}
diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go
index 848a6fa4..fbe55cf9 100644
--- a/relay/channel/openai/main.go
+++ b/relay/channel/openai/main.go
@@ -5,14 +5,16 @@ import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
- "one-api/common"
- "one-api/relay/constant"
"strings"
)
-func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -46,7 +48,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
}
for _, choice := range streamResponse.Choices {
@@ -56,7 +58,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
@@ -89,7 +91,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi
return nil, responseText
}
-func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) {
+func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
var textResponse SlimTextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -104,7 +106,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
- return &ErrorWithStatusCode{
+ return &model.ErrorWithStatusCode{
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
@@ -132,9 +134,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
- completionTokens += CountTokenText(choice.Message.StringContent(), model)
+ completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
}
- textResponse.Usage = Usage{
+ textResponse.Usage = model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go
index c831ce19..c09f2334 100644
--- a/relay/channel/openai/model.go
+++ b/relay/channel/openai/model.go
@@ -1,15 +1,6 @@
package openai
-type Message struct {
- Role string `json:"role"`
- Content any `json:"content"`
- Name *string `json:"name,omitempty"`
-}
-
-type ImageURL struct {
- Url string `json:"url,omitempty"`
- Detail string `json:"detail,omitempty"`
-}
+import "github.com/songquanpeng/one-api/relay/model"
type TextContent struct {
Type string `json:"type,omitempty"`
@@ -17,142 +8,21 @@ type TextContent struct {
}
type ImageContent struct {
- Type string `json:"type,omitempty"`
- ImageURL *ImageURL `json:"image_url,omitempty"`
-}
-
-type OpenAIMessageContent struct {
- Type string `json:"type,omitempty"`
- Text string `json:"text"`
- ImageURL *ImageURL `json:"image_url,omitempty"`
-}
-
-func (m Message) IsStringContent() bool {
- _, ok := m.Content.(string)
- return ok
-}
-
-func (m Message) StringContent() string {
- content, ok := m.Content.(string)
- if ok {
- return content
- }
- contentList, ok := m.Content.([]any)
- if ok {
- var contentStr string
- for _, contentItem := range contentList {
- contentMap, ok := contentItem.(map[string]any)
- if !ok {
- continue
- }
- if contentMap["type"] == ContentTypeText {
- if subStr, ok := contentMap["text"].(string); ok {
- contentStr += subStr
- }
- }
- }
- return contentStr
- }
- return ""
-}
-
-func (m Message) ParseContent() []OpenAIMessageContent {
- var contentList []OpenAIMessageContent
- content, ok := m.Content.(string)
- if ok {
- contentList = append(contentList, OpenAIMessageContent{
- Type: ContentTypeText,
- Text: content,
- })
- return contentList
- }
- anyList, ok := m.Content.([]any)
- if ok {
- for _, contentItem := range anyList {
- contentMap, ok := contentItem.(map[string]any)
- if !ok {
- continue
- }
- switch contentMap["type"] {
- case ContentTypeText:
- if subStr, ok := contentMap["text"].(string); ok {
- contentList = append(contentList, OpenAIMessageContent{
- Type: ContentTypeText,
- Text: subStr,
- })
- }
- case ContentTypeImageURL:
- if subObj, ok := contentMap["image_url"].(map[string]any); ok {
- contentList = append(contentList, OpenAIMessageContent{
- Type: ContentTypeImageURL,
- ImageURL: &ImageURL{
- Url: subObj["url"].(string),
- },
- })
- }
- }
- }
- return contentList
- }
- return nil
-}
-
-type ResponseFormat struct {
- Type string `json:"type,omitempty"`
-}
-
-type GeneralOpenAIRequest struct {
- Model string `json:"model,omitempty"`
- Messages []Message `json:"messages,omitempty"`
- Prompt any `json:"prompt,omitempty"`
- Stream bool `json:"stream,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- N int `json:"n,omitempty"`
- Input any `json:"input,omitempty"`
- Instruction string `json:"instruction,omitempty"`
- Size string `json:"size,omitempty"`
- Functions any `json:"functions,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
- ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
- Seed float64 `json:"seed,omitempty"`
- Tools any `json:"tools,omitempty"`
- ToolChoice any `json:"tool_choice,omitempty"`
- User string `json:"user,omitempty"`
-}
-
-func (r GeneralOpenAIRequest) ParseInput() []string {
- if r.Input == nil {
- return nil
- }
- var input []string
- switch r.Input.(type) {
- case string:
- input = []string{r.Input.(string)}
- case []any:
- input = make([]string, 0, len(r.Input.([]any)))
- for _, item := range r.Input.([]any) {
- if str, ok := item.(string); ok {
- input = append(input, str)
- }
- }
- }
- return input
+ Type string `json:"type,omitempty"`
+ ImageURL *model.ImageURL `json:"image_url,omitempty"`
}
type ChatRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- MaxTokens int `json:"max_tokens"`
+ Model string `json:"model"`
+ Messages []model.Message `json:"messages"`
+ MaxTokens int `json:"max_tokens"`
}
type TextRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- Prompt string `json:"prompt"`
- MaxTokens int `json:"max_tokens"`
+ Model string `json:"model"`
+ Messages []model.Message `json:"messages"`
+ Prompt string `json:"prompt"`
+ MaxTokens int `json:"max_tokens"`
//Stream bool `json:"stream"`
}
@@ -201,43 +71,30 @@ type TextToSpeechRequest struct {
ResponseFormat string `json:"response_format"`
}
-type Usage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
-}
-
-type Error struct {
- Message string `json:"message"`
- Type string `json:"type"`
- Param string `json:"param"`
- Code any `json:"code"`
-}
-
-type ErrorWithStatusCode struct {
- Error
- StatusCode int `json:"status_code"`
+type UsageOrResponseText struct {
+ *model.Usage
+ ResponseText string
}
type SlimTextResponse struct {
- Choices []TextResponseChoice `json:"choices"`
- Usage `json:"usage"`
- Error Error `json:"error"`
+ Choices []TextResponseChoice `json:"choices"`
+ model.Usage `json:"usage"`
+ Error model.Error `json:"error"`
}
type TextResponseChoice struct {
- Index int `json:"index"`
- Message `json:"message"`
- FinishReason string `json:"finish_reason"`
+ Index int `json:"index"`
+ model.Message `json:"message"`
+ FinishReason string `json:"finish_reason"`
}
type TextResponse struct {
- Id string `json:"id"`
- Model string `json:"model,omitempty"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Choices []TextResponseChoice `json:"choices"`
- Usage `json:"usage"`
+ Id string `json:"id"`
+ Model string `json:"model,omitempty"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Choices []TextResponseChoice `json:"choices"`
+ model.Usage `json:"usage"`
}
type EmbeddingResponseItem struct {
@@ -247,10 +104,10 @@ type EmbeddingResponseItem struct {
}
type EmbeddingResponse struct {
- Object string `json:"object"`
- Data []EmbeddingResponseItem `json:"data"`
- Model string `json:"model"`
- Usage `json:"usage"`
+ Object string `json:"object"`
+ Data []EmbeddingResponseItem `json:"data"`
+ Model string `json:"model"`
+ model.Usage `json:"usage"`
}
type ImageResponse struct {
diff --git a/relay/channel/openai/token.go b/relay/channel/openai/token.go
index 4b40b228..0720425f 100644
--- a/relay/channel/openai/token.go
+++ b/relay/channel/openai/token.go
@@ -4,9 +4,12 @@ import (
"errors"
"fmt"
"github.com/pkoukk/tiktoken-go"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/image"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/model"
"math"
- "one-api/common"
- "one-api/common/image"
"strings"
)
@@ -15,17 +18,17 @@ var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
- common.SysLog("initializing token encoders")
+ logger.SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
+ logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
+ logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
- for model, _ := range common.ModelRatio {
+ for model := range common.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
@@ -34,7 +37,7 @@ func InitTokenEncoders() {
tokenEncoderMap[model] = nil
}
}
- common.SysLog("token encoders initialized")
+ logger.SysLog("token encoders initialized")
}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
@@ -45,7 +48,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
if ok {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
- common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
+ logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
@@ -55,13 +58,13 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
- if common.ApproximateTokenEnabled {
+ if config.ApproximateTokenEnabled {
return int(float64(len(text)) * 0.38)
}
return len(tokenEncoder.Encode(text, nil, nil))
}
-func CountTokenMessages(messages []Message, model string) int {
+func CountTokenMessages(messages []model.Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
@@ -99,7 +102,7 @@ func CountTokenMessages(messages []Message, model string) int {
}
imageTokens, err := countImageTokens(url, detail)
if err != nil {
- common.SysError("error counting image tokens: " + err.Error())
+ logger.SysError("error counting image tokens: " + err.Error())
} else {
tokenNum += imageTokens
}
diff --git a/relay/channel/openai/util.go b/relay/channel/openai/util.go
index 69ece6b3..ba0cab7d 100644
--- a/relay/channel/openai/util.go
+++ b/relay/channel/openai/util.go
@@ -1,12 +1,14 @@
package openai
-func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode {
- Error := Error{
+import "github.com/songquanpeng/one-api/relay/model"
+
+func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
+ Error := model.Error{
Message: err.Error(),
Type: "one_api_error",
Code: code,
}
- return &ErrorWithStatusCode{
+ return &model.ErrorWithStatusCode{
Error: Error,
StatusCode: statusCode,
}
diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go
new file mode 100644
index 00000000..efd0620c
--- /dev/null
+++ b/relay/channel/palm/adaptor.go
@@ -0,0 +1,60 @@
+package palm
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("x-goog-api-key", meta.APIKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return ConvertRequest(*request), nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ var responseText string
+ err, responseText = StreamHandler(c, resp)
+ usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ } else {
+ err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "google palm"
+}
diff --git a/relay/channel/palm/constants.go b/relay/channel/palm/constants.go
new file mode 100644
index 00000000..a8349362
--- /dev/null
+++ b/relay/channel/palm/constants.go
@@ -0,0 +1,5 @@
+package palm
+
+var ModelList = []string{
+ "PaLM-2",
+}
diff --git a/relay/channel/palm/model.go b/relay/channel/palm/model.go
new file mode 100644
index 00000000..f653022c
--- /dev/null
+++ b/relay/channel/palm/model.go
@@ -0,0 +1,40 @@
+package palm
+
+import (
+ "github.com/songquanpeng/one-api/relay/model"
+)
+
+type ChatMessage struct {
+ Author string `json:"author"`
+ Content string `json:"content"`
+}
+
+type Filter struct {
+ Reason string `json:"reason"`
+ Message string `json:"message"`
+}
+
+type Prompt struct {
+ Messages []ChatMessage `json:"messages"`
+}
+
+type ChatRequest struct {
+ Prompt Prompt `json:"prompt"`
+ Temperature float64 `json:"temperature,omitempty"`
+ CandidateCount int `json:"candidateCount,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+ TopK int `json:"topK,omitempty"`
+}
+
+type Error struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Status string `json:"status"`
+}
+
+type ChatResponse struct {
+ Candidates []ChatMessage `json:"candidates"`
+ Messages []model.Message `json:"messages"`
+ Filters []Filter `json:"filters"`
+ Error Error `json:"error"`
+}
diff --git a/relay/channel/google/palm.go b/relay/channel/palm/palm.go
similarity index 73%
rename from relay/channel/google/palm.go
rename to relay/channel/palm/palm.go
index 77d8cbd6..56738544 100644
--- a/relay/channel/google/palm.go
+++ b/relay/channel/palm/palm.go
@@ -1,23 +1,26 @@
-package google
+package palm
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
- "one-api/common"
- "one-api/relay/channel/openai"
- "one-api/relay/constant"
)
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
-func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest {
- palmRequest := PaLMChatRequest{
- Prompt: PaLMPrompt{
- Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
+func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
+ palmRequest := ChatRequest{
+ Prompt: Prompt{
+ Messages: make([]ChatMessage, 0, len(textRequest.Messages)),
},
Temperature: textRequest.Temperature,
CandidateCount: textRequest.N,
@@ -25,7 +28,7 @@ func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatReques
TopK: textRequest.MaxTokens,
}
for _, message := range textRequest.Messages {
- palmMessage := PaLMChatMessage{
+ palmMessage := ChatMessage{
Content: message.StringContent(),
}
if message.Role == "user" {
@@ -38,14 +41,14 @@ func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatReques
return &palmRequest
}
-func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
+func responsePaLM2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: candidate.Content,
},
@@ -56,7 +59,7 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
return &fullTextResponse
}
-func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse {
+func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
@@ -69,29 +72,29 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompl
return &response
}
-func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
- responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
- createdTime := common.GetTimestamp()
+ responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
+ createdTime := helper.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- common.SysError("error reading stream response: " + err.Error())
+ logger.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
err = resp.Body.Close()
if err != nil {
- common.SysError("error closing stream response: " + err.Error())
+ logger.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
- var palmResponse PaLMChatResponse
+ var palmResponse ChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
@@ -103,7 +106,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
@@ -128,7 +131,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt
return nil, responseText
}
-func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -137,14 +140,14 @@ func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
- var palmResponse PaLMChatResponse
+ var palmResponse ChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
@@ -154,9 +157,9 @@ func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
- fullTextResponse.Model = model
- completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model)
- usage := openai.Usage{
+ fullTextResponse.Model = modelName
+ completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, modelName)
+ usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go
new file mode 100644
index 00000000..f348674e
--- /dev/null
+++ b/relay/channel/tencent/adaptor.go
@@ -0,0 +1,76 @@
+package tencent
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+ "strings"
+)
+
+// https://cloud.tencent.com/document/api/1729/101837
+
+type Adaptor struct {
+ Sign string
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ req.Header.Set("Authorization", a.Sign)
+ req.Header.Set("X-TC-Action", meta.ActualModelName)
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ apiKey := c.Request.Header.Get("Authorization")
+ apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+ appId, secretId, secretKey, err := ParseConfig(apiKey)
+ if err != nil {
+ return nil, err
+ }
+ tencentRequest := ConvertRequest(*request)
+ tencentRequest.AppId = appId
+ tencentRequest.SecretId = secretId
+ // we have to calculate the sign here
+ a.Sign = GetSign(*tencentRequest, secretKey)
+ return tencentRequest, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ var responseText string
+ err, responseText = StreamHandler(c, resp)
+ usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
+ } else {
+ err, usage = Handler(c, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "tencent"
+}
diff --git a/relay/channel/tencent/constants.go b/relay/channel/tencent/constants.go
new file mode 100644
index 00000000..fe176c2c
--- /dev/null
+++ b/relay/channel/tencent/constants.go
@@ -0,0 +1,7 @@
+package tencent
+
+var ModelList = []string{
+ "ChatPro",
+ "ChatStd",
+ "hunyuan",
+}
diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go
index 60e275a9..05edac20 100644
--- a/relay/channel/tencent/main.go
+++ b/relay/channel/tencent/main.go
@@ -9,11 +9,14 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
- "one-api/common"
- "one-api/relay/channel/openai"
- "one-api/relay/constant"
"sort"
"strconv"
"strings"
@@ -21,7 +24,7 @@ import (
// https://cloud.tencent.com/document/product/1729/97732
-func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
+func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
@@ -46,9 +49,9 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
stream = 1
}
return &ChatRequest{
- Timestamp: common.GetTimestamp(),
- Expired: common.GetTimestamp() + 24*60*60,
- QueryID: common.GetUUID(),
+ Timestamp: helper.GetTimestamp(),
+ Expired: helper.GetTimestamp() + 24*60*60,
+ QueryID: helper.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
@@ -59,13 +62,13 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Object: "chat.completion",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Usage: response.Usage,
}
if len(response.Choices) > 0 {
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: response.Choices[0].Messages.Content,
},
@@ -79,7 +82,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
response := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "tencent-hunyuan",
}
if len(TencentResponse.Choices) > 0 {
@@ -93,7 +96,7 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom
return &response
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -131,7 +134,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var TencentResponse ChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
@@ -140,7 +143,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
}
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -157,7 +160,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, responseText
}
-func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var TencentResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -172,8 +175,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
@@ -189,6 +192,9 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
+ }
return nil, &fullTextResponse.Usage
}
@@ -222,7 +228,7 @@ func GetSign(req ChatRequest, secretKey string) string {
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
- sort.Sort(sort.StringSlice(params))
+ sort.Strings(params)
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
diff --git a/relay/channel/tencent/model.go b/relay/channel/tencent/model.go
index 511f3d97..71286be9 100644
--- a/relay/channel/tencent/model.go
+++ b/relay/channel/tencent/model.go
@@ -1,7 +1,7 @@
package tencent
import (
- "one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
@@ -56,7 +56,7 @@ type ChatResponse struct {
Choices []ResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
- Usage openai.Usage `json:"usage,omitempty"` // token 数量
+ Usage model.Usage `json:"usage,omitempty"` // token 数量
Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go
new file mode 100644
index 00000000..92d9d7d6
--- /dev/null
+++ b/relay/channel/xunfei/adaptor.go
@@ -0,0 +1,70 @@
+package xunfei
+
+import (
+ "errors"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+ "strings"
+)
+
+type Adaptor struct {
+ request *model.GeneralOpenAIRequest
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ return "", nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ // check DoResponse for auth part
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ a.request = request
+ return nil, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ // xunfei's request is not http request, so we don't need to do anything here
+ dummyResp := &http.Response{}
+ dummyResp.StatusCode = http.StatusOK
+ return dummyResp, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ splits := strings.Split(meta.APIKey, "|")
+ if len(splits) != 3 {
+ return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
+ }
+ if a.request == nil {
+ return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
+ }
+ if meta.IsStream {
+ err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2])
+ } else {
+ err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2])
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "xunfei"
+}
diff --git a/relay/channel/xunfei/constants.go b/relay/channel/xunfei/constants.go
new file mode 100644
index 00000000..31dcec71
--- /dev/null
+++ b/relay/channel/xunfei/constants.go
@@ -0,0 +1,9 @@
+package xunfei
+
+var ModelList = []string{
+ "SparkDesk",
+ "SparkDesk-v1.1",
+ "SparkDesk-v2.1",
+ "SparkDesk-v3.1",
+ "SparkDesk-v3.5",
+}
diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go
index 1cc0b664..d064b11d 100644
--- a/relay/channel/xunfei/main.go
+++ b/relay/channel/xunfei/main.go
@@ -8,12 +8,15 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"net/url"
- "one-api/common"
- "one-api/relay/channel/openai"
- "one-api/relay/constant"
"strings"
"time"
)
@@ -21,7 +24,7 @@ import (
// https://console.xfyun.cn/services/cbm
// https://www.xfyun.cn/doc/spark/Web.html
-func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
+func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
@@ -60,7 +63,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
}
choice := openai.TextResponseChoice{
Index: 0,
- Message: openai.Message{
+ Message: model.Message{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
@@ -68,7 +71,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
}
fullTextResponse := openai.TextResponse{
Object: "chat.completion",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
Usage: response.Payload.Usage.Text,
}
@@ -90,7 +93,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
}
response := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "SparkDesk",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
@@ -123,14 +126,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
return callUrl
}
-func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
- domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
+func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
+ domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
common.SetEventStreamHeaders(c)
- var usage openai.Usage
+ var usage model.Usage
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
@@ -140,7 +143,7 @@ func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appI
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -153,13 +156,13 @@ func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appI
return nil, &usage
}
-func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
- domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
+func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
+ domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
- var usage openai.Usage
+ var usage model.Usage
var content string
var xunfeiResponse ChatResponse
stop := false
@@ -195,7 +198,7 @@ func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId stri
return nil, &usage
}
-func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
+func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
@@ -215,20 +218,20 @@ func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl,
for {
_, msg, err := conn.ReadMessage()
if err != nil {
- common.SysError("error reading stream response: " + err.Error())
+ logger.SysError("error reading stream response: " + err.Error())
break
}
var response ChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
- common.SysError("error closing websocket connection: " + err.Error())
+ logger.SysError("error closing websocket connection: " + err.Error())
}
break
}
@@ -239,20 +242,45 @@ func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl,
return dataChan, stopChan, nil
}
-func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
+func getAPIVersion(c *gin.Context, modelName string) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
- if apiVersion == "" {
- apiVersion = c.GetString("api_version")
+ if apiVersion != "" {
+ return apiVersion
}
- if apiVersion == "" {
- apiVersion = "v1.1"
- common.SysLog("api_version not found, use default: " + apiVersion)
+ parts := strings.Split(modelName, "-")
+ if len(parts) == 2 {
+ apiVersion = parts[1]
+ return apiVersion
+
}
- domain := "general"
- if apiVersion != "v1.1" {
- domain += strings.Split(apiVersion, ".")[0]
+ apiVersion = c.GetString(common.ConfigKeyAPIVersion)
+ if apiVersion != "" {
+ return apiVersion
}
+ apiVersion = "v1.1"
+ logger.SysLog("api_version not found, using default: " + apiVersion)
+ return apiVersion
+}
+
+// https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
+func apiVersion2domain(apiVersion string) string {
+ switch apiVersion {
+ case "v1.1":
+ return "general"
+ case "v2.1":
+ return "generalv2"
+ case "v3.1":
+ return "generalv3"
+ case "v3.5":
+ return "generalv3.5"
+ }
+ return "general" + apiVersion
+}
+
+func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
+ apiVersion := getAPIVersion(c, modelName)
+ domain := apiVersion2domain(apiVersion)
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl
}
diff --git a/relay/channel/xunfei/model.go b/relay/channel/xunfei/model.go
index 0ca42818..1266739d 100644
--- a/relay/channel/xunfei/model.go
+++ b/relay/channel/xunfei/model.go
@@ -1,7 +1,7 @@
package xunfei
import (
- "one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
@@ -55,7 +55,7 @@ type ChatResponse struct {
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
- Text openai.Usage `json:"text"`
+ Text model.Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go
new file mode 100644
index 00000000..7a822853
--- /dev/null
+++ b/relay/channel/zhipu/adaptor.go
@@ -0,0 +1,62 @@
+package zhipu
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "io"
+ "net/http"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Init(meta *util.RelayMeta) {
+
+}
+
+func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+ method := "invoke"
+ if meta.IsStream {
+ method = "sse-invoke"
+ }
+ return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+ channel.SetupCommonRequestHeader(c, req, meta)
+ token := GetToken(meta.APIKey)
+ req.Header.Set("Authorization", token)
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return ConvertRequest(*request), nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+ if meta.IsStream {
+ err, usage = StreamHandler(c, resp)
+ } else {
+ err, usage = Handler(c, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return "zhipu"
+}
diff --git a/relay/channel/zhipu/constants.go b/relay/channel/zhipu/constants.go
new file mode 100644
index 00000000..f0367b82
--- /dev/null
+++ b/relay/channel/zhipu/constants.go
@@ -0,0 +1,5 @@
+package zhipu
+
+var ModelList = []string{
+ "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
+}
diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go
index 3dc613a4..7c3e83f3 100644
--- a/relay/channel/zhipu/main.go
+++ b/relay/channel/zhipu/main.go
@@ -5,11 +5,14 @@ import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
- "one-api/common"
- "one-api/relay/channel/openai"
- "one-api/relay/constant"
"strings"
"sync"
"time"
@@ -34,7 +37,7 @@ func GetToken(apikey string) string {
split := strings.Split(apikey, ".")
if len(split) != 2 {
- common.SysError("invalid zhipu key: " + apikey)
+ logger.SysError("invalid zhipu key: " + apikey)
return ""
}
@@ -70,7 +73,7 @@ func GetToken(apikey string) string {
return tokenString
}
-func ConvertRequest(request openai.GeneralOpenAIRequest) *Request {
+func ConvertRequest(request model.GeneralOpenAIRequest) *Request {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
@@ -101,14 +104,14 @@ func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: response.Data.TaskId,
Object: "chat.completion",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)),
Usage: response.Data.Usage,
}
for i, choice := range response.Data.Choices {
openaiChoice := openai.TextResponseChoice{
Index: i,
- Message: openai.Message{
+ Message: model.Message{
Role: choice.Role,
Content: strings.Trim(choice.Content, "\""),
},
@@ -127,29 +130,29 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStr
choice.Delta.Content = zhipuResponse
response := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "chatglm",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
-func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) {
+func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *model.Usage) {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = ""
choice.FinishReason = &constant.StopFinishReason
response := openai.ChatCompletionsStreamResponse{
Id: zhipuResponse.RequestId,
Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
+ Created: helper.GetTimestamp(),
Model: "chatglm",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response, &zhipuResponse.Usage
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
- var usage *openai.Usage
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var usage *model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -193,7 +196,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -202,13 +205,13 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var zhipuResponse StreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
@@ -226,7 +229,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, usage
}
-func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
+func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var zhipuResponse Response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -241,8 +244,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if !zhipuResponse.Success {
- return &openai.ErrorWithStatusCode{
- Error: openai.Error{
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",
diff --git a/relay/channel/zhipu/model.go b/relay/channel/zhipu/model.go
index 08a5ec5f..b63e1d6f 100644
--- a/relay/channel/zhipu/model.go
+++ b/relay/channel/zhipu/model.go
@@ -1,7 +1,7 @@
package zhipu
import (
- "one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/model"
"time"
)
@@ -19,11 +19,11 @@ type Request struct {
}
type ResponseData struct {
- TaskId string `json:"task_id"`
- RequestId string `json:"request_id"`
- TaskStatus string `json:"task_status"`
- Choices []Message `json:"choices"`
- openai.Usage `json:"usage"`
+ TaskId string `json:"task_id"`
+ RequestId string `json:"request_id"`
+ TaskStatus string `json:"task_status"`
+ Choices []Message `json:"choices"`
+ model.Usage `json:"usage"`
}
type Response struct {
@@ -34,10 +34,10 @@ type Response struct {
}
type StreamMetaResponse struct {
- RequestId string `json:"request_id"`
- TaskId string `json:"task_id"`
- TaskStatus string `json:"task_status"`
- openai.Usage `json:"usage"`
+ RequestId string `json:"request_id"`
+ TaskId string `json:"task_id"`
+ TaskStatus string `json:"task_status"`
+ model.Usage `json:"usage"`
}
type tokenData struct {
diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go
new file mode 100644
index 00000000..d2184dac
--- /dev/null
+++ b/relay/constant/api_type.go
@@ -0,0 +1,45 @@
+package constant
+
+import (
+ "github.com/songquanpeng/one-api/common"
+)
+
+const (
+ APITypeOpenAI = iota
+ APITypeAnthropic
+ APITypePaLM
+ APITypeBaidu
+ APITypeZhipu
+ APITypeAli
+ APITypeXunfei
+ APITypeAIProxyLibrary
+ APITypeTencent
+ APITypeGemini
+
+ APITypeDummy // this one is only for count, do not add any channel after this
+)
+
+func ChannelType2APIType(channelType int) int {
+ apiType := APITypeOpenAI
+ switch channelType {
+ case common.ChannelTypeAnthropic:
+ apiType = APITypeAnthropic
+ case common.ChannelTypeBaidu:
+ apiType = APITypeBaidu
+ case common.ChannelTypePaLM:
+ apiType = APITypePaLM
+ case common.ChannelTypeZhipu:
+ apiType = APITypeZhipu
+ case common.ChannelTypeAli:
+ apiType = APITypeAli
+ case common.ChannelTypeXunfei:
+ apiType = APITypeXunfei
+ case common.ChannelTypeAIProxyLibrary:
+ apiType = APITypeAIProxyLibrary
+ case common.ChannelTypeTencent:
+ apiType = APITypeTencent
+ case common.ChannelTypeGemini:
+ apiType = APITypeGemini
+ }
+ return apiType
+}
diff --git a/relay/constant/common.go b/relay/constant/common.go
new file mode 100644
index 00000000..b6606cc6
--- /dev/null
+++ b/relay/constant/common.go
@@ -0,0 +1,3 @@
+package constant
+
+var StopFinishReason = "stop"
diff --git a/relay/constant/main.go b/relay/constant/main.go
deleted file mode 100644
index b3aeaaff..00000000
--- a/relay/constant/main.go
+++ /dev/null
@@ -1,16 +0,0 @@
-package constant
-
-const (
- RelayModeUnknown = iota
- RelayModeChatCompletions
- RelayModeCompletions
- RelayModeEmbeddings
- RelayModeModerations
- RelayModeImagesGenerations
- RelayModeEdits
- RelayModeAudioSpeech
- RelayModeAudioTranscription
- RelayModeAudioTranslation
-)
-
-var StopFinishReason = "stop"
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
new file mode 100644
index 00000000..5e2fe574
--- /dev/null
+++ b/relay/constant/relay_mode.go
@@ -0,0 +1,42 @@
+package constant
+
+import "strings"
+
+const (
+ RelayModeUnknown = iota
+ RelayModeChatCompletions
+ RelayModeCompletions
+ RelayModeEmbeddings
+ RelayModeModerations
+ RelayModeImagesGenerations
+ RelayModeEdits
+ RelayModeAudioSpeech
+ RelayModeAudioTranscription
+ RelayModeAudioTranslation
+)
+
+func Path2RelayMode(path string) int {
+ relayMode := RelayModeUnknown
+ if strings.HasPrefix(path, "/v1/chat/completions") {
+ relayMode = RelayModeChatCompletions
+ } else if strings.HasPrefix(path, "/v1/completions") {
+ relayMode = RelayModeCompletions
+ } else if strings.HasPrefix(path, "/v1/embeddings") {
+ relayMode = RelayModeEmbeddings
+ } else if strings.HasSuffix(path, "embeddings") {
+ relayMode = RelayModeEmbeddings
+ } else if strings.HasPrefix(path, "/v1/moderations") {
+ relayMode = RelayModeModerations
+ } else if strings.HasPrefix(path, "/v1/images/generations") {
+ relayMode = RelayModeImagesGenerations
+ } else if strings.HasPrefix(path, "/v1/edits") {
+ relayMode = RelayModeEdits
+ } else if strings.HasPrefix(path, "/v1/audio/speech") {
+ relayMode = RelayModeAudioSpeech
+ } else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
+ relayMode = RelayModeAudioTranscription
+ } else if strings.HasPrefix(path, "/v1/audio/translations") {
+ relayMode = RelayModeAudioTranslation
+ }
+ return relayMode
+}
diff --git a/relay/controller/audio.go b/relay/controller/audio.go
index 08d9af2a..ee8771c9 100644
--- a/relay/controller/audio.go
+++ b/relay/controller/audio.go
@@ -8,17 +8,20 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
- "one-api/common"
- "one-api/model"
- "one-api/relay/channel/openai"
- "one-api/relay/constant"
- "one-api/relay/util"
"strings"
)
-func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
+func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
audioModel := "whisper-1"
tokenId := c.GetInt("token_id")
@@ -53,7 +56,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
default:
- preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
+ preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio)
}
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
@@ -102,7 +105,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
- apiVersion := util.GetAPIVersion(c)
+ apiVersion := util.GetAzureAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
}
@@ -191,7 +194,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
// negative means add quota back for token & user
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
+ logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
}
}()
}(c.Request.Context())
diff --git a/relay/controller/helper.go b/relay/controller/helper.go
new file mode 100644
index 00000000..a06b2768
--- /dev/null
+++ b/relay/controller/helper.go
@@ -0,0 +1,122 @@
+package controller
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
+ "math"
+ "net/http"
+)
+
+func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
+ textRequest := &relaymodel.GeneralOpenAIRequest{}
+ err := common.UnmarshalBodyReusable(c, textRequest)
+ if err != nil {
+ return nil, err
+ }
+ if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
+ textRequest.Model = "text-moderation-latest"
+ }
+ if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
+ textRequest.Model = c.Param("model")
+ }
+ err = util.ValidateTextRequest(textRequest, relayMode)
+ if err != nil {
+ return nil, err
+ }
+ return textRequest, nil
+}
+
+func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
+ switch relayMode {
+ case constant.RelayModeChatCompletions:
+ return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
+ case constant.RelayModeCompletions:
+ return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
+ case constant.RelayModeModerations:
+ return openai.CountTokenInput(textRequest.Input, textRequest.Model)
+ }
+ return 0
+}
+
+func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
+ preConsumedTokens := config.PreConsumedQuota
+ if textRequest.MaxTokens != 0 {
+ preConsumedTokens = promptTokens + textRequest.MaxTokens
+ }
+ return int(float64(preConsumedTokens) * ratio)
+}
+
+func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) {
+ preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
+
+ userQuota, err := model.CacheGetUserQuota(meta.UserId)
+ if err != nil {
+ return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+ }
+ if userQuota-preConsumedQuota < 0 {
+ return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+ }
+ err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
+ if err != nil {
+ return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+ }
+ if userQuota > 100*preConsumedQuota {
+ // in this case, we do not pre-consume quota
+ // because the user has enough quota
+ preConsumedQuota = 0
+ logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
+ }
+ if preConsumedQuota > 0 {
+ err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
+ if err != nil {
+ return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+ }
+ }
+ return preConsumedQuota, nil
+}
+
+func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
+ if usage == nil {
+ logger.Error(ctx, "usage is nil, which is unexpected")
+ return
+ }
+ quota := 0
+ completionRatio := common.GetCompletionRatio(textRequest.Model)
+ promptTokens := usage.PromptTokens
+ completionTokens := usage.CompletionTokens
+ quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
+ if ratio != 0 && quota <= 0 {
+ quota = 1
+ }
+ totalTokens := promptTokens + completionTokens
+ if totalTokens == 0 {
+ // in this case, must be some error happened
+ // we cannot just return, because we may have to return the pre-consumed quota
+ quota = 0
+ }
+ quotaDelta := quota - preConsumedQuota
+ err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
+ if err != nil {
+ logger.Error(ctx, "error consuming token remain quota: "+err.Error())
+ }
+ err = model.CacheUpdateUserQuota(meta.UserId)
+ if err != nil {
+ logger.Error(ctx, "error update user quota cache: "+err.Error())
+ }
+ if quota != 0 {
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
+ model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
+ model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
+ model.UpdateChannelUsedQuota(meta.ChannelId, quota)
+ }
+}
diff --git a/relay/controller/image.go b/relay/controller/image.go
index be5fc3dd..6ec368f5 100644
--- a/relay/controller/image.go
+++ b/relay/controller/image.go
@@ -6,12 +6,14 @@ import (
"encoding/json"
"errors"
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
- "one-api/common"
- "one-api/model"
- "one-api/relay/channel/openai"
- "one-api/relay/util"
"strings"
"github.com/gin-gonic/gin"
@@ -27,7 +29,7 @@ func isWithinRange(element string, value int) bool {
return value >= min && value <= max
}
-func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
+func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
imageModel := "dall-e-2"
imageSize := "1024x1024"
@@ -83,7 +85,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
}
// Number of generated images validation
- if isWithinRange(imageModel, imageRequest.N) == false {
+ if !isWithinRange(imageModel, imageRequest.N) {
// channel not azure
if channelType != common.ChannelTypeAzure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
@@ -112,7 +114,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
- apiVersion := util.GetAPIVersion(c)
+ apiVersion := util.GetAzureAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
}
@@ -175,11 +177,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
}
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
+ logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
- common.SysError("error update user quota cache: " + err.Error())
+ logger.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
diff --git a/relay/controller/text.go b/relay/controller/text.go
index b17ff950..cc460511 100644
--- a/relay/controller/text.go
+++ b/relay/controller/text.go
@@ -2,680 +2,100 @@ package controller
import (
"bytes"
- "context"
"encoding/json"
- "errors"
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/helper"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/util"
"io"
- "math"
"net/http"
- "one-api/common"
- "one-api/model"
- "one-api/relay/channel/aiproxy"
- "one-api/relay/channel/ali"
- "one-api/relay/channel/anthropic"
- "one-api/relay/channel/baidu"
- "one-api/relay/channel/google"
- "one-api/relay/channel/openai"
- "one-api/relay/channel/tencent"
- "one-api/relay/channel/xunfei"
- "one-api/relay/channel/zhipu"
- "one-api/relay/constant"
- "one-api/relay/util"
"strings"
)
-const (
- APITypeOpenAI = iota
- APITypeClaude
- APITypePaLM
- APITypeBaidu
- APITypeZhipu
- APITypeAli
- APITypeXunfei
- APITypeAIProxyLibrary
- APITypeTencent
- APITypeGemini
-)
-
-func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
- channelType := c.GetInt("channel")
- channelId := c.GetInt("channel_id")
- tokenId := c.GetInt("token_id")
- userId := c.GetInt("id")
- group := c.GetString("group")
- var textRequest openai.GeneralOpenAIRequest
- err := common.UnmarshalBodyReusable(c, &textRequest)
+func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
+ ctx := c.Request.Context()
+ meta := util.GetRelayMeta(c)
+ // get & validate textRequest
+ textRequest, err := getAndValidateTextRequest(c, meta.Mode)
if err != nil {
- return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
- }
- if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
- return openai.ErrorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest)
- }
- if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
- textRequest.Model = "text-moderation-latest"
- }
- if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
- textRequest.Model = c.Param("model")
- }
- // request validation
- if textRequest.Model == "" {
- return openai.ErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
- }
- switch relayMode {
- case constant.RelayModeCompletions:
- if textRequest.Prompt == "" {
- return openai.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
- }
- case constant.RelayModeChatCompletions:
- if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
- return openai.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
- }
- case constant.RelayModeEmbeddings:
- case constant.RelayModeModerations:
- if textRequest.Input == "" {
- return openai.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
- }
- case constant.RelayModeEdits:
- if textRequest.Instruction == "" {
- return openai.ErrorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
- }
+ logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
+ return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
}
+ meta.IsStream = textRequest.Stream
+
// map model name
- modelMapping := c.GetString("model_mapping")
- isModelMapped := false
- if modelMapping != "" && modelMapping != "{}" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[textRequest.Model] != "" {
- textRequest.Model = modelMap[textRequest.Model]
- isModelMapped = true
- }
- }
- apiType := APITypeOpenAI
- switch channelType {
- case common.ChannelTypeAnthropic:
- apiType = APITypeClaude
- case common.ChannelTypeBaidu:
- apiType = APITypeBaidu
- case common.ChannelTypePaLM:
- apiType = APITypePaLM
- case common.ChannelTypeZhipu:
- apiType = APITypeZhipu
- case common.ChannelTypeAli:
- apiType = APITypeAli
- case common.ChannelTypeXunfei:
- apiType = APITypeXunfei
- case common.ChannelTypeAIProxyLibrary:
- apiType = APITypeAIProxyLibrary
- case common.ChannelTypeTencent:
- apiType = APITypeTencent
- case common.ChannelTypeGemini:
- apiType = APITypeGemini
- }
- baseURL := common.ChannelBaseURLs[channelType]
- requestURL := c.Request.URL.String()
- if c.GetString("base_url") != "" {
- baseURL = c.GetString("base_url")
- }
- fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
- switch apiType {
- case APITypeOpenAI:
- if channelType == common.ChannelTypeAzure {
- // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
- apiVersion := util.GetAPIVersion(c)
- requestURL := strings.Split(requestURL, "?")[0]
- requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
- baseURL = c.GetString("base_url")
- task := strings.TrimPrefix(requestURL, "/v1/")
- model_ := textRequest.Model
- model_ = strings.Replace(model_, ".", "", -1)
- // https://github.com/songquanpeng/one-api/issues/67
- model_ = strings.TrimSuffix(model_, "-0301")
- model_ = strings.TrimSuffix(model_, "-0314")
- model_ = strings.TrimSuffix(model_, "-0613")
-
- requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
- fullRequestURL = util.GetFullRequestURL(baseURL, requestURL, channelType)
- }
- case APITypeClaude:
- fullRequestURL = "https://api.anthropic.com/v1/complete"
- if baseURL != "" {
- fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
- }
- case APITypeBaidu:
- switch textRequest.Model {
- case "ERNIE-Bot":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
- case "ERNIE-Bot-turbo":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
- case "ERNIE-Bot-4":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
- case "BLOOMZ-7B":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
- case "Embedding-V1":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
- }
- apiKey := c.Request.Header.Get("Authorization")
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
- var err error
- if apiKey, err = baidu.GetAccessToken(apiKey); err != nil {
- return openai.ErrorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
- }
- fullRequestURL += "?access_token=" + apiKey
- case APITypePaLM:
- fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
- if baseURL != "" {
- fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
- }
- case APITypeGemini:
- requestBaseURL := "https://generativelanguage.googleapis.com"
- if baseURL != "" {
- requestBaseURL = baseURL
- }
- version := "v1"
- if c.GetString("api_version") != "" {
- version = c.GetString("api_version")
- }
- action := "generateContent"
- if textRequest.Stream {
- action = "streamGenerateContent"
- }
- fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
- case APITypeZhipu:
- method := "invoke"
- if textRequest.Stream {
- method = "sse-invoke"
- }
- fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
- case APITypeAli:
- fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
- if relayMode == constant.RelayModeEmbeddings {
- fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
- }
- case APITypeTencent:
- fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
- case APITypeAIProxyLibrary:
- fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
- }
- var promptTokens int
- var completionTokens int
- switch relayMode {
- case constant.RelayModeChatCompletions:
- promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
- case constant.RelayModeCompletions:
- promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
- case constant.RelayModeModerations:
- promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model)
- }
- preConsumedTokens := common.PreConsumedQuota
- if textRequest.MaxTokens != 0 {
- preConsumedTokens = promptTokens + textRequest.MaxTokens
- }
+ var isModelMapped bool
+ meta.OriginModelName = textRequest.Model
+ textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
+ meta.ActualModelName = textRequest.Model
+ // get model ratio & group ratio
modelRatio := common.GetModelRatio(textRequest.Model)
- groupRatio := common.GetGroupRatio(group)
+ groupRatio := common.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
- preConsumedQuota := int(float64(preConsumedTokens) * ratio)
- userQuota, err := model.CacheGetUserQuota(userId)
- if err != nil {
- return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+ // pre-consume quota
+ promptTokens := getPromptTokens(textRequest, meta.Mode)
+ meta.PromptTokens = promptTokens
+ preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
+ if bizErr != nil {
+ logger.Warnf(ctx, "preConsumeQuota failed: %+v", *bizErr)
+ return bizErr
}
- if userQuota-preConsumedQuota < 0 {
- return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
- }
- err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
- if err != nil {
- return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
- }
- if userQuota > 100*preConsumedQuota {
- // in this case, we do not pre-consume quota
- // because the user has enough quota
- preConsumedQuota = 0
- common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
- }
- if preConsumedQuota > 0 {
- err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
- if err != nil {
- return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
- }
+
+ adaptor := helper.GetAdaptor(meta.APIType)
+ if adaptor == nil {
+ return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
+
+ // get request body
var requestBody io.Reader
- if isModelMapped {
- jsonStr, err := json.Marshal(textRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+ if meta.APIType == constant.APITypeOpenAI {
+ // no need to convert request for openai
+ if isModelMapped {
+ jsonStr, err := json.Marshal(textRequest)
+ if err != nil {
+ return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ } else {
+ requestBody = c.Request.Body
}
- requestBody = bytes.NewBuffer(jsonStr)
} else {
- requestBody = c.Request.Body
- }
- switch apiType {
- case APITypeClaude:
- claudeRequest := anthropic.ConvertRequest(textRequest)
- jsonStr, err := json.Marshal(claudeRequest)
+ convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeBaidu:
- var jsonData []byte
- var err error
- switch relayMode {
- case constant.RelayModeEmbeddings:
- baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest)
- jsonData, err = json.Marshal(baiduEmbeddingRequest)
- default:
- baiduRequest := baidu.ConvertRequest(textRequest)
- jsonData, err = json.Marshal(baiduRequest)
+ return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
+ jsonData, err := json.Marshal(convertedRequest)
if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+ return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
- case APITypePaLM:
- palmRequest := google.ConvertPaLMRequest(textRequest)
- jsonStr, err := json.Marshal(palmRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeGemini:
- geminiChatRequest := google.ConvertGeminiRequest(textRequest)
- jsonStr, err := json.Marshal(geminiChatRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeZhipu:
- zhipuRequest := zhipu.ConvertRequest(textRequest)
- jsonStr, err := json.Marshal(zhipuRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeAli:
- var jsonStr []byte
- var err error
- switch relayMode {
- case constant.RelayModeEmbeddings:
- aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest)
- jsonStr, err = json.Marshal(aliEmbeddingRequest)
- default:
- aliRequest := ali.ConvertRequest(textRequest)
- jsonStr, err = json.Marshal(aliRequest)
- }
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeTencent:
- apiKey := c.Request.Header.Get("Authorization")
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
- appId, secretId, secretKey, err := tencent.ParseConfig(apiKey)
- if err != nil {
- return openai.ErrorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
- }
- tencentRequest := tencent.ConvertRequest(textRequest)
- tencentRequest.AppId = appId
- tencentRequest.SecretId = secretId
- jsonStr, err := json.Marshal(tencentRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- sign := tencent.GetSign(*tencentRequest, secretKey)
- c.Request.Header.Set("Authorization", sign)
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeAIProxyLibrary:
- aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest)
- aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
- jsonStr, err := json.Marshal(aiProxyLibraryRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
}
- var req *http.Request
- var resp *http.Response
- isStream := textRequest.Stream
-
- if apiType != APITypeXunfei { // cause xunfei use websocket
- req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
- if err != nil {
- return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
- }
- apiKey := c.Request.Header.Get("Authorization")
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
- switch apiType {
- case APITypeOpenAI:
- if channelType == common.ChannelTypeAzure {
- req.Header.Set("api-key", apiKey)
- } else {
- req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
- if channelType == common.ChannelTypeOpenRouter {
- req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
- req.Header.Set("X-Title", "One API")
- }
- }
- case APITypeClaude:
- req.Header.Set("x-api-key", apiKey)
- anthropicVersion := c.Request.Header.Get("anthropic-version")
- if anthropicVersion == "" {
- anthropicVersion = "2023-06-01"
- }
- req.Header.Set("anthropic-version", anthropicVersion)
- case APITypeZhipu:
- token := zhipu.GetToken(apiKey)
- req.Header.Set("Authorization", token)
- case APITypeAli:
- req.Header.Set("Authorization", "Bearer "+apiKey)
- if textRequest.Stream {
- req.Header.Set("X-DashScope-SSE", "enable")
- }
- if c.GetString("plugin") != "" {
- req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
- }
- case APITypeTencent:
- req.Header.Set("Authorization", apiKey)
- case APITypePaLM:
- req.Header.Set("x-goog-api-key", apiKey)
- case APITypeGemini:
- req.Header.Set("x-goog-api-key", apiKey)
- default:
- req.Header.Set("Authorization", "Bearer "+apiKey)
- }
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- if isStream && c.Request.Header.Get("Accept") == "" {
- req.Header.Set("Accept", "text/event-stream")
- }
- //req.Header.Set("Connection", c.Request.Header.Get("Connection"))
- resp, err = util.HTTPClient.Do(req)
- if err != nil {
- return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- }
- err = req.Body.Close()
- if err != nil {
- return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
- }
- err = c.Request.Body.Close()
- if err != nil {
- return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
- }
- isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
-
- if resp.StatusCode != http.StatusOK {
- if preConsumedQuota != 0 {
- go func(ctx context.Context) {
- // return pre-consumed quota
- err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
- if err != nil {
- common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
- }
- }(c.Request.Context())
- }
- return util.RelayErrorHandler(resp)
- }
+ // do request
+ resp, err := adaptor.DoRequest(c, meta, requestBody)
+ if err != nil {
+ logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
+ return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+ }
+ meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
+ if resp.StatusCode != http.StatusOK {
+ util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
+ return util.RelayErrorHandler(resp)
}
- var textResponse openai.SlimTextResponse
- tokenName := c.GetString("token_name")
-
- defer func(ctx context.Context) {
- // c.Writer.Flush()
- go func() {
- quota := 0
- completionRatio := common.GetCompletionRatio(textRequest.Model)
- promptTokens = textResponse.Usage.PromptTokens
- completionTokens = textResponse.Usage.CompletionTokens
- quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
- if ratio != 0 && quota <= 0 {
- quota = 1
- }
- totalTokens := promptTokens + completionTokens
- if totalTokens == 0 {
- // in this case, must be some error happened
- // we cannot just return, because we may have to return the pre-consumed quota
- quota = 0
- }
- quotaDelta := quota - preConsumedQuota
- err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
- if err != nil {
- common.LogError(ctx, "error consuming token remain quota: "+err.Error())
- }
- err = model.CacheUpdateUserQuota(userId)
- if err != nil {
- common.LogError(ctx, "error update user quota cache: "+err.Error())
- }
- if quota != 0 {
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- model.UpdateChannelUsedQuota(channelId, quota)
- }
-
- }()
- }(c.Request.Context())
- switch apiType {
- case APITypeOpenAI:
- if isStream {
- err, responseText := openai.StreamHandler(c, resp, relayMode)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := openai.Handler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeClaude:
- if isStream {
- err, responseText := anthropic.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := anthropic.Handler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeBaidu:
- if isStream {
- err, usage := baidu.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- } else {
- var err *openai.ErrorWithStatusCode
- var usage *openai.Usage
- switch relayMode {
- case constant.RelayModeEmbeddings:
- err, usage = baidu.EmbeddingHandler(c, resp)
- default:
- err, usage = baidu.Handler(c, resp)
- }
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypePaLM:
- if textRequest.Stream { // PaLM2 API does not support stream
- err, responseText := google.PaLMStreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeGemini:
- if textRequest.Stream {
- err, responseText := google.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeZhipu:
- if isStream {
- err, usage := zhipu.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- // zhipu's API does not return prompt tokens & completion tokens
- textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
- return nil
- } else {
- err, usage := zhipu.Handler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- // zhipu's API does not return prompt tokens & completion tokens
- textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
- return nil
- }
- case APITypeAli:
- if isStream {
- err, usage := ali.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- } else {
- var err *openai.ErrorWithStatusCode
- var usage *openai.Usage
- switch relayMode {
- case constant.RelayModeEmbeddings:
- err, usage = ali.EmbeddingHandler(c, resp)
- default:
- err, usage = ali.Handler(c, resp)
- }
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeXunfei:
- auth := c.Request.Header.Get("Authorization")
- auth = strings.TrimPrefix(auth, "Bearer ")
- splits := strings.Split(auth, "|")
- if len(splits) != 3 {
- return openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
- }
- var err *openai.ErrorWithStatusCode
- var usage *openai.Usage
- if isStream {
- err, usage = xunfei.StreamHandler(c, textRequest, splits[0], splits[1], splits[2])
- } else {
- err, usage = xunfei.Handler(c, textRequest, splits[0], splits[1], splits[2])
- }
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- case APITypeAIProxyLibrary:
- if isStream {
- err, usage := aiproxy.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- } else {
- err, usage := aiproxy.Handler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeTencent:
- if isStream {
- err, responseText := tencent.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := tencent.Handler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- default:
- return openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
+ // do response
+ usage, respErr := adaptor.DoResponse(c, resp, meta)
+ if respErr != nil {
+ logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
+ util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
+ return respErr
}
+ // post-consume quota
+ go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
+ return nil
}
diff --git a/relay/helper/main.go b/relay/helper/main.go
new file mode 100644
index 00000000..c2b6e6af
--- /dev/null
+++ b/relay/helper/main.go
@@ -0,0 +1,42 @@
+package helper
+
+import (
+ "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/channel/aiproxy"
+ "github.com/songquanpeng/one-api/relay/channel/ali"
+ "github.com/songquanpeng/one-api/relay/channel/anthropic"
+ "github.com/songquanpeng/one-api/relay/channel/baidu"
+ "github.com/songquanpeng/one-api/relay/channel/gemini"
+ "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/channel/palm"
+ "github.com/songquanpeng/one-api/relay/channel/tencent"
+ "github.com/songquanpeng/one-api/relay/channel/xunfei"
+ "github.com/songquanpeng/one-api/relay/channel/zhipu"
+ "github.com/songquanpeng/one-api/relay/constant"
+)
+
+func GetAdaptor(apiType int) channel.Adaptor {
+ switch apiType {
+ case constant.APITypeAIProxyLibrary:
+ return &aiproxy.Adaptor{}
+ case constant.APITypeAli:
+ return &ali.Adaptor{}
+ case constant.APITypeAnthropic:
+ return &anthropic.Adaptor{}
+ case constant.APITypeBaidu:
+ return &baidu.Adaptor{}
+ case constant.APITypeGemini:
+ return &gemini.Adaptor{}
+ case constant.APITypeOpenAI:
+ return &openai.Adaptor{}
+ case constant.APITypePaLM:
+ return &palm.Adaptor{}
+ case constant.APITypeTencent:
+ return &tencent.Adaptor{}
+ case constant.APITypeXunfei:
+ return &xunfei.Adaptor{}
+ case constant.APITypeZhipu:
+ return &zhipu.Adaptor{}
+ }
+ return nil
+}
diff --git a/relay/channel/openai/constant.go b/relay/model/constant.go
similarity index 83%
rename from relay/channel/openai/constant.go
rename to relay/model/constant.go
index 000f72ee..f6cf1924 100644
--- a/relay/channel/openai/constant.go
+++ b/relay/model/constant.go
@@ -1,4 +1,4 @@
-package openai
+package model
const (
ContentTypeText = "text"
diff --git a/relay/model/general.go b/relay/model/general.go
new file mode 100644
index 00000000..fbcc04e8
--- /dev/null
+++ b/relay/model/general.go
@@ -0,0 +1,46 @@
+package model
+
+type ResponseFormat struct {
+ Type string `json:"type,omitempty"`
+}
+
+type GeneralOpenAIRequest struct {
+ Model string `json:"model,omitempty"`
+ Messages []Message `json:"messages,omitempty"`
+ Prompt any `json:"prompt,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ N int `json:"n,omitempty"`
+ Input any `json:"input,omitempty"`
+ Instruction string `json:"instruction,omitempty"`
+ Size string `json:"size,omitempty"`
+ Functions any `json:"functions,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
+ Seed float64 `json:"seed,omitempty"`
+ Tools any `json:"tools,omitempty"`
+ ToolChoice any `json:"tool_choice,omitempty"`
+ User string `json:"user,omitempty"`
+}
+
+func (r GeneralOpenAIRequest) ParseInput() []string {
+ if r.Input == nil {
+ return nil
+ }
+ var input []string
+ switch r.Input.(type) {
+ case string:
+ input = []string{r.Input.(string)}
+ case []any:
+ input = make([]string, 0, len(r.Input.([]any)))
+ for _, item := range r.Input.([]any) {
+ if str, ok := item.(string); ok {
+ input = append(input, str)
+ }
+ }
+ }
+ return input
+}
diff --git a/relay/model/message.go b/relay/model/message.go
new file mode 100644
index 00000000..c6c8a271
--- /dev/null
+++ b/relay/model/message.go
@@ -0,0 +1,88 @@
+package model
+
+type Message struct {
+ Role string `json:"role"`
+ Content any `json:"content"`
+ Name *string `json:"name,omitempty"`
+}
+
+func (m Message) IsStringContent() bool {
+ _, ok := m.Content.(string)
+ return ok
+}
+
+func (m Message) StringContent() string {
+ content, ok := m.Content.(string)
+ if ok {
+ return content
+ }
+ contentList, ok := m.Content.([]any)
+ if ok {
+ var contentStr string
+ for _, contentItem := range contentList {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+ return ""
+}
+
+func (m Message) ParseContent() []MessageContent {
+ var contentList []MessageContent
+ content, ok := m.Content.(string)
+ if ok {
+ contentList = append(contentList, MessageContent{
+ Type: ContentTypeText,
+ Text: content,
+ })
+ return contentList
+ }
+ anyList, ok := m.Content.([]any)
+ if ok {
+ for _, contentItem := range anyList {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ switch contentMap["type"] {
+ case ContentTypeText:
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentList = append(contentList, MessageContent{
+ Type: ContentTypeText,
+ Text: subStr,
+ })
+ }
+ case ContentTypeImageURL:
+ if subObj, ok := contentMap["image_url"].(map[string]any); ok {
+ contentList = append(contentList, MessageContent{
+ Type: ContentTypeImageURL,
+ ImageURL: &ImageURL{
+ Url: subObj["url"].(string),
+ },
+ })
+ }
+ }
+ }
+ return contentList
+ }
+ return nil
+}
+
+type ImageURL struct {
+ Url string `json:"url,omitempty"`
+ Detail string `json:"detail,omitempty"`
+}
+
+type MessageContent struct {
+ Type string `json:"type,omitempty"`
+ Text string `json:"text"`
+ ImageURL *ImageURL `json:"image_url,omitempty"`
+}
diff --git a/relay/model/misc.go b/relay/model/misc.go
new file mode 100644
index 00000000..163bc398
--- /dev/null
+++ b/relay/model/misc.go
@@ -0,0 +1,19 @@
+package model
+
+type Usage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+type Error struct {
+ Message string `json:"message"`
+ Type string `json:"type"`
+ Param string `json:"param"`
+ Code any `json:"code"`
+}
+
+type ErrorWithStatusCode struct {
+ Error
+ StatusCode int `json:"status_code"`
+}
diff --git a/relay/util/billing.go b/relay/util/billing.go
new file mode 100644
index 00000000..1e2b09ea
--- /dev/null
+++ b/relay/util/billing.go
@@ -0,0 +1,19 @@
+package util
+
+import (
+ "context"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
+)
+
+func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) {
+ if preConsumedQuota != 0 {
+ go func(ctx context.Context) {
+ // return pre-consumed quota
+ err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
+ if err != nil {
+ logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
+ }
+ }(ctx)
+ }
+}
diff --git a/relay/util/common.go b/relay/util/common.go
index 9d13b12e..6d993378 100644
--- a/relay/util/common.go
+++ b/relay/util/common.go
@@ -4,19 +4,21 @@ import (
"context"
"encoding/json"
"fmt"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
- "one-api/common"
- "one-api/model"
- "one-api/relay/channel/openai"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
-func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
- if !common.AutomaticDisableChannelEnabled {
+func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
+ if !config.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
@@ -31,8 +33,8 @@ func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
return false
}
-func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
- if !common.AutomaticEnableChannelEnabled {
+func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool {
+ if !config.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
@@ -45,11 +47,11 @@ func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
}
type GeneralErrorResponse struct {
- Error openai.Error `json:"error"`
- Message string `json:"message"`
- Msg string `json:"msg"`
- Err string `json:"err"`
- ErrorMsg string `json:"error_msg"`
+ Error relaymodel.Error `json:"error"`
+ Message string `json:"message"`
+ Msg string `json:"msg"`
+ Err string `json:"err"`
+ ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
@@ -85,10 +87,10 @@ func (e GeneralErrorResponse) ToMessage() string {
return ""
}
-func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) {
- ErrorWithStatusCode = &openai.ErrorWithStatusCode{
+func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) {
+ ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
- Error: openai.Error{
+ Error: relaymodel.Error{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
@@ -138,11 +140,11 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
+ logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
- common.SysError("error update user quota cache: " + err.Error())
+ logger.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota != 0 {
@@ -152,15 +154,15 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
- common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
+ logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}
-func GetAPIVersion(c *gin.Context) string {
+func GetAzureAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
- apiVersion = c.GetString("api_version")
+ apiVersion = c.GetString(common.ConfigKeyAPIVersion)
}
return apiVersion
}
diff --git a/relay/util/init.go b/relay/util/init.go
index d308d900..03dad31b 100644
--- a/relay/util/init.go
+++ b/relay/util/init.go
@@ -1,8 +1,8 @@
package util
import (
+ "github.com/songquanpeng/one-api/common/config"
"net/http"
- "one-api/common"
"time"
)
@@ -10,11 +10,11 @@ var HTTPClient *http.Client
var ImpatientHTTPClient *http.Client
func init() {
- if common.RelayTimeout == 0 {
+ if config.RelayTimeout == 0 {
HTTPClient = &http.Client{}
} else {
HTTPClient = &http.Client{
- Timeout: time.Duration(common.RelayTimeout) * time.Second,
+ Timeout: time.Duration(config.RelayTimeout) * time.Second,
}
}
diff --git a/relay/util/model_mapping.go b/relay/util/model_mapping.go
new file mode 100644
index 00000000..39e062a1
--- /dev/null
+++ b/relay/util/model_mapping.go
@@ -0,0 +1,12 @@
+package util
+
+func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) {
+ if mapping == nil {
+ return modelName, false
+ }
+ mappedModelName := mapping[modelName]
+ if mappedModelName != "" {
+ return mappedModelName, true
+ }
+ return modelName, false
+}
diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go
new file mode 100644
index 00000000..31b9d2b4
--- /dev/null
+++ b/relay/util/relay_meta.go
@@ -0,0 +1,55 @@
+package util
+
+import (
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "strings"
+)
+
+type RelayMeta struct {
+ Mode int
+ ChannelType int
+ ChannelId int
+ TokenId int
+ TokenName string
+ UserId int
+ Group string
+ ModelMapping map[string]string
+ BaseURL string
+ APIVersion string
+ APIKey string
+ APIType int
+ Config map[string]string
+ IsStream bool
+ OriginModelName string
+ ActualModelName string
+ RequestURLPath string
+ PromptTokens int // only for DoResponse
+}
+
+func GetRelayMeta(c *gin.Context) *RelayMeta {
+ meta := RelayMeta{
+ Mode: constant.Path2RelayMode(c.Request.URL.Path),
+ ChannelType: c.GetInt("channel"),
+ ChannelId: c.GetInt("channel_id"),
+ TokenId: c.GetInt("token_id"),
+ TokenName: c.GetString("token_name"),
+ UserId: c.GetInt("id"),
+ Group: c.GetString("group"),
+ ModelMapping: c.GetStringMapString("model_mapping"),
+ BaseURL: c.GetString("base_url"),
+ APIVersion: c.GetString(common.ConfigKeyAPIVersion),
+ APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+ Config: nil,
+ RequestURLPath: c.Request.URL.String(),
+ }
+ if meta.ChannelType == common.ChannelTypeAzure {
+ meta.APIVersion = GetAzureAPIVersion(c)
+ }
+ if meta.BaseURL == "" {
+ meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
+ }
+ meta.APIType = constant.ChannelType2APIType(meta.ChannelType)
+ return &meta
+}
diff --git a/relay/util/validation.go b/relay/util/validation.go
new file mode 100644
index 00000000..ef8d840c
--- /dev/null
+++ b/relay/util/validation.go
@@ -0,0 +1,37 @@
+package util
+
+import (
+ "errors"
+ "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/model"
+ "math"
+)
+
+func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error {
+ if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
+ return errors.New("max_tokens is invalid")
+ }
+ if textRequest.Model == "" {
+ return errors.New("model is required")
+ }
+ switch relayMode {
+ case constant.RelayModeCompletions:
+ if textRequest.Prompt == "" {
+ return errors.New("field prompt is required")
+ }
+ case constant.RelayModeChatCompletions:
+ if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
+ return errors.New("field messages is required")
+ }
+ case constant.RelayModeEmbeddings:
+ case constant.RelayModeModerations:
+ if textRequest.Input == "" {
+ return errors.New("field input is required")
+ }
+ case constant.RelayModeEdits:
+ if textRequest.Instruction == "" {
+ return errors.New("field instruction is required")
+ }
+ }
+ return nil
+}
diff --git a/router/api-router.go b/router/api-router.go
index bd2574ab..3b1b78d6 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -1,8 +1,8 @@
package router
import (
- "one-api/controller"
- "one-api/middleware"
+ "github.com/songquanpeng/one-api/controller"
+ "github.com/songquanpeng/one-api/middleware"
"github.com/gin-contrib/gzip"
"github.com/gin-gonic/gin"
diff --git a/router/dashboard.go b/router/dashboard.go
index f0900d8f..ffe05066 100644
--- a/router/dashboard.go
+++ b/router/dashboard.go
@@ -1,11 +1,16 @@
package router
import (
+
"one-api/controller"
"one-api/middleware"
"github.com/gin-contrib/gzip"
"github.com/gin-gonic/gin"
+
+ "github.com/songquanpeng/one-api/controller"
+ "github.com/songquanpeng/one-api/middleware"
+
)
func SetDashboardRouter(router *gin.Engine) {
diff --git a/router/main.go b/router/main.go
index 85127a1a..39d8c04f 100644
--- a/router/main.go
+++ b/router/main.go
@@ -4,8 +4,9 @@ import (
"embed"
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
"net/http"
- "one-api/common"
"os"
"strings"
)
@@ -15,9 +16,9 @@ func SetRouter(router *gin.Engine, buildFS embed.FS) {
SetDashboardRouter(router)
SetRelayRouter(router)
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
- if common.IsMasterNode && frontendBaseUrl != "" {
+ if config.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = ""
- common.SysLog("FRONTEND_BASE_URL is ignored on master node")
+ logger.SysLog("FRONTEND_BASE_URL is ignored on master node")
}
if frontendBaseUrl == "" {
SetWebRouter(router, buildFS)
diff --git a/router/relay-router.go b/router/relay-router.go
index 56ab9b28..65072c86 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -1,8 +1,8 @@
package router
import (
- "one-api/controller"
- "one-api/middleware"
+ "github.com/songquanpeng/one-api/controller"
+ "github.com/songquanpeng/one-api/middleware"
"github.com/gin-gonic/gin"
)
diff --git a/router/web-router.go b/router/web-router.go
index 7328c7a3..3c9b4643 100644
--- a/router/web-router.go
+++ b/router/web-router.go
@@ -6,19 +6,20 @@ import (
"github.com/gin-contrib/gzip"
"github.com/gin-contrib/static"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/controller"
+ "github.com/songquanpeng/one-api/middleware"
"net/http"
- "one-api/common"
- "one-api/controller"
- "one-api/middleware"
"strings"
)
func SetWebRouter(router *gin.Engine, buildFS embed.FS) {
- indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", common.Theme))
+ indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", config.Theme))
router.Use(gzip.Gzip(gzip.DefaultCompression))
router.Use(middleware.GlobalWebRateLimit())
router.Use(middleware.Cache())
- router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", common.Theme))))
+ router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", config.Theme))))
router.NoRoute(func(c *gin.Context) {
if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") {
controller.RelayNotFound(c)
diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js
index 3ce27838..aeff5190 100644
--- a/web/berry/src/constants/ChannelConstants.js
+++ b/web/berry/src/constants/ChannelConstants.js
@@ -59,6 +59,12 @@ export const CHANNEL_OPTIONS = {
value: 19,
color: 'default'
},
+ 25: {
+ key: 25,
+ text: 'Moonshot AI',
+ value: 19,
+ color: 'default'
+ },
23: {
key: 23,
text: '腾讯混元',
diff --git a/web/berry/src/hooks/useRegister.js b/web/berry/src/hooks/useRegister.js
index d07dc43a..5377e96d 100644
--- a/web/berry/src/hooks/useRegister.js
+++ b/web/berry/src/hooks/useRegister.js
@@ -6,6 +6,10 @@ const useRegister = () => {
const navigate = useNavigate();
const register = async (input, turnstile) => {
try {
+ let affCode = localStorage.getItem('aff');
+ if (affCode) {
+ input = { ...input, aff_code: affCode };
+ }
const res = await API.post(`/api/user/register?turnstile=${turnstile}`, input);
const { success, message } = res.data;
if (success) {
diff --git a/web/berry/src/views/Channel/index.js b/web/berry/src/views/Channel/index.js
index 5b7f1722..39ab5d82 100644
--- a/web/berry/src/views/Channel/index.js
+++ b/web/berry/src/views/Channel/index.js
@@ -202,9 +202,7 @@ export default function ChannelPage() {
- 当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo
- 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。 另外,OpenAI 渠道已经不再支持通过 key
- 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
+ OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
@@ -229,9 +227,9 @@ export default function ChannelPage() {
}>
测试启用渠道
- }>
- 更新启用余额
-
+ {/*}>*/}
+ {/* 更新启用余额*/}
+ {/**/}
}>
删除禁用渠道
diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js
index d270f527..a091c8d6 100644
--- a/web/berry/src/views/Channel/type/Config.js
+++ b/web/berry/src/views/Channel/type/Config.js
@@ -94,7 +94,13 @@ const typeConfig = {
other: "版本号",
},
input: {
- models: ["SparkDesk"],
+ models: [
+ "SparkDesk",
+ 'SparkDesk-v1.1',
+ 'SparkDesk-v2.1',
+ 'SparkDesk-v3.1',
+ 'SparkDesk-v3.5'
+ ],
},
prompt: {
key: "按照如下格式输入:APPID|APISecret|APIKey",
diff --git a/web/berry/src/views/Setting/component/OperationSetting.js b/web/berry/src/views/Setting/component/OperationSetting.js
index 0d331d76..d91298b2 100644
--- a/web/berry/src/views/Setting/component/OperationSetting.js
+++ b/web/berry/src/views/Setting/component/OperationSetting.js
@@ -27,6 +27,7 @@ const OperationSetting = () => {
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
ModelRatio: "",
+ CompletionRatio: "",
GroupRatio: "",
TopUpLink: "",
ChatLink: "",
@@ -52,9 +53,12 @@ const OperationSetting = () => {
if (success) {
let newInputs = {};
data.forEach((item) => {
- if (item.key === "ModelRatio" || item.key === "GroupRatio") {
+ if (item.key === "ModelRatio" || item.key === "GroupRatio" || item.key === "CompletionRatio") {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
+ if (item.value === '{}') {
+ item.value = '';
+ }
newInputs[item.key] = item.value;
});
setInputs(newInputs);
@@ -133,6 +137,13 @@ const OperationSetting = () => {
}
await updateOption("GroupRatio", inputs.GroupRatio);
}
+ if (originInputs['CompletionRatio'] !== inputs.CompletionRatio) {
+ if (!verifyJSON(inputs.CompletionRatio)) {
+ showError('补全倍率不是合法的 JSON 字符串');
+ return;
+ }
+ await updateOption('CompletionRatio', inputs.CompletionRatio);
+ }
break;
case "quota":
if (originInputs["QuotaForNewUser"] !== inputs.QuotaForNewUser) {
@@ -500,7 +511,20 @@ const OperationSetting = () => {
placeholder="为一个 JSON 文本,键为模型名称,值为倍率"
/>
-
+
+
+
@@ -222,7 +222,7 @@ export default function TokensTableRow({ item, manageToken, handleOpenModal, set
-
+
diff --git a/web/default/src/components/ChannelsTable.js b/web/default/src/components/ChannelsTable.js
index a2adfd32..7117fe53 100644
--- a/web/default/src/components/ChannelsTable.js
+++ b/web/default/src/components/ChannelsTable.js
@@ -322,10 +322,7 @@ const ChannelsTable = () => {
setShowPrompt(false);
setPromptShown("channel-test");
}}>
- 当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo
- 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。
-
- 另外,OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
+ OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
)
}
@@ -525,8 +522,8 @@ const ChannelsTable = () => {
-
+ {/**/}
diff --git a/web/default/src/components/OperationSetting.js b/web/default/src/components/OperationSetting.js
index 3b52bb27..b823bb28 100644
--- a/web/default/src/components/OperationSetting.js
+++ b/web/default/src/components/OperationSetting.js
@@ -11,6 +11,7 @@ const OperationSetting = () => {
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
ModelRatio: '',
+ CompletionRatio: '',
GroupRatio: '',
TopUpLink: '',
ChatLink: '',
@@ -34,9 +35,12 @@ const OperationSetting = () => {
if (success) {
let newInputs = {};
data.forEach((item) => {
- if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
+ if (item.key === 'ModelRatio' || item.key === 'GroupRatio' || item.key === 'CompletionRatio') {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
+ if (item.value === '{}') {
+ item.value = '';
+ }
newInputs[item.key] = item.value;
});
setInputs(newInputs);
@@ -101,6 +105,13 @@ const OperationSetting = () => {
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
+ if (originInputs['CompletionRatio'] !== inputs.CompletionRatio) {
+ if (!verifyJSON(inputs.CompletionRatio)) {
+ showError('补全倍率不是合法的 JSON 字符串');
+ return;
+ }
+ await updateOption('CompletionRatio', inputs.CompletionRatio);
+ }
break;
case 'quota':
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
@@ -271,10 +282,10 @@ const OperationSetting = () => {
onChange={handleInputChange}
/>
{
@@ -344,6 +355,17 @@ const OperationSetting = () => {
placeholder='为一个 JSON 文本,键为模型名称,值为倍率'
/>
+
+
+
{
localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
- localModels = ['SparkDesk'];
+ localModels = [
+ 'SparkDesk',
+ 'SparkDesk-v1.1',
+ 'SparkDesk-v2.1',
+ 'SparkDesk-v3.1',
+ 'SparkDesk-v3.5'
+ ];
break;
case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
@@ -93,6 +99,9 @@ const EditChannel = () => {
case 24:
localModels = ['gemini-pro', 'gemini-pro-vision'];
break;
+ case 25:
+ localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'];
+ break;
}
setInputs((inputs) => ({ ...inputs, models: localModels }));
}