Merge remote-tracking branch 'origin/main'
# Conflicts: # controller/relay.go # main.go # middleware/distributor.go
This commit is contained in:
commit
9c08d78349
15
README.md
15
README.md
@ -68,6 +68,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
|
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
|
||||||
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
||||||
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
||||||
|
+ [x] [360 智脑](https://ai.360.cn)
|
||||||
2. 支持配置镜像以及众多第三方代理服务:
|
2. 支持配置镜像以及众多第三方代理服务:
|
||||||
+ [x] [OpenAI-SB](https://openai-sb.com)
|
+ [x] [OpenAI-SB](https://openai-sb.com)
|
||||||
+ [x] [API2D](https://api2d.com/r/197971)
|
+ [x] [API2D](https://api2d.com/r/197971)
|
||||||
@ -108,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
|
|
||||||
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
|
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
|
||||||
|
|
||||||
|
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
|
||||||
|
|
||||||
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
|
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
|
||||||
|
|
||||||
如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
|
如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
|
||||||
@ -274,8 +277,9 @@ graph LR
|
|||||||
不加的话将会使用负载均衡的方式使用多个渠道。
|
不加的话将会使用负载均衡的方式使用多个渠道。
|
||||||
|
|
||||||
### 环境变量
|
### 环境变量
|
||||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。
|
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
||||||
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
||||||
|
+ 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
|
||||||
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
|
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
|
||||||
+ 例子:`SESSION_SECRET=random_string`
|
+ 例子:`SESSION_SECRET=random_string`
|
||||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
|
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
|
||||||
@ -302,6 +306,14 @@ graph LR
|
|||||||
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
|
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
|
||||||
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
|
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
|
||||||
+ 例子:`POLLING_INTERVAL=5`
|
+ 例子:`POLLING_INTERVAL=5`
|
||||||
|
10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||||
|
+ 例子:`BATCH_UPDATE_ENABLED=true`
|
||||||
|
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
|
||||||
|
11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
|
||||||
|
+ 例子:`BATCH_UPDATE_INTERVAL=5`
|
||||||
|
12. 请求频率限制:
|
||||||
|
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
|
||||||
|
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
|
||||||
|
|
||||||
### 命令行参数
|
### 命令行参数
|
||||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||||
@ -338,6 +350,7 @@ https://openai.justsong.cn
|
|||||||
5. ChatGPT Next Web 报错:`Failed to fetch`
|
5. ChatGPT Next Web 报错:`Failed to fetch`
|
||||||
+ 部署的时候不要设置 `BASE_URL`。
|
+ 部署的时候不要设置 `BASE_URL`。
|
||||||
+ 检查你的接口地址和 API Key 有没有填对。
|
+ 检查你的接口地址和 API Key 有没有填对。
|
||||||
|
+ 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
|
||||||
6. 报错:`当前分组负载已饱和,请稍后再试`
|
6. 报错:`当前分组负载已饱和,请稍后再试`
|
||||||
+ 上游通道 429 了。
|
+ 上游通道 429 了。
|
||||||
|
|
||||||
|
@ -98,6 +98,9 @@ var RequestInterval = time.Duration(requestInterval) * time.Second
|
|||||||
|
|
||||||
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
|
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
|
||||||
|
|
||||||
|
var BatchUpdateEnabled = false
|
||||||
|
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RoleGuestUser = 0
|
RoleGuestUser = 0
|
||||||
RoleCommonUser = 1
|
RoleCommonUser = 1
|
||||||
@ -115,10 +118,10 @@ var (
|
|||||||
// All duration's unit is seconds
|
// All duration's unit is seconds
|
||||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||||
var (
|
var (
|
||||||
GlobalApiRateLimitNum = 180
|
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||||
GlobalApiRateLimitDuration int64 = 3 * 60
|
GlobalApiRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
GlobalWebRateLimitNum = 60
|
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||||
GlobalWebRateLimitDuration int64 = 3 * 60
|
GlobalWebRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
UploadRateLimitNum = 10
|
UploadRateLimitNum = 10
|
||||||
@ -158,45 +161,53 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ChannelTypeUnknown = 0
|
ChannelTypeUnknown = 0
|
||||||
ChannelTypeOpenAI = 1
|
ChannelTypeOpenAI = 1
|
||||||
ChannelTypeAPI2D = 2
|
ChannelTypeAPI2D = 2
|
||||||
ChannelTypeAzure = 3
|
ChannelTypeAzure = 3
|
||||||
ChannelTypeCloseAI = 4
|
ChannelTypeCloseAI = 4
|
||||||
ChannelTypeOpenAISB = 5
|
ChannelTypeOpenAISB = 5
|
||||||
ChannelTypeOpenAIMax = 6
|
ChannelTypeOpenAIMax = 6
|
||||||
ChannelTypeOhMyGPT = 7
|
ChannelTypeOhMyGPT = 7
|
||||||
ChannelTypeCustom = 8
|
ChannelTypeCustom = 8
|
||||||
ChannelTypeAILS = 9
|
ChannelTypeAILS = 9
|
||||||
ChannelTypeAIProxy = 10
|
ChannelTypeAIProxy = 10
|
||||||
ChannelTypePaLM = 11
|
ChannelTypePaLM = 11
|
||||||
ChannelTypeAPI2GPT = 12
|
ChannelTypeAPI2GPT = 12
|
||||||
ChannelTypeAIGC2D = 13
|
ChannelTypeAIGC2D = 13
|
||||||
ChannelTypeAnthropic = 14
|
ChannelTypeAnthropic = 14
|
||||||
ChannelTypeBaidu = 15
|
ChannelTypeBaidu = 15
|
||||||
ChannelTypeZhipu = 16
|
ChannelTypeZhipu = 16
|
||||||
ChannelTypeAli = 17
|
ChannelTypeAli = 17
|
||||||
ChannelTypeXunfei = 18
|
ChannelTypeXunfei = 18
|
||||||
|
ChannelType360 = 19
|
||||||
|
ChannelTypeOpenRouter = 20
|
||||||
|
ChannelTypeAIProxyLibrary = 21
|
||||||
|
ChannelTypeFastGPT = 22
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
"", // 0
|
"", // 0
|
||||||
"https://api.openai.com", // 1
|
"https://api.openai.com", // 1
|
||||||
"https://oa.api2d.net", // 2
|
"https://oa.api2d.net", // 2
|
||||||
"", // 3
|
"", // 3
|
||||||
"https://api.closeai-proxy.xyz", // 4
|
"https://api.closeai-proxy.xyz", // 4
|
||||||
"https://api.openai-sb.com", // 5
|
"https://api.openai-sb.com", // 5
|
||||||
"https://api.openaimax.com", // 6
|
"https://api.openaimax.com", // 6
|
||||||
"https://api.ohmygpt.com", // 7
|
"https://api.ohmygpt.com", // 7
|
||||||
"", // 8
|
"", // 8
|
||||||
"https://api.caipacity.com", // 9
|
"https://api.caipacity.com", // 9
|
||||||
"https://api.aiproxy.io", // 10
|
"https://api.aiproxy.io", // 10
|
||||||
"", // 11
|
"", // 11
|
||||||
"https://api.api2gpt.com", // 12
|
"https://api.api2gpt.com", // 12
|
||||||
"https://api.aigc2d.com", // 13
|
"https://api.aigc2d.com", // 13
|
||||||
"https://api.anthropic.com", // 14
|
"https://api.anthropic.com", // 14
|
||||||
"https://aip.baidubce.com", // 15
|
"https://aip.baidubce.com", // 15
|
||||||
"https://open.bigmodel.cn", // 16
|
"https://open.bigmodel.cn", // 16
|
||||||
"https://dashscope.aliyuncs.com", // 17
|
"https://dashscope.aliyuncs.com", // 17
|
||||||
"", // 18
|
"", // 18
|
||||||
|
"https://ai.360.cn", // 19
|
||||||
|
"https://openrouter.ai/api", // 20
|
||||||
|
"https://api.aiproxy.io", // 21
|
||||||
|
"https://fastgpt.run/api/openapi", // 22
|
||||||
}
|
}
|
||||||
|
@ -13,46 +13,52 @@ import (
|
|||||||
// 1 === $0.002 / 1K tokens
|
// 1 === $0.002 / 1K tokens
|
||||||
// 1 === ¥0.014 / 1k tokens
|
// 1 === ¥0.014 / 1k tokens
|
||||||
var ModelRatio = map[string]float64{
|
var ModelRatio = map[string]float64{
|
||||||
"gpt-4": 15,
|
"gpt-4": 15,
|
||||||
"gpt-4-0314": 15,
|
"gpt-4-0314": 15,
|
||||||
"gpt-4-0613": 15,
|
"gpt-4-0613": 15,
|
||||||
"gpt-4-32k": 30,
|
"gpt-4-32k": 30,
|
||||||
"gpt-4-32k-0314": 30,
|
"gpt-4-32k-0314": 30,
|
||||||
"gpt-4-32k-0613": 30,
|
"gpt-4-32k-0613": 30,
|
||||||
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
|
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
|
||||||
"gpt-3.5-turbo-0301": 0.75,
|
"gpt-3.5-turbo-0301": 0.75,
|
||||||
"gpt-3.5-turbo-0613": 0.75,
|
"gpt-3.5-turbo-0613": 0.75,
|
||||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
"gpt-3.5-turbo-16k-0613": 1.5,
|
||||||
"text-ada-001": 0.2,
|
"text-ada-001": 0.2,
|
||||||
"text-babbage-001": 0.25,
|
"text-babbage-001": 0.25,
|
||||||
"text-curie-001": 1,
|
"text-curie-001": 1,
|
||||||
"text-davinci-002": 10,
|
"text-davinci-002": 10,
|
||||||
"text-davinci-003": 10,
|
"text-davinci-003": 10,
|
||||||
"text-davinci-edit-001": 10,
|
"text-davinci-edit-001": 10,
|
||||||
"code-davinci-edit-001": 10,
|
"code-davinci-edit-001": 10,
|
||||||
"whisper-1": 10,
|
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
||||||
"davinci": 10,
|
"davinci": 10,
|
||||||
"curie": 10,
|
"curie": 10,
|
||||||
"babbage": 10,
|
"babbage": 10,
|
||||||
"ada": 10,
|
"ada": 10,
|
||||||
"text-embedding-ada-002": 0.05,
|
"text-embedding-ada-002": 0.05,
|
||||||
"text-search-ada-doc-001": 10,
|
"text-search-ada-doc-001": 10,
|
||||||
"text-moderation-stable": 0.1,
|
"text-moderation-stable": 0.1,
|
||||||
"text-moderation-latest": 0.1,
|
"text-moderation-latest": 0.1,
|
||||||
"dall-e": 8,
|
"dall-e": 8,
|
||||||
"claude-instant-1": 0.815, // $1.63 / 1M tokens
|
"claude-instant-1": 0.815, // $1.63 / 1M tokens
|
||||||
"claude-2": 5.51, // $11.02 / 1M tokens
|
"claude-2": 5.51, // $11.02 / 1M tokens
|
||||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
||||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"PaLM-2": 1,
|
"PaLM-2": 1,
|
||||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
|
"qwen-v1": 0.8572, // ¥0.012 / 1k tokens
|
||||||
"qwen-plus-v1": 0.5715, // Same as above
|
"qwen-plus-v1": 1, // ¥0.014 / 1k tokens
|
||||||
"SparkDesk": 0.8572, // TBD
|
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||||
|
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||||
|
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
|
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
|
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
|
"360GPT_S2_V9.4": 0.8572, // ¥0.012 / 1k tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelRatio2JSONString() string {
|
func ModelRatio2JSONString() string {
|
||||||
|
@ -14,7 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
|
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypePaLM:
|
case common.ChannelTypePaLM:
|
||||||
fallthrough
|
fallthrough
|
||||||
@ -24,10 +24,19 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
|
|||||||
fallthrough
|
fallthrough
|
||||||
case common.ChannelTypeZhipu:
|
case common.ChannelTypeZhipu:
|
||||||
fallthrough
|
fallthrough
|
||||||
|
case common.ChannelTypeAli:
|
||||||
|
fallthrough
|
||||||
|
case common.ChannelType360:
|
||||||
|
fallthrough
|
||||||
case common.ChannelTypeXunfei:
|
case common.ChannelTypeXunfei:
|
||||||
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
request.Model = "gpt-35-turbo"
|
request.Model = "gpt-35-turbo"
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
|
||||||
|
}
|
||||||
|
}()
|
||||||
default:
|
default:
|
||||||
request.Model = "gpt-3.5-turbo"
|
request.Model = "gpt-3.5-turbo"
|
||||||
}
|
}
|
||||||
|
@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
channel.CreatedTime = common.GetTimestamp()
|
channel.CreatedTime = common.GetTimestamp()
|
||||||
keys := strings.Split(channel.Key, "\n")
|
keys := strings.Split(channel.Key, "\n")
|
||||||
channels := make([]model.Channel, 0)
|
channels := make([]model.Channel, 0, len(keys))
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
if key == "" {
|
if key == "" {
|
||||||
continue
|
continue
|
||||||
|
@ -63,6 +63,15 @@ func init() {
|
|||||||
Root: "dall-e",
|
Root: "dall-e",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "whisper-1",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "whisper-1",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Id: "gpt-3.5-turbo",
|
Id: "gpt-3.5-turbo",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@ -351,6 +360,15 @@ func init() {
|
|||||||
Root: "qwen-plus-v1",
|
Root: "qwen-plus-v1",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "text-embedding-v1",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "ali",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "text-embedding-v1",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Id: "SparkDesk",
|
Id: "SparkDesk",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@ -360,6 +378,51 @@ func init() {
|
|||||||
Root: "SparkDesk",
|
Root: "SparkDesk",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "360GPT_S2_V9",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "360",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "360GPT_S2_V9",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "embedding-bert-512-v1",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "360",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "embedding-bert-512-v1",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "embedding_s1_v1",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "360",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "embedding_s1_v1",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "semantic_similarity_s1_v1",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "360",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "semantic_similarity_s1_v1",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "360GPT_S2_V9.4",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "360",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "360GPT_S2_V9.4",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
openAIModelsMap = make(map[string]OpenAIModels)
|
openAIModelsMap = make(map[string]OpenAIModels)
|
||||||
for _, model := range openAIModels {
|
for _, model := range openAIModels {
|
||||||
|
220
controller/relay-aiproxy.go
Normal file
220
controller/relay-aiproxy.go
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
||||||
|
|
||||||
|
type AIProxyLibraryRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
LibraryId string `json:"libraryId"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryError struct {
|
||||||
|
ErrCode int `json:"errCode"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryDocument struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
Documents []AIProxyLibraryDocument `json:"documents"`
|
||||||
|
AIProxyLibraryError
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryStreamResponse struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Finish bool `json:"finish"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Documents []AIProxyLibraryDocument `json:"documents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
||||||
|
query := ""
|
||||||
|
if len(request.Messages) != 0 {
|
||||||
|
query = request.Messages[len(request.Messages)-1].Content
|
||||||
|
}
|
||||||
|
return &AIProxyLibraryRequest{
|
||||||
|
Model: request.Model,
|
||||||
|
Stream: request.Stream,
|
||||||
|
Query: query,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
|
||||||
|
if len(documents) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
content := "\n\n参考文档:\n"
|
||||||
|
for i, document := range documents {
|
||||||
|
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
|
||||||
|
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
|
||||||
|
choice := OpenAITextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: content,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}
|
||||||
|
fullTextResponse := OpenAITextResponse{
|
||||||
|
Id: common.GetUUID(),
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Choices: []OpenAITextResponseChoice{choice},
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
||||||
|
choice.FinishReason = &stopFinishReason
|
||||||
|
return &ChatCompletionsStreamResponse{
|
||||||
|
Id: common.GetUUID(),
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "",
|
||||||
|
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = response.Content
|
||||||
|
return &ChatCompletionsStreamResponse{
|
||||||
|
Id: common.GetUUID(),
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: response.Model,
|
||||||
|
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var usage Usage
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 5 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if data[:5] != "data:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = data[5:]
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
setEventStreamHeaders(c)
|
||||||
|
var documents []AIProxyLibraryDocument
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(AIProxyLibraryResponse.Documents) != 0 {
|
||||||
|
documents = AIProxyLibraryResponse.Documents
|
||||||
|
}
|
||||||
|
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
response := documentsAIProxyLibrary(documents)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var AIProxyLibraryResponse AIProxyLibraryResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if AIProxyLibraryResponse.ErrCode != 0 {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: AIProxyLibraryResponse.Message,
|
||||||
|
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
|
||||||
|
Code: AIProxyLibraryResponse.ErrCode,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
@ -35,6 +35,29 @@ type AliChatRequest struct {
|
|||||||
Parameters AliParameters `json:"parameters,omitempty"`
|
Parameters AliParameters `json:"parameters,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AliEmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input struct {
|
||||||
|
Texts []string `json:"texts"`
|
||||||
|
} `json:"input"`
|
||||||
|
Parameters *struct {
|
||||||
|
TextType string `json:"text_type,omitempty"`
|
||||||
|
} `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbedding struct {
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
TextIndex int `json:"text_index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbeddingResponse struct {
|
||||||
|
Output struct {
|
||||||
|
Embeddings []AliEmbedding `json:"embeddings"`
|
||||||
|
} `json:"output"`
|
||||||
|
Usage AliUsage `json:"usage"`
|
||||||
|
AliError
|
||||||
|
}
|
||||||
|
|
||||||
type AliError struct {
|
type AliError struct {
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
@ -44,6 +67,7 @@ type AliError struct {
|
|||||||
type AliUsage struct {
|
type AliUsage struct {
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliOutput struct {
|
type AliOutput struct {
|
||||||
@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
|
||||||
|
return &AliEmbeddingRequest{
|
||||||
|
Model: "text-embedding-v1",
|
||||||
|
Input: struct {
|
||||||
|
Texts []string `json:"texts"`
|
||||||
|
}{
|
||||||
|
Texts: request.ParseInput(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var aliResponse AliEmbeddingResponse
|
||||||
|
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliResponse.Code != "" {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: aliResponse.Message,
|
||||||
|
Type: aliResponse.Code,
|
||||||
|
Param: aliResponse.RequestId,
|
||||||
|
Code: aliResponse.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||||
|
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
||||||
|
Object: "list",
|
||||||
|
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||||
|
Model: "text-embedding-v1",
|
||||||
|
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range response.Output.Embeddings {
|
||||||
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
||||||
|
Object: `embedding`,
|
||||||
|
Index: item.TextIndex,
|
||||||
|
Embedding: item.Embedding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &openAIEmbeddingResponse
|
||||||
|
}
|
||||||
|
|
||||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
||||||
choice := OpenAITextResponseChoice{
|
choice := OpenAITextResponseChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
|
147
controller/relay-audio.go
Normal file
147
controller/relay-audio.go
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
|
audioModel := "whisper-1"
|
||||||
|
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
channelType := c.GetInt("channel")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
group := c.GetString("group")
|
||||||
|
|
||||||
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
|
modelRatio := common.GetModelRatio(audioModel)
|
||||||
|
groupRatio := common.GetGroupRatio(group)
|
||||||
|
ratio := modelRatio * groupRatio
|
||||||
|
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||||
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if userQuota > 100*preConsumedQuota {
|
||||||
|
// in this case, we do not pre-consume quota
|
||||||
|
// because the user has enough quota
|
||||||
|
preConsumedQuota = 0
|
||||||
|
}
|
||||||
|
if preConsumedQuota > 0 {
|
||||||
|
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// map model name
|
||||||
|
modelMapping := c.GetString("model_mapping")
|
||||||
|
if modelMapping != "" {
|
||||||
|
modelMap := make(map[string]string)
|
||||||
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if modelMap[audioModel] != "" {
|
||||||
|
audioModel = modelMap[audioModel]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
|
requestURL := c.Request.URL.String()
|
||||||
|
|
||||||
|
if c.GetString("base_url") != "" {
|
||||||
|
baseURL = c.GetString("base_url")
|
||||||
|
}
|
||||||
|
|
||||||
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
requestBody := c.Request.Body
|
||||||
|
|
||||||
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||||
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = req.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
err = c.Request.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
var audioResponse AudioResponse
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
go func() {
|
||||||
|
quota := countTokenText(audioResponse.Text, audioModel)
|
||||||
|
quotaDelta := quota - preConsumedQuota
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
|
}
|
||||||
|
err = model.CacheUpdateUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error update user quota cache: " + err.Error())
|
||||||
|
}
|
||||||
|
if quota != 0 {
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
|
model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}()
|
||||||
|
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &audioResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
|
|||||||
}
|
}
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
||||||
baiduEmbeddingRequest := BaiduEmbeddingRequest{
|
return &BaiduEmbeddingRequest{
|
||||||
Input: nil,
|
Input: request.ParseInput(),
|
||||||
}
|
}
|
||||||
switch request.Input.(type) {
|
|
||||||
case string:
|
|
||||||
baiduEmbeddingRequest.Input = []string{request.Input.(string)}
|
|
||||||
case []any:
|
|
||||||
for _, item := range request.Input.([]any) {
|
|
||||||
if str, ok := item.(string); ok {
|
|
||||||
baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &baiduEmbeddingRequest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||||
|
@ -22,6 +22,7 @@ const (
|
|||||||
APITypeZhipu
|
APITypeZhipu
|
||||||
APITypeAli
|
APITypeAli
|
||||||
APITypeXunfei
|
APITypeXunfei
|
||||||
|
APITypeAIProxyLibrary
|
||||||
)
|
)
|
||||||
|
|
||||||
var httpClient *http.Client
|
var httpClient *http.Client
|
||||||
@ -104,6 +105,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
apiType = APITypeAli
|
apiType = APITypeAli
|
||||||
case common.ChannelTypeXunfei:
|
case common.ChannelTypeXunfei:
|
||||||
apiType = APITypeXunfei
|
apiType = APITypeXunfei
|
||||||
|
case common.ChannelTypeAIProxyLibrary:
|
||||||
|
apiType = APITypeAIProxyLibrary
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
@ -172,6 +175,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||||
case APITypeAli:
|
case APITypeAli:
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||||
|
if relayMode == RelayModeEmbeddings {
|
||||||
|
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||||
|
}
|
||||||
|
case APITypeAIProxyLibrary:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
||||||
}
|
}
|
||||||
var promptTokens int
|
var promptTokens int
|
||||||
var completionTokens int
|
var completionTokens int
|
||||||
@ -258,8 +266,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
case APITypeAli:
|
case APITypeAli:
|
||||||
aliRequest := requestOpenAI2Ali(textRequest)
|
var jsonStr []byte
|
||||||
jsonStr, err := json.Marshal(aliRequest)
|
var err error
|
||||||
|
switch relayMode {
|
||||||
|
case RelayModeEmbeddings:
|
||||||
|
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
|
||||||
|
jsonStr, err = json.Marshal(aliEmbeddingRequest)
|
||||||
|
default:
|
||||||
|
aliRequest := requestOpenAI2Ali(textRequest)
|
||||||
|
jsonStr, err = json.Marshal(aliRequest)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
case APITypeAIProxyLibrary:
|
||||||
|
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
|
||||||
|
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
||||||
|
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -287,6 +311,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
req.Header.Set("api-key", apiKey)
|
req.Header.Set("api-key", apiKey)
|
||||||
} else {
|
} else {
|
||||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||||
|
if channelType == common.ChannelTypeOpenRouter {
|
||||||
|
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||||
|
req.Header.Set("X-Title", "One API")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case APITypeClaude:
|
case APITypeClaude:
|
||||||
req.Header.Set("x-api-key", apiKey)
|
req.Header.Set("x-api-key", apiKey)
|
||||||
@ -303,6 +331,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if textRequest.Stream {
|
if textRequest.Stream {
|
||||||
req.Header.Set("X-DashScope-SSE", "enable")
|
req.Header.Set("X-DashScope-SSE", "enable")
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
@ -365,7 +395,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
|
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
|
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -491,7 +520,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
err, usage := aliHandler(c, resp)
|
var err *OpenAIErrorWithStatusCode
|
||||||
|
var usage *Usage
|
||||||
|
switch relayMode {
|
||||||
|
case RelayModeEmbeddings:
|
||||||
|
err, usage = aliEmbeddingHandler(c, resp)
|
||||||
|
default:
|
||||||
|
err, usage = aliHandler(c, resp)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -519,6 +555,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
} else {
|
} else {
|
||||||
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
|
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
case APITypeAIProxyLibrary:
|
||||||
|
if isStream {
|
||||||
|
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
err, usage := aiProxyLibraryHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,24 @@ var stopFinishReason = "stop"
|
|||||||
|
|
||||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||||
|
|
||||||
|
func InitTokenEncoders() {
|
||||||
|
common.SysLog("initializing token encoders")
|
||||||
|
fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||||
|
if err != nil {
|
||||||
|
common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
|
||||||
|
}
|
||||||
|
for model, _ := range common.ModelRatio {
|
||||||
|
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
|
||||||
|
tokenEncoderMap[model] = fallbackTokenEncoder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tokenEncoderMap[model] = tokenEncoder
|
||||||
|
}
|
||||||
|
common.SysLog("token encoders initialized")
|
||||||
|
}
|
||||||
|
|
||||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||||
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
|
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
|
||||||
return tokenEncoder
|
return tokenEncoder
|
||||||
|
@ -2,7 +2,6 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -29,6 +28,7 @@ const (
|
|||||||
RelayModeMidjourneyChange
|
RelayModeMidjourneyChange
|
||||||
RelayModeMidjourneyNotify
|
RelayModeMidjourneyNotify
|
||||||
RelayModeMidjourneyTaskFetch
|
RelayModeMidjourneyTaskFetch
|
||||||
|
RelayModeAudio
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/chat
|
// https://platform.openai.com/docs/api-reference/chat
|
||||||
@ -45,6 +45,26 @@ type GeneralOpenAIRequest struct {
|
|||||||
Input any `json:"input,omitempty"`
|
Input any `json:"input,omitempty"`
|
||||||
Instruction string `json:"instruction,omitempty"`
|
Instruction string `json:"instruction,omitempty"`
|
||||||
Size string `json:"size,omitempty"`
|
Size string `json:"size,omitempty"`
|
||||||
|
Functions any `json:"functions,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||||
|
if r.Input == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var input []string
|
||||||
|
switch r.Input.(type) {
|
||||||
|
case string:
|
||||||
|
input = []string{r.Input.(string)}
|
||||||
|
case []any:
|
||||||
|
input = make([]string, 0, len(r.Input.([]any)))
|
||||||
|
for _, item := range r.Input.([]any) {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
input = append(input, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
@ -67,6 +87,10 @@ type ImageRequest struct {
|
|||||||
Size string `json:"size"`
|
Size string `json:"size"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AudioResponse struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type Usage struct {
|
type Usage struct {
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
@ -147,23 +171,6 @@ type CompletionsStreamResponse struct {
|
|||||||
} `json:"choices"`
|
} `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MidjourneyRequest struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
NotifyHook string `json:"notifyHook"`
|
|
||||||
Action string `json:"action"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
State string `json:"state"`
|
|
||||||
TaskId string `json:"taskId"`
|
|
||||||
Base64Array []string `json:"base64Array"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MidjourneyResponse struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Properties interface{} `json:"properties"`
|
|
||||||
Result string `json:"result"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
relayMode := RelayModeUnknown
|
relayMode := RelayModeUnknown
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||||
|
@ -523,5 +523,6 @@
|
|||||||
"按照如下格式输入:": "Enter in the following format:",
|
"按照如下格式输入:": "Enter in the following format:",
|
||||||
"模型版本": "Model version",
|
"模型版本": "Model version",
|
||||||
"请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
|
"请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
|
||||||
"点击查看": "click to view"
|
"点击查看": "click to view",
|
||||||
|
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
|
||||||
}
|
}
|
||||||
|
6
main.go
6
main.go
@ -78,6 +78,12 @@ func main() {
|
|||||||
go controller.AutomaticallyTestChannels(frequency)
|
go controller.AutomaticallyTestChannels(frequency)
|
||||||
}
|
}
|
||||||
go controller.UpdateMidjourneyTask()
|
go controller.UpdateMidjourneyTask()
|
||||||
|
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
||||||
|
common.BatchUpdateEnabled = true
|
||||||
|
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||||
|
model.InitBatchUpdater()
|
||||||
|
}
|
||||||
|
controller.InitTokenEncoders()
|
||||||
|
|
||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.Default()
|
server := gin.Default()
|
||||||
|
@ -109,7 +109,18 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !model.CacheIsUserEnabled(token.UserId) {
|
userEnabled, err := model.IsUserEnabled(token.UserId)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"message": err.Error(),
|
||||||
|
"type": "one_api_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !userEnabled {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": "用户已被封禁",
|
"message": "用户已被封禁",
|
||||||
|
@ -2,7 +2,6 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
@ -22,7 +21,6 @@ func Distribute() func(c *gin.Context) {
|
|||||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||||
c.Set("group", userGroup)
|
c.Set("group", userGroup)
|
||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
var err error
|
|
||||||
channelId, ok := c.Get("channelId")
|
channelId, ok := c.Get("channelId")
|
||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
@ -58,7 +56,6 @@ func Distribute() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
// Select a channel for the user
|
// Select a channel for the user
|
||||||
var modelRequest ModelRequest
|
var modelRequest ModelRequest
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
|
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
|
||||||
@ -79,7 +76,6 @@ func Distribute() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.Model == "" {
|
||||||
modelRequest.Model = "text-moderation-stable"
|
modelRequest.Model = "text-moderation-stable"
|
||||||
@ -95,6 +91,11 @@ func Distribute() func(c *gin.Context) {
|
|||||||
modelRequest.Model = "dall-e"
|
modelRequest.Model = "dall-e"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
modelRequest.Model = "whisper-1"
|
||||||
|
}
|
||||||
|
}
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||||
@ -118,8 +119,13 @@ func Distribute() func(c *gin.Context) {
|
|||||||
c.Set("model_mapping", channel.ModelMapping)
|
c.Set("model_mapping", channel.ModelMapping)
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set("base_url", channel.BaseURL)
|
c.Set("base_url", channel.BaseURL)
|
||||||
if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
|
switch channel.Type {
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeXunfei:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeAIProxyLibrary:
|
||||||
|
c.Set("library_id", channel.Other)
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheIsUserEnabled(userId int) bool {
|
func CacheIsUserEnabled(userId int) (bool, error) {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return IsUserEnabled(userId)
|
return IsUserEnabled(userId)
|
||||||
}
|
}
|
||||||
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
||||||
if err != nil {
|
if err == nil {
|
||||||
status := common.UserStatusDisabled
|
return enabled == "1", nil
|
||||||
if IsUserEnabled(userId) {
|
|
||||||
status = common.UserStatusEnabled
|
|
||||||
}
|
|
||||||
enabled = fmt.Sprintf("%d", status)
|
|
||||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("Redis set user enabled error: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return enabled == "1"
|
|
||||||
|
userEnabled, err := IsUserEnabled(userId)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
enabled = "0"
|
||||||
|
if userEnabled {
|
||||||
|
enabled = "1"
|
||||||
|
}
|
||||||
|
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("Redis set user enabled error: " + err.Error())
|
||||||
|
}
|
||||||
|
return userEnabled, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var group2model2channels map[string]map[string][]*Channel
|
var group2model2channels map[string]map[string][]*Channel
|
||||||
|
@ -141,6 +141,14 @@ func UpdateChannelStatusById(id int, status int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannelUsedQuota(id int, quota int) {
|
func UpdateChannelUsedQuota(id int, quota int) {
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updateChannelUsedQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateChannelUsedQuota(id int, quota int) {
|
||||||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update channel used quota: " + err.Error())
|
common.SysError("failed to update channel used quota: " + err.Error())
|
||||||
|
@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
}
|
}
|
||||||
token, err = CacheGetTokenByKey(key)
|
token, err = CacheGetTokenByKey(key)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
if token.Status == common.TokenStatusExhausted {
|
||||||
|
return nil, errors.New("该令牌额度已用尽")
|
||||||
|
} else if token.Status == common.TokenStatusExpired {
|
||||||
|
return nil, errors.New("该令牌已过期")
|
||||||
|
}
|
||||||
if token.Status != common.TokenStatusEnabled {
|
if token.Status != common.TokenStatusEnabled {
|
||||||
return nil, errors.New("该令牌状态不可用")
|
return nil, errors.New("该令牌状态不可用")
|
||||||
}
|
}
|
||||||
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
||||||
token.Status = common.TokenStatusExpired
|
if !common.RedisEnabled {
|
||||||
err := token.SelectUpdate()
|
token.Status = common.TokenStatusExpired
|
||||||
if err != nil {
|
err := token.SelectUpdate()
|
||||||
common.SysError("failed to update token status" + err.Error())
|
if err != nil {
|
||||||
|
common.SysError("failed to update token status" + err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("该令牌已过期")
|
return nil, errors.New("该令牌已过期")
|
||||||
}
|
}
|
||||||
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
||||||
token.Status = common.TokenStatusExhausted
|
if !common.RedisEnabled {
|
||||||
err := token.SelectUpdate()
|
// in this case, we can make sure the token is exhausted
|
||||||
if err != nil {
|
token.Status = common.TokenStatusExhausted
|
||||||
common.SysError("failed to update token status" + err.Error())
|
err := token.SelectUpdate()
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to update token status" + err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("该令牌额度已用尽")
|
return nil, errors.New("该令牌额度已用尽")
|
||||||
}
|
}
|
||||||
go func() {
|
|
||||||
token.AccessedTime = common.GetTimestamp()
|
|
||||||
err := token.SelectUpdate()
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("failed to update token" + err.Error())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("无效的令牌")
|
return nil, errors.New("无效的令牌")
|
||||||
@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return increaseTokenQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func increaseTokenQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
||||||
"used_quota": gorm.Expr("used_quota - ?", quota),
|
"used_quota": gorm.Expr("used_quota - ?", quota),
|
||||||
|
"accessed_time": common.GetTimestamp(),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
return err
|
return err
|
||||||
@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return decreaseTokenQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decreaseTokenQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
||||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
|
"accessed_time": common.GetTimestamp(),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
return err
|
return err
|
||||||
|
@ -235,17 +235,16 @@ func IsAdmin(userId int) bool {
|
|||||||
return user.Role >= common.RoleAdminUser
|
return user.Role >= common.RoleAdminUser
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsUserEnabled(userId int) bool {
|
func IsUserEnabled(userId int) (bool, error) {
|
||||||
if userId == 0 {
|
if userId == 0 {
|
||||||
return false
|
return false, errors.New("user id is empty")
|
||||||
}
|
}
|
||||||
var user User
|
var user User
|
||||||
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
|
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("no such user " + err.Error())
|
return false, err
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
return user.Status == common.UserStatusEnabled
|
return user.Status == common.UserStatusEnabled, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateAccessToken(token string) (user *User) {
|
func ValidateAccessToken(token string) (user *User) {
|
||||||
@ -284,6 +283,14 @@ func IncreaseUserQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return increaseUserQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func increaseUserQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -292,6 +299,14 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return decreaseUserQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decreaseUserQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -302,10 +317,18 @@ func GetRootUserEmail() (email string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updateUserUsedQuotaAndRequestCount(id, quota, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
||||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
"request_count": gorm.Expr("request_count + ?", 1),
|
"request_count": gorm.Expr("request_count + ?", count),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
75
model/utils.go
Normal file
75
model/utils.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/common"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
|
||||||
|
|
||||||
|
const (
|
||||||
|
BatchUpdateTypeUserQuota = iota
|
||||||
|
BatchUpdateTypeTokenQuota
|
||||||
|
BatchUpdateTypeUsedQuotaAndRequestCount
|
||||||
|
BatchUpdateTypeChannelUsedQuota
|
||||||
|
)
|
||||||
|
|
||||||
|
var batchUpdateStores []map[int]int
|
||||||
|
var batchUpdateLocks []sync.Mutex
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
|
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
|
||||||
|
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitBatchUpdater() {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
|
||||||
|
batchUpdate()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func addNewRecord(type_ int, id int, value int) {
|
||||||
|
batchUpdateLocks[type_].Lock()
|
||||||
|
defer batchUpdateLocks[type_].Unlock()
|
||||||
|
if _, ok := batchUpdateStores[type_][id]; !ok {
|
||||||
|
batchUpdateStores[type_][id] = value
|
||||||
|
} else {
|
||||||
|
batchUpdateStores[type_][id] += value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func batchUpdate() {
|
||||||
|
common.SysLog("batch update started")
|
||||||
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
|
batchUpdateLocks[i].Lock()
|
||||||
|
store := batchUpdateStores[i]
|
||||||
|
batchUpdateStores[i] = make(map[int]int)
|
||||||
|
batchUpdateLocks[i].Unlock()
|
||||||
|
|
||||||
|
for key, value := range store {
|
||||||
|
switch i {
|
||||||
|
case BatchUpdateTypeUserQuota:
|
||||||
|
err := increaseUserQuota(key, value)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to batch update user quota: " + err.Error())
|
||||||
|
}
|
||||||
|
case BatchUpdateTypeTokenQuota:
|
||||||
|
err := increaseTokenQuota(key, value)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to batch update token quota: " + err.Error())
|
||||||
|
}
|
||||||
|
case BatchUpdateTypeUsedQuotaAndRequestCount:
|
||||||
|
updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
|
||||||
|
case BatchUpdateTypeChannelUsedQuota:
|
||||||
|
updateChannelUsedQuota(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
common.SysLog("batch update finished")
|
||||||
|
}
|
@ -26,8 +26,8 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
|
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
|
||||||
relayV1Router.POST("/embeddings", controller.Relay)
|
relayV1Router.POST("/embeddings", controller.Relay)
|
||||||
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
||||||
relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
|
relayV1Router.POST("/audio/transcriptions", controller.Relay)
|
||||||
relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
|
relayV1Router.POST("/audio/translations", controller.Relay)
|
||||||
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
||||||
relayV1Router.POST("/files", controller.RelayNotImplemented)
|
relayV1Router.POST("/files", controller.RelayNotImplemented)
|
||||||
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
||||||
|
@ -324,7 +324,7 @@ const LogsTable = () => {
|
|||||||
.map((log, idx) => {
|
.map((log, idx) => {
|
||||||
if (log.deleted) return <></>;
|
if (log.deleted) return <></>;
|
||||||
return (
|
return (
|
||||||
<Table.Row key={log.created_at}>
|
<Table.Row key={log.id}>
|
||||||
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
|
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
|
||||||
{
|
{
|
||||||
isAdminUser && (
|
isAdminUser && (
|
||||||
|
@ -7,7 +7,11 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
|
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
|
||||||
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
|
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
|
||||||
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
|
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
|
||||||
|
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
|
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||||
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
|
{ key: 20, text: '代理:OpenRouter', value: 20, color: 'black' },
|
||||||
{ key: 2, text: '代理:API2D', value: 2, color: 'blue' },
|
{ key: 2, text: '代理:API2D', value: 2, color: 'blue' },
|
||||||
{ key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
|
{ key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
|
||||||
{ key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' },
|
{ key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' },
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
|
import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
|
||||||
import { useParams, useNavigate } from 'react-router-dom';
|
import { useNavigate, useParams } from 'react-router-dom';
|
||||||
import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
|
import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
|
||||||
import { CHANNEL_OPTIONS } from '../../constants';
|
import { CHANNEL_OPTIONS } from '../../constants';
|
||||||
|
|
||||||
@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = {
|
|||||||
'gpt-4-32k-0314': 'gpt-4-32k'
|
'gpt-4-32k-0314': 'gpt-4-32k'
|
||||||
};
|
};
|
||||||
|
|
||||||
|
function type2secretPrompt(type) {
|
||||||
|
// inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
|
||||||
|
switch (type) {
|
||||||
|
case 15:
|
||||||
|
return '按照如下格式输入:APIKey|SecretKey';
|
||||||
|
case 18:
|
||||||
|
return '按照如下格式输入:APPID|APISecret|APIKey';
|
||||||
|
case 22:
|
||||||
|
return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
|
||||||
|
default:
|
||||||
|
return '请输入渠道对应的鉴权密钥';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const EditChannel = () => {
|
const EditChannel = () => {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
@ -53,7 +67,7 @@ const EditChannel = () => {
|
|||||||
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
|
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
|
||||||
break;
|
break;
|
||||||
case 17:
|
case 17:
|
||||||
localModels = ['qwen-v1', 'qwen-plus-v1'];
|
localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
|
||||||
break;
|
break;
|
||||||
case 16:
|
case 16:
|
||||||
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
||||||
@ -61,6 +75,9 @@ const EditChannel = () => {
|
|||||||
case 18:
|
case 18:
|
||||||
localModels = ['SparkDesk'];
|
localModels = ['SparkDesk'];
|
||||||
break;
|
break;
|
||||||
|
case 19:
|
||||||
|
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'];
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
setInputs((inputs) => ({ ...inputs, models: localModels }));
|
setInputs((inputs) => ({ ...inputs, models: localModels }));
|
||||||
}
|
}
|
||||||
@ -190,6 +207,24 @@ const EditChannel = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const addCustomModel = () => {
|
||||||
|
if (customModel.trim() === '') return;
|
||||||
|
if (inputs.models.includes(customModel)) return;
|
||||||
|
let localModels = [...inputs.models];
|
||||||
|
localModels.push(customModel);
|
||||||
|
let localModelOptions = [];
|
||||||
|
localModelOptions.push({
|
||||||
|
key: customModel,
|
||||||
|
text: customModel,
|
||||||
|
value: customModel
|
||||||
|
});
|
||||||
|
setModelOptions(modelOptions => {
|
||||||
|
return [...modelOptions, ...localModelOptions];
|
||||||
|
});
|
||||||
|
setCustomModel('');
|
||||||
|
handleInputChange(null, { name: 'models', value: localModels });
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Segment loading={loading}>
|
<Segment loading={loading}>
|
||||||
@ -292,6 +327,20 @@ const EditChannel = () => {
|
|||||||
</Form.Field>
|
</Form.Field>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
inputs.type === 21 && (
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Input
|
||||||
|
label='知识库 ID'
|
||||||
|
name='other'
|
||||||
|
placeholder={'请输入知识库 ID,例如:123456'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.other}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Dropdown
|
<Form.Dropdown
|
||||||
label='模型'
|
label='模型'
|
||||||
@ -319,29 +368,19 @@ const EditChannel = () => {
|
|||||||
}}>清除所有模型</Button>
|
}}>清除所有模型</Button>
|
||||||
<Input
|
<Input
|
||||||
action={
|
action={
|
||||||
<Button type={'button'} onClick={() => {
|
<Button type={'button'} onClick={addCustomModel}>填入</Button>
|
||||||
if (customModel.trim() === '') return;
|
|
||||||
if (inputs.models.includes(customModel)) return;
|
|
||||||
let localModels = [...inputs.models];
|
|
||||||
localModels.push(customModel);
|
|
||||||
let localModelOptions = [];
|
|
||||||
localModelOptions.push({
|
|
||||||
key: customModel,
|
|
||||||
text: customModel,
|
|
||||||
value: customModel
|
|
||||||
});
|
|
||||||
setModelOptions(modelOptions => {
|
|
||||||
return [...modelOptions, ...localModelOptions];
|
|
||||||
});
|
|
||||||
setCustomModel('');
|
|
||||||
handleInputChange(null, { name: 'models', value: localModels });
|
|
||||||
}}>填入</Button>
|
|
||||||
}
|
}
|
||||||
placeholder='输入自定义模型名称'
|
placeholder='输入自定义模型名称'
|
||||||
value={customModel}
|
value={customModel}
|
||||||
onChange={(e, { value }) => {
|
onChange={(e, { value }) => {
|
||||||
setCustomModel(value);
|
setCustomModel(value);
|
||||||
}}
|
}}
|
||||||
|
onKeyDown={(e) => {
|
||||||
|
if (e.key === 'Enter') {
|
||||||
|
addCustomModel();
|
||||||
|
e.preventDefault();
|
||||||
|
}
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
@ -372,7 +411,7 @@ const EditChannel = () => {
|
|||||||
label='密钥'
|
label='密钥'
|
||||||
name='key'
|
name='key'
|
||||||
required
|
required
|
||||||
placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
|
placeholder={type2secretPrompt(inputs.type)}
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
value={inputs.key}
|
value={inputs.key}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
@ -390,7 +429,7 @@ const EditChannel = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
inputs.type !== 3 && inputs.type !== 8 && (
|
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
label='代理'
|
label='代理'
|
||||||
@ -403,6 +442,20 @@ const EditChannel = () => {
|
|||||||
</Form.Field>
|
</Form.Field>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
inputs.type === 22 && (
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Input
|
||||||
|
label='私有部署地址'
|
||||||
|
name='base_url'
|
||||||
|
placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.base_url}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Button onClick={handleCancel}>取消</Button>
|
<Button onClick={handleCancel}>取消</Button>
|
||||||
<Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
|
<Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
|
||||||
</Form>
|
</Form>
|
||||||
|
Loading…
Reference in New Issue
Block a user