From 07b2fd58d646817e597bd6edceabb1c12e57ed38 Mon Sep 17 00:00:00 2001
From: Buer <42402987+MartialBE@users.noreply.github.com>
Date: Tue, 28 May 2024 01:22:40 +0800
Subject: [PATCH 1/8] feat: berry theme update & bug fix (#1471)
* feat: load channel models from server
* chore: support AWS Claude/Cloudflare/Coze
* fix: Popup message when copying fails
* chore: Optimize tips
---
web/berry/src/constants/ChannelConstants.js | 30 +-
web/berry/src/constants/SnackbarConstants.js | 42 +-
web/berry/src/utils/common.js | 37 ++
.../AuthForms/ResetPasswordForm.js | 43 +-
.../src/views/Channel/component/EditModal.js | 502 +++++++-----------
.../src/views/Channel/component/NameLabel.js | 27 +-
web/berry/src/views/Channel/index.js | 8 +-
web/berry/src/views/Channel/type/Config.js | 196 ++++---
web/berry/src/views/Profile/index.js | 5 +-
.../views/Redemption/component/TableRow.js | 5 +-
.../src/views/Token/component/TableRow.js | 14 +-
.../src/views/Topup/component/InviteCard.js | 8 +-
12 files changed, 448 insertions(+), 469 deletions(-)
diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js
index e6b0aed5..589ef1fb 100644
--- a/web/berry/src/constants/ChannelConstants.js
+++ b/web/berry/src/constants/ChannelConstants.js
@@ -11,12 +11,18 @@ export const CHANNEL_OPTIONS = {
value: 14,
color: 'primary'
},
- // 33: {
- // key: 33,
- // text: 'AWS Claude',
- // value: 33,
- // color: 'primary'
- // },
+ 33: {
+ key: 33,
+ text: 'AWS Claude',
+ value: 33,
+ color: 'primary'
+ },
+ 37: {
+ key: 37,
+ text: 'Cloudflare',
+ value: 37,
+ color: 'success'
+ },
3: {
key: 3,
text: 'Azure OpenAI',
@@ -119,12 +125,12 @@ export const CHANNEL_OPTIONS = {
value: 32,
color: 'primary'
},
- // 34: {
- // key: 34,
- // text: 'Coze',
- // value: 34,
- // color: 'primary'
- // },
+ 34: {
+ key: 34,
+ text: 'Coze',
+ value: 34,
+ color: 'primary'
+ },
35: {
key: 35,
text: 'Cohere',
diff --git a/web/berry/src/constants/SnackbarConstants.js b/web/berry/src/constants/SnackbarConstants.js
index 19523da1..05f79231 100644
--- a/web/berry/src/constants/SnackbarConstants.js
+++ b/web/berry/src/constants/SnackbarConstants.js
@@ -1,24 +1,56 @@
+import { closeSnackbar } from 'notistack';
+import { IconX } from '@tabler/icons-react';
+import { IconButton } from '@mui/material';
+const action = (snackbarId) => (
+ <>
+ {
+ closeSnackbar(snackbarId);
+ }}
+ >
+
+
+ >
+);
+
export const snackbarConstants = {
Common: {
ERROR: {
variant: 'error',
- autoHideDuration: 5000
+ autoHideDuration: 5000,
+ preventDuplicate: true,
+ action
},
WARNING: {
variant: 'warning',
- autoHideDuration: 10000
+ autoHideDuration: 10000,
+ preventDuplicate: true,
+ action
},
SUCCESS: {
variant: 'success',
- autoHideDuration: 1500
+ autoHideDuration: 1500,
+ preventDuplicate: true,
+ action
},
INFO: {
variant: 'info',
- autoHideDuration: 3000
+ autoHideDuration: 3000,
+ preventDuplicate: true,
+ action
},
NOTICE: {
variant: 'info',
- autoHideDuration: 7000
+ autoHideDuration: 20000,
+ preventDuplicate: true,
+ action
+ },
+ COPY: {
+ variant: 'copy',
+ persist: true,
+ preventDuplicate: true,
+ allowDownload: true,
+ action
}
},
Mobile: {
diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js
index 947df3bf..d74d032e 100644
--- a/web/berry/src/utils/common.js
+++ b/web/berry/src/utils/common.js
@@ -193,3 +193,40 @@ export function removeTrailingSlash(url) {
return url;
}
}
+
+let channelModels = undefined;
+export async function loadChannelModels() {
+ const res = await API.get('/api/models');
+ const { success, data } = res.data;
+ if (!success) {
+ return;
+ }
+ channelModels = data;
+ localStorage.setItem('channel_models', JSON.stringify(data));
+}
+
+export function getChannelModels(type) {
+ if (channelModels !== undefined && type in channelModels) {
+ return channelModels[type];
+ }
+ let models = localStorage.getItem('channel_models');
+ if (!models) {
+ return [];
+ }
+ channelModels = JSON.parse(models);
+ if (type in channelModels) {
+ return channelModels[type];
+ }
+ return [];
+}
+
+export function copy(text, name = '') {
+ try {
+ navigator.clipboard.writeText(text);
+ } catch (error) {
+ text = `复制${name}失败,请手动复制:
${text}`;
+ enqueueSnackbar(, getSnackbarOptions('COPY'));
+ return;
+ }
+ showSuccess(`复制${name}成功!`);
+}
diff --git a/web/berry/src/views/Authentication/AuthForms/ResetPasswordForm.js b/web/berry/src/views/Authentication/AuthForms/ResetPasswordForm.js
index eaa8dc95..a9f0f9e3 100644
--- a/web/berry/src/views/Authentication/AuthForms/ResetPasswordForm.js
+++ b/web/berry/src/views/Authentication/AuthForms/ResetPasswordForm.js
@@ -1,22 +1,22 @@
-import { useState, useEffect } from "react";
-import { useSearchParams } from "react-router-dom";
+import { useState, useEffect } from 'react';
+import { useSearchParams } from 'react-router-dom';
// material-ui
-import { Button, Stack, Typography, Alert } from "@mui/material";
+import { Button, Stack, Typography, Alert } from '@mui/material';
// assets
-import { showError, showInfo } from "utils/common";
-import { API } from "utils/api";
+import { showError, copy } from 'utils/common';
+import { API } from 'utils/api';
// ===========================|| FIREBASE - REGISTER ||=========================== //
const ResetPasswordForm = () => {
const [searchParams] = useSearchParams();
const [inputs, setInputs] = useState({
- email: "",
- token: "",
+ email: '',
+ token: ''
});
- const [newPassword, setNewPassword] = useState("");
+ const [newPassword, setNewPassword] = useState('');
const submit = async () => {
const res = await API.post(`/api/user/reset`, inputs);
@@ -24,31 +24,25 @@ const ResetPasswordForm = () => {
if (success) {
let password = res.data.data;
setNewPassword(password);
- navigator.clipboard.writeText(password);
- showInfo(`新密码已复制到剪贴板:${password}`);
+ copy(password, '新密码');
} else {
showError(message);
}
};
useEffect(() => {
- let email = searchParams.get("email");
- let token = searchParams.get("token");
+ let email = searchParams.get('email');
+ let token = searchParams.get('token');
setInputs({
token,
- email,
+ email
});
}, []);
return (
-
+
{!inputs.email || !inputs.token ? (
-
+
无效的链接
) : newPassword ? (
@@ -57,14 +51,7 @@ const ResetPasswordForm = () => {
请登录后及时修改密码
) : (
-
-
+
{matchUpMd ? (
-
+
}>
刷新
diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js
index 7e42ca8d..88e1ea92 100644
--- a/web/berry/src/views/Channel/type/Config.js
+++ b/web/berry/src/views/Channel/type/Config.js
@@ -1,177 +1,209 @@
const defaultConfig = {
input: {
- name: "",
+ name: '',
type: 1,
- key: "",
- base_url: "",
- other: "",
- model_mapping: "",
+ key: '',
+ base_url: '',
+ other: '',
+ model_mapping: '',
models: [],
- groups: ["default"],
+ groups: ['default'],
+ config: {}
},
inputLabel: {
- name: "渠道名称",
- type: "渠道类型",
- base_url: "渠道API地址",
- key: "密钥",
- other: "其他参数",
- models: "模型",
- model_mapping: "模型映射关系",
- groups: "用户组",
+ name: '渠道名称',
+ type: '渠道类型',
+ base_url: '渠道API地址',
+ key: '密钥',
+ other: '其他参数',
+ models: '模型',
+ model_mapping: '模型映射关系',
+ groups: '用户组',
+ config: null
},
prompt: {
- type: "请选择渠道类型",
- name: "请为渠道命名",
- base_url: "可空,请输入中转API地址,例如通过cloudflare中转",
- key: "请输入渠道对应的鉴权密钥",
- other: "",
- models: "请选择该渠道所支持的模型",
+ type: '请选择渠道类型',
+ name: '请为渠道命名',
+ base_url: '可空,请输入中转API地址,例如通过cloudflare中转',
+ key: '请输入渠道对应的鉴权密钥',
+ other: '',
+ models: '请选择该渠道所支持的模型',
model_mapping:
'请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}',
- groups: "请选择该渠道所支持的用户组",
+ groups: '请选择该渠道所支持的用户组',
+ config: null
},
- modelGroup: "openai",
+ modelGroup: 'openai'
};
const typeConfig = {
3: {
inputLabel: {
- base_url: "AZURE_OPENAI_ENDPOINT",
- other: "默认 API 版本",
+ base_url: 'AZURE_OPENAI_ENDPOINT',
+ other: '默认 API 版本'
},
prompt: {
- base_url: "请填写AZURE_OPENAI_ENDPOINT",
- other: "请输入默认API版本,例如:2024-03-01-preview",
- },
+ base_url: '请填写AZURE_OPENAI_ENDPOINT',
+ other: '请输入默认API版本,例如:2024-03-01-preview'
+ }
},
11: {
input: {
- models: ["PaLM-2"],
+ models: ['PaLM-2']
},
- modelGroup: "google palm",
+ modelGroup: 'google palm'
},
14: {
input: {
- models: ["claude-instant-1", "claude-2", "claude-2.0", "claude-2.1"],
+ models: ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1']
},
- modelGroup: "anthropic",
+ modelGroup: 'anthropic'
},
15: {
input: {
- models: ["ERNIE-Bot", "ERNIE-Bot-turbo", "ERNIE-Bot-4", "Embedding-V1"],
+ models: ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']
},
prompt: {
- key: "按照如下格式输入:APIKey|SecretKey",
+ key: '按照如下格式输入:APIKey|SecretKey'
},
- modelGroup: "baidu",
+ modelGroup: 'baidu'
},
16: {
input: {
- models: ["glm-4", "glm-4v", "glm-3-turbo", "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite"],
+ models: ['glm-4', 'glm-4v', 'glm-3-turbo', 'chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']
},
- modelGroup: "zhipu",
+ modelGroup: 'zhipu'
},
17: {
inputLabel: {
- other: "插件参数",
+ other: '插件参数'
},
input: {
- models: [
- "qwen-turbo",
- "qwen-plus",
- "qwen-max",
- "qwen-max-longcontext",
- "text-embedding-v1",
- ],
+ models: ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']
},
prompt: {
- other: "请输入插件参数,即 X-DashScope-Plugin 请求头的取值",
+ other: '请输入插件参数,即 X-DashScope-Plugin 请求头的取值'
},
- modelGroup: "ali",
+ modelGroup: 'ali'
},
18: {
inputLabel: {
- other: "版本号",
+ other: '版本号'
},
input: {
- models: [
- "SparkDesk",
- 'SparkDesk-v1.1',
- 'SparkDesk-v2.1',
- 'SparkDesk-v3.1',
- 'SparkDesk-v3.5'
- ],
+ models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5']
},
prompt: {
- key: "按照如下格式输入:APPID|APISecret|APIKey",
- other: "请输入版本号,例如:v3.1",
+ key: '按照如下格式输入:APPID|APISecret|APIKey',
+ other: '请输入版本号,例如:v3.1'
},
- modelGroup: "xunfei",
+ modelGroup: 'xunfei'
},
19: {
input: {
- models: [
- "360GPT_S2_V9",
- "embedding-bert-512-v1",
- "embedding_s1_v1",
- "semantic_similarity_s1_v1",
- ],
+ models: ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']
},
- modelGroup: "360",
+ modelGroup: '360'
},
22: {
prompt: {
- key: "按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041",
- },
+ key: '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'
+ }
},
23: {
input: {
- models: ["hunyuan"],
+ models: ['hunyuan']
},
prompt: {
- key: "按照如下格式输入:AppId|SecretId|SecretKey",
+ key: '按照如下格式输入:AppId|SecretId|SecretKey'
},
- modelGroup: "tencent",
+ modelGroup: 'tencent'
},
24: {
inputLabel: {
- other: "版本号",
+ other: '版本号'
},
input: {
- models: ["gemini-pro"],
+ models: ['gemini-pro']
},
prompt: {
- other: "请输入版本号,例如:v1",
+ other: '请输入版本号,例如:v1'
},
- modelGroup: "google gemini",
+ modelGroup: 'google gemini'
},
25: {
input: {
- models: ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'],
+ models: ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k']
},
- modelGroup: "moonshot",
+ modelGroup: 'moonshot'
},
26: {
input: {
- models: ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding'],
+ models: ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding']
},
- modelGroup: "baichuan",
+ modelGroup: 'baichuan'
},
27: {
input: {
- models: ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat'],
+ models: ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat']
},
- modelGroup: "minimax",
+ modelGroup: 'minimax'
},
29: {
- modelGroup: "groq",
+ modelGroup: 'groq'
},
30: {
- modelGroup: "ollama",
+ modelGroup: 'ollama'
},
31: {
- modelGroup: "lingyiwanwu",
+ modelGroup: 'lingyiwanwu'
},
+ 33: {
+ inputLabel: {
+ key: '',
+ config: {
+ region: 'Region',
+ ak: 'Access Key',
+ sk: 'Secret Key'
+ }
+ },
+ prompt: {
+ key: '',
+ config: {
+ region: 'region,e.g. us-west-2',
+ ak: 'AWS IAM Access Key',
+ sk: 'AWS IAM Secret Key'
+ }
+ },
+ modelGroup: 'anthropic'
+ },
+ 37: {
+ inputLabel: {
+ config: {
+ user_id: 'Account ID'
+ }
+ },
+ prompt: {
+ config: {
+ user_id: '请输入 Account ID,例如:d8d7c61dbc334c32d3ced580e4bf42b4'
+ }
+ },
+ modelGroup: 'Cloudflare'
+ },
+ 34: {
+ inputLabel: {
+ config: {
+ user_id: 'User ID'
+ }
+ },
+ prompt: {
+ models: '对于 Coze 而言,模型名称即 Bot ID,你可以添加一个前缀 `bot-`,例如:`bot-123456`',
+ config: {
+ user_id: '生成该密钥的用户 ID'
+ }
+ },
+ modelGroup: 'Coze'
+ }
};
export { defaultConfig, typeConfig };
diff --git a/web/berry/src/views/Profile/index.js b/web/berry/src/views/Profile/index.js
index 483e3141..4705d8af 100644
--- a/web/berry/src/views/Profile/index.js
+++ b/web/berry/src/views/Profile/index.js
@@ -21,7 +21,7 @@ import { IconBrandWechat, IconBrandGithub, IconMail } from '@tabler/icons-react'
import Label from 'ui-component/Label';
import { API } from 'utils/api';
import { showError, showSuccess } from 'utils/common';
-import { onGitHubOAuthClicked, onLarkOAuthClicked } from 'utils/common';
+import { onGitHubOAuthClicked, onLarkOAuthClicked, copy } from 'utils/common';
import * as Yup from 'yup';
import WechatModal from 'views/Authentication/AuthForms/WechatModal';
import { useSelector } from 'react-redux';
@@ -90,8 +90,7 @@ export default function Profile() {
const { success, message, data } = res.data;
if (success) {
setInputs((inputs) => ({ ...inputs, access_token: data }));
- navigator.clipboard.writeText(data);
- showSuccess(`令牌已重置并已复制到剪贴板`);
+ copy(data, '访问令牌');
} else {
showError(message);
}
diff --git a/web/berry/src/views/Redemption/component/TableRow.js b/web/berry/src/views/Redemption/component/TableRow.js
index 68c9a505..380af037 100644
--- a/web/berry/src/views/Redemption/component/TableRow.js
+++ b/web/berry/src/views/Redemption/component/TableRow.js
@@ -18,7 +18,7 @@ import {
import Label from 'ui-component/Label';
import TableSwitch from 'ui-component/Switch';
-import { timestamp2string, renderQuota, showSuccess } from 'utils/common';
+import { timestamp2string, renderQuota, copy } from 'utils/common';
import { IconDotsVertical, IconEdit, IconTrash } from '@tabler/icons-react';
@@ -83,8 +83,7 @@ export default function RedemptionTableRow({ item, manageRedemption, handleOpenM
variant="contained"
color="primary"
onClick={() => {
- navigator.clipboard.writeText(item.key);
- showSuccess('已复制到剪贴板!');
+ copy(item.key, '兑换码');
}}
>
复制
diff --git a/web/berry/src/views/Token/component/TableRow.js b/web/berry/src/views/Token/component/TableRow.js
index 51ab0d4b..6a197e69 100644
--- a/web/berry/src/views/Token/component/TableRow.js
+++ b/web/berry/src/views/Token/component/TableRow.js
@@ -20,7 +20,7 @@ import {
} from '@mui/material';
import TableSwitch from 'ui-component/Switch';
-import { renderQuota, showSuccess, timestamp2string } from 'utils/common';
+import { renderQuota, timestamp2string, copy } from 'utils/common';
import { IconDotsVertical, IconEdit, IconTrash, IconCaretDownFilled } from '@tabler/icons-react';
@@ -141,8 +141,7 @@ export default function TokensTableRow({ item, manageToken, handleOpenModal, set
if (type === 'link') {
window.open(text);
} else {
- navigator.clipboard.writeText(text);
- showSuccess('已复制到剪贴板!');
+ copy(text);
}
handleCloseMenu();
};
@@ -192,7 +191,7 @@ export default function TokensTableRow({ item, manageToken, handleOpenModal, set
id={`switch-${item.id}`}
checked={statusSwitch === 1}
onChange={handleStatus}
- // disabled={statusSwitch !== 1 && statusSwitch !== 2}
+ // disabled={statusSwitch !== 1 && statusSwitch !== 2}
/>
@@ -211,8 +210,7 @@ export default function TokensTableRow({ item, manageToken, handleOpenModal, set
{
- navigator.clipboard.writeText(`sk-${item.key}`);
- showSuccess('已复制到剪贴板!');
+ copy(`sk-${item.key}`);
}}
>
复制
@@ -222,7 +220,9 @@ export default function TokensTableRow({ item, manageToken, handleOpenModal, set
- handleCopy(COPY_OPTIONS[0], 'link')}>聊天
+ handleCopy(COPY_OPTIONS[0], 'link')}>
+ 聊天
+
handleOpenMenu(e, 'link')}>
diff --git a/web/berry/src/views/Topup/component/InviteCard.js b/web/berry/src/views/Topup/component/InviteCard.js
index a95f85e5..73c9670f 100644
--- a/web/berry/src/views/Topup/component/InviteCard.js
+++ b/web/berry/src/views/Topup/component/InviteCard.js
@@ -4,7 +4,7 @@ import SubCard from 'ui-component/cards/SubCard';
import inviteImage from 'assets/images/invite/cwok_casual_19.webp';
import { useState } from 'react';
import { API } from 'utils/api';
-import { showError, showSuccess } from 'utils/common';
+import { showError, copy } from 'utils/common';
const InviteCard = () => {
const theme = useTheme();
@@ -12,8 +12,7 @@ const InviteCard = () => {
const handleInviteUrl = async () => {
if (inviteUl) {
- navigator.clipboard.writeText(inviteUl);
- showSuccess(`邀请链接已复制到剪切板`);
+ copy(inviteUl, '邀请链接');
return;
}
const res = await API.get('/api/user/aff');
@@ -21,8 +20,7 @@ const InviteCard = () => {
if (success) {
let link = `${window.location.origin}/register?aff=${data}`;
setInviteUrl(link);
- navigator.clipboard.writeText(link);
- showSuccess(`邀请链接已复制到剪切板`);
+ copy(link, '邀请链接');
} else {
showError(message);
}
From a9211d66f6717989c634e4de595c84436e2d4ac5 Mon Sep 17 00:00:00 2001
From: Dafei Zhao
Date: Mon, 27 May 2024 13:26:07 -0400
Subject: [PATCH 2/8] fix: fix gpt-4o token encoding (#1446)
---
relay/adaptor/openai/token.go | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/relay/adaptor/openai/token.go b/relay/adaptor/openai/token.go
index bb9c38a9..ddbfad86 100644
--- a/relay/adaptor/openai/token.go
+++ b/relay/adaptor/openai/token.go
@@ -24,6 +24,10 @@ func InitTokenEncoders() {
logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
+ gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
+ if err != nil {
+ logger.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
+ }
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
@@ -31,6 +35,8 @@ func InitTokenEncoders() {
for model := range billingratio.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
+ } else if strings.HasPrefix(model, "gpt-4o") {
+ tokenEncoderMap[model] = gpt4oTokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = gpt4TokenEncoder
} else {
From fa74ba0eaa4e608f128fb0b4af564f67b32565e2 Mon Sep 17 00:00:00 2001
From: Ghostz <137054651+ye4293@users.noreply.github.com>
Date: Tue, 28 May 2024 01:30:51 +0800
Subject: [PATCH 3/8] chore: print user id when relay error happened (#1447)
* add userid when relay error
* chore: update log format
---------
Co-authored-by: JustSong
---
controller/relay.go | 14 ++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)
diff --git a/controller/relay.go b/controller/relay.go
index aba4cd94..5d8ac690 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -4,6 +4,9 @@ import (
"bytes"
"context"
"fmt"
+ "io"
+ "net/http"
+
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
@@ -16,8 +19,6 @@ import (
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
- "io"
- "net/http"
)
// https://platform.openai.com/docs/api-reference/chat
@@ -47,6 +48,7 @@ func Relay(c *gin.Context) {
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
channelId := c.GetInt(ctxkey.ChannelId)
+ userId := c.GetInt("id")
bizErr := relayHelper(c, relayMode)
if bizErr == nil {
monitor.Emit(channelId, true)
@@ -56,7 +58,7 @@ func Relay(c *gin.Context) {
channelName := c.GetString(ctxkey.ChannelName)
group := c.GetString(ctxkey.Group)
originalModel := c.GetString(ctxkey.OriginalModel)
- go processChannelRelayError(ctx, channelId, channelName, bizErr)
+ go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
requestId := c.GetString(helper.RequestIdKey)
retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) {
@@ -83,7 +85,7 @@ func Relay(c *gin.Context) {
channelId := c.GetInt(ctxkey.ChannelId)
lastFailedChannelId = channelId
channelName := c.GetString(ctxkey.ChannelName)
- go processChannelRelayError(ctx, channelId, channelName, bizErr)
+ go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
}
if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests {
@@ -115,8 +117,8 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
return true
}
-func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
- logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
+func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
+ logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
monitor.DisableChannel(channelId, channelName, err.Message)
From 3be28da57b1bc466cd67b037ab38739032af80d6 Mon Sep 17 00:00:00 2001
From: fatwang2 <134143178+fatwang2@users.noreply.github.com>
Date: Tue, 28 May 2024 01:31:08 +0800
Subject: [PATCH 4/8] Update package.json (#1465)
---
web/berry/package.json | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/web/berry/package.json b/web/berry/package.json
index 2edb2355..f8265ef7 100644
--- a/web/berry/package.json
+++ b/web/berry/package.json
@@ -26,7 +26,7 @@
"notistack": "^3.0.1",
"prop-types": "^15.8.1",
"react": "^18.2.0",
- "react-apexcharts": "^1.4.0",
+ "react-apexcharts": "1.4.0",
"react-device-detect": "^2.2.2",
"react-dom": "^18.2.0",
"react-perfect-scrollbar": "^1.5.8",
From 332c8db0b31ef117a5ea090d2875f15111da3a27 Mon Sep 17 00:00:00 2001
From: Mo
Date: Tue, 28 May 2024 01:32:57 +0800
Subject: [PATCH 5/8] fix: add prefixes to image models to solve the problem of
duplicate models (#1469)
* Add prefixes to image models to solve the problem of duplicate models
* Fix the issue that response_format is not set, causing the b64_json parameter to be ignored.
---
relay/billing/ratio/image.go | 5 +++++
relay/controller/image.go | 7 ++++++-
2 files changed, 11 insertions(+), 1 deletion(-)
diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go
index 5a29cddc..ced0c667 100644
--- a/relay/billing/ratio/image.go
+++ b/relay/billing/ratio/image.go
@@ -49,3 +49,8 @@ var ImagePromptLengthLimitations = map[string]int{
"wanx-v1": 4000,
"cogview-3": 833,
}
+
+var ImageOriginModelName = map[string]string{
+ "ali-stable-diffusion-xl": "stable-diffusion-xl",
+ "ali-stable-diffusion-v1.5": "stable-diffusion-v1.5",
+}
diff --git a/relay/controller/image.go b/relay/controller/image.go
index 6620bef5..691c7c0e 100644
--- a/relay/controller/image.go
+++ b/relay/controller/image.go
@@ -55,6 +55,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
}
+ imageModel := imageRequest.Model
+ // Convert the original image model
+ imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName)
+ c.Set("response_format", imageRequest.ResponseFormat)
+
var requestBody io.Reader
if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
@@ -89,7 +94,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = bytes.NewBuffer(jsonStr)
}
- modelRatio := billingratio.GetModelRatio(imageRequest.Model)
+ modelRatio := billingratio.GetModelRatio(imageModel)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
From b53e00a9b375b44c7c6b54ad2a2e7c4b868823f9 Mon Sep 17 00:00:00 2001
From: carey036 <45777074+carey036@users.noreply.github.com>
Date: Tue, 28 May 2024 01:44:38 +0800
Subject: [PATCH 6/8] feat: generate default token after register (#1401)
* feat: generate default token after register
* chore: use go routine to create default token for new user
---------
Co-authored-by: JustSong
---
controller/user.go | 25 +++++++++++++++++++++++++
1 file changed, 25 insertions(+)
diff --git a/controller/user.go b/controller/user.go
index af90acf6..9ab37b5a 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -6,6 +6,8 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model"
"net/http"
@@ -109,6 +111,7 @@ func Logout(c *gin.Context) {
}
func Register(c *gin.Context) {
+ ctx := c.Request.Context()
if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册",
@@ -173,6 +176,28 @@ func Register(c *gin.Context) {
})
return
}
+ go func() {
+ err := user.ValidateAndFill()
+ if err != nil {
+ logger.Errorf(ctx, "user.ValidateAndFill failed: %w", err)
+ return
+ }
+ cleanToken := model.Token{
+ UserId: user.Id,
+ Name: "default",
+ Key: random.GenerateKey(),
+ CreatedTime: helper.GetTimestamp(),
+ AccessedTime: helper.GetTimestamp(),
+ ExpiredTime: -1,
+ RemainQuota: -1,
+ UnlimitedQuota: true,
+ }
+ err = cleanToken.Insert()
+ if err != nil {
+ logger.Errorf(ctx, "cleanToken.Insert failed: %w", err)
+ return
+ }
+ }()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
From ceea4c6d4a3de6970bab4381316f4cd19c6e7b28 Mon Sep 17 00:00:00 2001
From: JustSong
Date: Wed, 29 May 2024 01:14:00 +0800
Subject: [PATCH 7/8] feat: support user content download proxy & relay proxy
now
---
README.md | 19 ++++++-----
common/client/init.go | 60 +++++++++++++++++++++++++++++++++++
common/config/config.go | 4 +++
common/image/image.go | 5 +--
controller/channel-billing.go | 2 +-
main.go | 2 ++
relay/adaptor/baidu/main.go | 2 +-
relay/adaptor/common.go | 2 +-
relay/client/init.go | 24 --------------
relay/controller/audio.go | 2 +-
10 files changed, 84 insertions(+), 38 deletions(-)
create mode 100644 common/client/init.go
delete mode 100644 relay/client/init.go
diff --git a/README.md b/README.md
index 40f6e4e0..167fe5ba 100644
--- a/README.md
+++ b/README.md
@@ -384,14 +384,17 @@ graph LR
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
-18. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
-19. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。
-20. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。
-21. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
-22. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
-23. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
-24. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
-25. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
+18. `RELAY_PROXY`:设置后使用该代理来请求 API。
+19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。
+20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。
+21. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
+22. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。
+23. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。
+24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
+25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
+26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
+27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
+28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
### 命令行参数
1. `--port `: 指定服务器监听的端口号,默认为 `3000`。
diff --git a/common/client/init.go b/common/client/init.go
new file mode 100644
index 00000000..f803cbf8
--- /dev/null
+++ b/common/client/init.go
@@ -0,0 +1,60 @@
+package client
+
+import (
+ "fmt"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "net/http"
+ "net/url"
+ "time"
+)
+
+var HTTPClient *http.Client
+var ImpatientHTTPClient *http.Client
+var UserContentRequestHTTPClient *http.Client
+
+func Init() {
+ if config.UserContentRequestProxy != "" {
+ logger.SysLog(fmt.Sprintf("using %s as proxy to fetch user content", config.UserContentRequestProxy))
+ proxyURL, err := url.Parse(config.UserContentRequestProxy)
+ if err != nil {
+ logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
+ }
+ transport := &http.Transport{
+ Proxy: http.ProxyURL(proxyURL),
+ }
+ UserContentRequestHTTPClient = &http.Client{
+ Transport: transport,
+ Timeout: time.Second * time.Duration(config.UserContentRequestTimeout),
+ }
+ } else {
+ UserContentRequestHTTPClient = &http.Client{}
+ }
+ var transport http.RoundTripper
+ if config.RelayProxy != "" {
+ logger.SysLog(fmt.Sprintf("using %s as api relay proxy", config.RelayProxy))
+ proxyURL, err := url.Parse(config.RelayProxy)
+ if err != nil {
+ logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
+ }
+ transport = &http.Transport{
+ Proxy: http.ProxyURL(proxyURL),
+ }
+ }
+
+ if config.RelayTimeout == 0 {
+ HTTPClient = &http.Client{
+ Transport: transport,
+ }
+ } else {
+ HTTPClient = &http.Client{
+ Timeout: time.Duration(config.RelayTimeout) * time.Second,
+ Transport: transport,
+ }
+ }
+
+ ImpatientHTTPClient = &http.Client{
+ Timeout: 5 * time.Second,
+ Transport: transport,
+ }
+}
diff --git a/common/config/config.go b/common/config/config.go
index 0864d844..539eeef4 100644
--- a/common/config/config.go
+++ b/common/config/config.go
@@ -144,3 +144,7 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
var GeminiVersion = env.String("GEMINI_VERSION", "v1")
+
+var RelayProxy = env.String("RELAY_PROXY", "")
+var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
+var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
diff --git a/common/image/image.go b/common/image/image.go
index 12f0adff..beebd0c6 100644
--- a/common/image/image.go
+++ b/common/image/image.go
@@ -3,6 +3,7 @@ package image
import (
"bytes"
"encoding/base64"
+ "github.com/songquanpeng/one-api/common/client"
"image"
_ "image/gif"
_ "image/jpeg"
@@ -19,7 +20,7 @@ import (
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
func IsImageUrl(url string) (bool, error) {
- resp, err := http.Head(url)
+ resp, err := client.UserContentRequestHTTPClient.Head(url)
if err != nil {
return false, err
}
@@ -34,7 +35,7 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
if !isImage {
return
}
- resp, err := http.Get(url)
+ resp, err := client.UserContentRequestHTTPClient.Get(url)
if err != nil {
return
}
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index b7ac61fd..53592744 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -4,12 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
+ "github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/channeltype"
- "github.com/songquanpeng/one-api/relay/client"
"io"
"net/http"
"strconv"
diff --git a/main.go b/main.go
index bdcdcd61..eb6f368c 100644
--- a/main.go
+++ b/main.go
@@ -7,6 +7,7 @@ import (
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
@@ -94,6 +95,7 @@ func main() {
logger.SysLog("metric enabled, will disable channel if too much request failed")
}
openai.InitTokenEncoders()
+ client.Init()
// Initialize HTTP server
server := gin.New()
diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go
index 6df5ce84..b816e0f4 100644
--- a/relay/adaptor/baidu/main.go
+++ b/relay/adaptor/baidu/main.go
@@ -7,9 +7,9 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
- "github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
diff --git a/relay/adaptor/common.go b/relay/adaptor/common.go
index 82a5160e..8953d7a3 100644
--- a/relay/adaptor/common.go
+++ b/relay/adaptor/common.go
@@ -4,7 +4,7 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/client"
+ "github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/relay/meta"
"io"
"net/http"
diff --git a/relay/client/init.go b/relay/client/init.go
deleted file mode 100644
index 4b59cba7..00000000
--- a/relay/client/init.go
+++ /dev/null
@@ -1,24 +0,0 @@
-package client
-
-import (
- "github.com/songquanpeng/one-api/common/config"
- "net/http"
- "time"
-)
-
-var HTTPClient *http.Client
-var ImpatientHTTPClient *http.Client
-
-func init() {
- if config.RelayTimeout == 0 {
- HTTPClient = &http.Client{}
- } else {
- HTTPClient = &http.Client{
- Timeout: time.Duration(config.RelayTimeout) * time.Second,
- }
- }
-
- ImpatientHTTPClient = &http.Client{
- Timeout: 5 * time.Second,
- }
-}
diff --git a/relay/controller/audio.go b/relay/controller/audio.go
index 15e74290..8f9708d0 100644
--- a/relay/controller/audio.go
+++ b/relay/controller/audio.go
@@ -9,6 +9,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
+ "github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
@@ -17,7 +18,6 @@ import (
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
- "github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
From 9321427c6ef4d703f4e315c15d334c88359ec8e3 Mon Sep 17 00:00:00 2001
From: Wei Tingjiang
Date: Wed, 29 May 2024 01:17:32 +0800
Subject: [PATCH 8/8] feat: support gemini embeddings
(text-embedding-004,embedding-001) (#1475)
* Refactor Gemini Adaptor to Support Embeddings
* Add new models to ModelList
---
relay/adaptor/gemini/adaptor.go | 26 +++++++++--
relay/adaptor/gemini/constants.go | 2 +-
relay/adaptor/gemini/main.go | 76 +++++++++++++++++++++++++++++++
relay/adaptor/gemini/model.go | 27 +++++++++++
4 files changed, 127 insertions(+), 4 deletions(-)
diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go
index a4dcae93..12f48c71 100644
--- a/relay/adaptor/gemini/adaptor.go
+++ b/relay/adaptor/gemini/adaptor.go
@@ -13,6 +13,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/relaymode"
)
type Adaptor struct {
@@ -24,7 +25,14 @@ func (a *Adaptor) Init(meta *meta.Meta) {
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
- action := "generateContent"
+ action := ""
+ switch meta.Mode {
+ case relaymode.Embeddings:
+ action = "batchEmbedContents"
+ default:
+ action = "generateContent"
+ }
+
if meta.IsStream {
action = "streamGenerateContent?alt=sse"
}
@@ -41,7 +49,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
- return ConvertRequest(*request), nil
+ switch relayMode {
+ case relaymode.Embeddings:
+ geminiEmbeddingRequest := ConvertEmbeddingRequest(*request)
+ return geminiEmbeddingRequest, nil
+ default:
+ geminiRequest := ConvertRequest(*request)
+ return geminiRequest, nil
+ }
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
@@ -61,7 +76,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
- err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ switch meta.Mode {
+ case relaymode.Embeddings:
+ err, usage = EmbeddingHandler(c, resp)
+ default:
+ err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ }
}
return
}
diff --git a/relay/adaptor/gemini/constants.go b/relay/adaptor/gemini/constants.go
index 32e7c240..f65e6bfc 100644
--- a/relay/adaptor/gemini/constants.go
+++ b/relay/adaptor/gemini/constants.go
@@ -4,5 +4,5 @@ package gemini
var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
- "gemini-pro-vision", "gemini-1.0-pro-vision-001",
+ "gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004",
}
diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go
index faccc4cb..534b2708 100644
--- a/relay/adaptor/gemini/main.go
+++ b/relay/adaptor/gemini/main.go
@@ -134,6 +134,29 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
return &geminiRequest
}
+func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
+ inputs := request.ParseInput()
+ requests := make([]EmbeddingRequest, len(inputs))
+ model := fmt.Sprintf("models/%s", request.Model)
+
+ for i, input := range inputs {
+ requests[i] = EmbeddingRequest{
+ Model: model,
+ Content: ChatContent{
+ Parts: []Part{
+ {
+ Text: input,
+ },
+ },
+ },
+ }
+ }
+
+ return &BatchEmbeddingRequest{
+ Requests: requests,
+ }
+}
+
type ChatResponse struct {
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
@@ -230,6 +253,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
return &response
}
+func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
+ openAIEmbeddingResponse := openai.EmbeddingResponse{
+ Object: "list",
+ Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
+ Model: "gemini-embedding",
+ Usage: model.Usage{TotalTokens: 0},
+ }
+ for _, item := range response.Embeddings {
+ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
+ Object: `embedding`,
+ Index: 0,
+ Embedding: item.Values,
+ })
+ }
+ return &openAIEmbeddingResponse
+}
+
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
@@ -337,3 +377,39 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
+
+func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var geminiEmbeddingResponse EmbeddingResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = json.Unmarshal(responseBody, &geminiEmbeddingResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ if geminiEmbeddingResponse.Error != nil {
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
+ Message: geminiEmbeddingResponse.Error.Message,
+ Type: "gemini_error",
+ Param: "",
+ Code: geminiEmbeddingResponse.Error.Code,
+ },
+ StatusCode: resp.StatusCode,
+ }, nil
+ }
+ fullTextResponse := embeddingResponseGemini2OpenAI(&geminiEmbeddingResponse)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ return nil, &fullTextResponse.Usage
+}
diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go
index 47b74fbc..f7179ea4 100644
--- a/relay/adaptor/gemini/model.go
+++ b/relay/adaptor/gemini/model.go
@@ -7,6 +7,33 @@ type ChatRequest struct {
Tools []ChatTools `json:"tools,omitempty"`
}
+type EmbeddingRequest struct {
+ Model string `json:"model"`
+ Content ChatContent `json:"content"`
+ TaskType string `json:"taskType,omitempty"`
+ Title string `json:"title,omitempty"`
+ OutputDimensionality int `json:"outputDimensionality,omitempty"`
+}
+
+type BatchEmbeddingRequest struct {
+ Requests []EmbeddingRequest `json:"requests"`
+}
+
+type EmbeddingData struct {
+ Values []float64 `json:"values"`
+}
+
+type EmbeddingResponse struct {
+ Embeddings []EmbeddingData `json:"embeddings"`
+ Error *Error `json:"error,omitempty"`
+}
+
+type Error struct {
+ Code int `json:"code,omitempty"`
+ Message string `json:"message,omitempty"`
+ Status string `json:"status,omitempty"`
+}
+
type InlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`