✨ feat: support other OpenAI APIs (#165)
* ✨ feat: support other OpenAI APIs * 🔖 chore: Update English translation
This commit is contained in:
parent
f91b9856d4
commit
628df97f96
@ -1122,5 +1122,7 @@
|
||||
"模型名称映射, 你可以取一个容易记忆的名字来代替coze-{bot_id},例如:": "Model name mapping, you can take an easy-to-remember name to replace coze-{bot_id}, for example: ",
|
||||
",注意:如果使用了模型映射,那么上面的模型名称必须使用映射前的名称,上述例子中,你应该在模型中填入coze-translate(如果已经使用了coze-*,可以忽略)。": ", Note: If a model mapping is used, then the model name above must use the name before the mapping. In the example above, you should fill in coze-translate in the model (if coze-* has been used, it can be ignored).",
|
||||
"位置/区域": "Location/Region",
|
||||
"请输入你 Speech Studio 的位置/区域,例如:eastasia": "Please enter the location/region of your Speech Studio, for example: eastasia"
|
||||
"请输入你 Speech Studio 的位置/区域,例如:eastasia": "Please enter the location/region of your Speech Studio, for example: eastasia",
|
||||
"必须指定渠道": "Channel must be specified",
|
||||
"中继": "Relay"
|
||||
}
|
||||
|
@ -114,6 +114,10 @@ func tokenAuth(c *gin.Context, key string) {
|
||||
return
|
||||
}
|
||||
c.Set("specific_channel_id", channelId)
|
||||
if len(parts) == 3 && parts[2] == "ignore" {
|
||||
c.Set("specific_channel_id_ignore", true)
|
||||
}
|
||||
|
||||
} else {
|
||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
@ -135,3 +139,16 @@ func MjAuth() func(c *gin.Context) {
|
||||
tokenAuth(c, key)
|
||||
}
|
||||
}
|
||||
|
||||
func SpecifiedChannel() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
channelId := c.GetInt("specific_channel_id")
|
||||
c.Set("specific_channel_id_ignore", false)
|
||||
|
||||
if channelId <= 0 {
|
||||
abortWithMessage(c, http.StatusForbidden, "必须指定渠道")
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
@ -143,3 +143,7 @@ func (p *BaseProvider) GetSupportedAPIUri(relayMode int) (url string, err *types
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *BaseProvider) GetRequester() *requester.HTTPRequester {
|
||||
return p.Requester
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ type ProviderInterface interface {
|
||||
// 获取完整请求URL
|
||||
// GetFullRequestURL(requestURL string, modelName string) string
|
||||
// 获取请求头
|
||||
// GetRequestHeaders() (headers map[string]string)
|
||||
GetRequestHeaders() map[string]string
|
||||
// 获取用量
|
||||
GetUsage() *types.Usage
|
||||
// 设置用量
|
||||
@ -35,6 +35,7 @@ type ProviderInterface interface {
|
||||
// SupportAPI(relayMode int) bool
|
||||
GetChannel() *model.Channel
|
||||
ModelMappingHandler(modelName string) (string, error)
|
||||
GetRequester() *requester.HTTPRequester
|
||||
}
|
||||
|
||||
// 完成接口
|
||||
|
@ -85,22 +85,28 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
|
||||
|
||||
if p.IsAzure {
|
||||
apiVersion := p.Channel.Other
|
||||
// 检测模型是是否包含 . 如果有则直接去掉
|
||||
modelName = strings.Replace(modelName, ".", "", -1)
|
||||
if modelName != "" {
|
||||
// 检测模型是是否包含 . 如果有则直接去掉
|
||||
modelName = strings.Replace(modelName, ".", "", -1)
|
||||
|
||||
if modelName == "dall-e-2" {
|
||||
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
|
||||
// 已经没有dall-e-2了,所以暂时写死
|
||||
requestURL = fmt.Sprintf("/openai/%s:submit?api-version=2023-09-01-preview", requestURL)
|
||||
if modelName == "dall-e-2" {
|
||||
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
|
||||
// 已经没有dall-e-2了,所以暂时写死
|
||||
requestURL = fmt.Sprintf("/openai/%s:submit?api-version=2023-09-01-preview", requestURL)
|
||||
} else {
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||
}
|
||||
} else {
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||
requestURL = strings.TrimPrefix(requestURL, "/v1")
|
||||
requestURL = fmt.Sprintf("/openai%s?api-version=%s", requestURL, apiVersion)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
if p.IsAzure {
|
||||
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
||||
requestURL = strings.TrimPrefix(requestURL, "/openai")
|
||||
requestURL = strings.TrimPrefix(requestURL, "/deployments")
|
||||
} else {
|
||||
requestURL = strings.TrimPrefix(requestURL, "/v1")
|
||||
}
|
||||
|
@ -78,7 +78,8 @@ func GetProvider(c *gin.Context, modeName string) (provider providersBase.Provid
|
||||
|
||||
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail error) {
|
||||
channelId := c.GetInt("specific_channel_id")
|
||||
if channelId > 0 {
|
||||
ignore := c.GetBool("specific_channel_id_ignore")
|
||||
if channelId > 0 && !ignore {
|
||||
return fetchChannelById(channelId)
|
||||
}
|
||||
|
||||
@ -206,7 +207,8 @@ func responseCache(c *gin.Context, response string) {
|
||||
|
||||
func shouldRetry(c *gin.Context, statusCode int) bool {
|
||||
channelId := c.GetInt("specific_channel_id")
|
||||
if channelId > 0 {
|
||||
ignore := c.GetBool("specific_channel_id_ignore")
|
||||
if channelId > 0 && !ignore {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
@ -230,3 +232,11 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st
|
||||
controller.DisableChannel(channelId, channelName, err.Message, true)
|
||||
}
|
||||
}
|
||||
|
||||
func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||
c.JSON(err.StatusCode, gin.H{
|
||||
"error": err.OpenAIError,
|
||||
})
|
||||
}
|
||||
|
@ -75,15 +75,10 @@ func Relay(c *gin.Context) {
|
||||
}
|
||||
|
||||
if apiErr != nil {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
if apiErr.StatusCode == http.StatusTooManyRequests {
|
||||
apiErr.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
apiErr.OpenAIError.Message = common.MessageWithRequestId(apiErr.OpenAIError.Message, requestId)
|
||||
c.JSON(apiErr.StatusCode, gin.H{
|
||||
"error": apiErr.OpenAIError,
|
||||
})
|
||||
|
||||
relayResponseWithErr(c, apiErr)
|
||||
}
|
||||
}
|
||||
|
||||
|
85
relay/relay.go
Normal file
85
relay/relay.go
Normal file
@ -0,0 +1,85 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers/azure"
|
||||
"one-api/providers/openai"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayOnly(c *gin.Context) {
|
||||
provider, _, fail := GetProvider(c, "")
|
||||
if fail != nil {
|
||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, fail.Error())
|
||||
return
|
||||
}
|
||||
|
||||
channel := provider.GetChannel()
|
||||
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeAzure {
|
||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type azureopenai or openai")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取请求的path
|
||||
url := ""
|
||||
path := c.Request.URL.Path
|
||||
openAIProvider, ok := provider.(*openai.OpenAIProvider)
|
||||
if !ok {
|
||||
azureProvider, ok := provider.(*azure.AzureProvider)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type openai")
|
||||
return
|
||||
}
|
||||
url = azureProvider.GetFullRequestURL(path, "")
|
||||
} else {
|
||||
url = openAIProvider.GetFullRequestURL(path, "")
|
||||
}
|
||||
|
||||
headers := c.Request.Header
|
||||
mapHeaders := provider.GetRequestHeaders()
|
||||
// 设置请求头
|
||||
for k, v := range headers {
|
||||
if _, ok := mapHeaders[k]; ok {
|
||||
continue
|
||||
}
|
||||
mapHeaders[k] = strings.Join(v, ", ")
|
||||
}
|
||||
|
||||
requester := provider.GetRequester()
|
||||
req, err := requester.NewRequest(c.Request.Method, url, requester.WithBody(c.Request.Body), requester.WithHeader(mapHeaders))
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
defer req.Body.Close()
|
||||
|
||||
response, errWithCode := requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
relayResponseWithErr(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
errWithCode = responseMultipart(c, response)
|
||||
|
||||
if errWithCode != nil {
|
||||
relayResponseWithErr(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
requestTime := 0
|
||||
requestStartTimeValue := c.Request.Context().Value("requestStartTime")
|
||||
if requestStartTimeValue != nil {
|
||||
requestStartTime, ok := requestStartTimeValue.(time.Time)
|
||||
if ok {
|
||||
requestTime = int(time.Since(requestStartTime).Milliseconds())
|
||||
}
|
||||
}
|
||||
model.RecordConsumeLog(c.Request.Context(), c.GetInt("id"), c.GetInt("channel_id"), 0, 0, "", c.GetString("token_name"), 0, "中继:"+path, requestTime)
|
||||
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"one-api/controller"
|
||||
"one-api/middleware"
|
||||
"one-api/relay"
|
||||
"one-api/relay/midjourney"
|
||||
@ -38,43 +37,20 @@ func setOpenAIRouter(router *gin.Engine) {
|
||||
relayV1Router.POST("/audio/translations", relay.Relay)
|
||||
relayV1Router.POST("/audio/speech", relay.Relay)
|
||||
relayV1Router.POST("/moderations", relay.Relay)
|
||||
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/files", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine_tuning/jobs", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine_tuning/jobs", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine_tuning/jobs/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/assistants", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/assistants/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/assistants/:id/files", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants/:id/files/:fileId", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/assistants/:id/files/:fileId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants/:id/files", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/threads/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/messages", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/messages/:messageId", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/messages/:messageId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/messages/:messageId/files/:filesId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/messages/:messageId/files", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs/:runsId", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs/:runsId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs/:runsId/submit_tool_outputs", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs/:runsId/cancel", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs/:runsId/steps/:stepId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented)
|
||||
|
||||
relayV1Router.Use(middleware.SpecifiedChannel())
|
||||
{
|
||||
relayV1Router.Any("/files", relay.RelayOnly)
|
||||
relayV1Router.Any("/files/*any", relay.RelayOnly)
|
||||
relayV1Router.Any("/fine_tuning/*any", relay.RelayOnly)
|
||||
relayV1Router.Any("/assistants", relay.RelayOnly)
|
||||
relayV1Router.Any("/assistants/*any", relay.RelayOnly)
|
||||
relayV1Router.Any("/threads", relay.RelayOnly)
|
||||
relayV1Router.Any("/threads/*any", relay.RelayOnly)
|
||||
relayV1Router.Any("/batches/*any", relay.RelayOnly)
|
||||
relayV1Router.Any("/vector_stores/*any", relay.RelayOnly)
|
||||
relayV1Router.DELETE("/models/:model", relay.RelayOnly)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user