chore: reorganize constant related package

This commit is contained in:
JustSong 2024-04-06 00:44:33 +08:00
parent 880e12c855
commit f9d914873f
30 changed files with 269 additions and 257 deletions

View File

@ -38,44 +38,6 @@ const (
ChannelStatusAutoDisabled = 3 ChannelStatusAutoDisabled = 3
) )
const (
ChannelTypeUnknown = iota
ChannelTypeOpenAI
ChannelTypeAPI2D
ChannelTypeAzure
ChannelTypeCloseAI
ChannelTypeOpenAISB
ChannelTypeOpenAIMax
ChannelTypeOhMyGPT
ChannelTypeCustom
ChannelTypeAILS
ChannelTypeAIProxy
ChannelTypePaLM
ChannelTypeAPI2GPT
ChannelTypeAIGC2D
ChannelTypeAnthropic
ChannelTypeBaidu
ChannelTypeZhipu
ChannelTypeAli
ChannelTypeXunfei
ChannelType360
ChannelTypeOpenRouter
ChannelTypeAIProxyLibrary
ChannelTypeFastGPT
ChannelTypeTencent
ChannelTypeGemini
ChannelTypeMoonshot
ChannelTypeBaichuan
ChannelTypeMinimax
ChannelTypeMistral
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeLingYiWanWu
ChannelTypeStepFun
ChannelTypeDummy
)
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{
"", // 0 "", // 0
"https://api.openai.com", // 1 "https://api.openai.com", // 1

View File

@ -9,6 +9,7 @@ import (
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
@ -209,23 +210,23 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
channel.BaseURL = &baseURL channel.BaseURL = &baseURL
} }
switch channel.Type { switch channel.Type {
case common.ChannelTypeOpenAI: case channeltype.OpenAI:
if channel.GetBaseURL() != "" { if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
} }
case common.ChannelTypeAzure: case channeltype.Azure:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
case common.ChannelTypeCustom: case channeltype.Custom:
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
case common.ChannelTypeCloseAI: case channeltype.CloseAI:
return updateChannelCloseAIBalance(channel) return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB: case channeltype.OpenAISB:
return updateChannelOpenAISBBalance(channel) return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy: case channeltype.AIProxy:
return updateChannelAIProxyBalance(channel) return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT: case channeltype.API2GPT:
return updateChannelAPI2GPTBalance(channel) return updateChannelAPI2GPTBalance(channel)
case common.ChannelTypeAIGC2D: case channeltype.AIGC2D:
return updateChannelAIGC2DBalance(channel) return updateChannelAIGC2DBalance(channel)
default: default:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
@ -305,7 +306,7 @@ func updateAllChannelsBalance() error {
continue continue
} }
// TODO: support Azure // TODO: support Azure
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom {
continue continue
} }
balance, err := updateChannelBalance(channel) balance, err := updateChannelBalance(channel)

View File

@ -12,9 +12,10 @@ import (
"github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
@ -57,7 +58,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, "") middleware.SetupContextForSelectedChannel(c, channel, "")
meta := util.GetRelayMeta(c) meta := util.GetRelayMeta(c)
apiType := constant.ChannelType2APIType(channel.Type) apiType := channeltype.ToAPIType(channel.Type)
adaptor := helper.GetAdaptor(apiType) adaptor := helper.GetAdaptor(apiType)
if adaptor == nil { if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
@ -73,7 +74,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
request := buildTestRequest() request := buildTestRequest()
request.Model = modelName request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil { if err != nil {
return err, nil return err, nil
} }

View File

@ -3,10 +3,10 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
@ -62,8 +62,8 @@ func init() {
IsBlocking: false, IsBlocking: false,
}) })
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
for i := 0; i < constant.APITypeDummy; i++ { for i := 0; i < apitype.Dummy; i++ {
if i == constant.APITypeAIProxyLibrary { if i == apitype.AIProxyLibrary {
continue continue
} }
adaptor := helper.GetAdaptor(i) adaptor := helper.GetAdaptor(i)
@ -82,7 +82,7 @@ func init() {
} }
} }
for _, channelType := range openai.CompatibleChannels { for _, channelType := range openai.CompatibleChannels {
if channelType == common.ChannelTypeAzure { if channelType == channeltype.Azure {
continue continue
} }
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
@ -103,8 +103,8 @@ func init() {
modelsMap[model.Id] = model modelsMap[model.Id] = model
} }
channelId2Models = make(map[int][]string) channelId2Models = make(map[int][]string)
for i := 1; i < common.ChannelTypeDummy; i++ { for i := 1; i < channeltype.Dummy; i++ {
adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) adaptor := helper.GetAdaptor(channeltype.ToAPIType(i))
meta := &util.RelayMeta{ meta := &util.RelayMeta{
ChannelType: i, ChannelType: i,
} }

View File

@ -12,9 +12,9 @@ import (
"github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/middleware"
dbmodel "github.com/songquanpeng/one-api/model" dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
@ -25,13 +25,13 @@ import (
func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode var err *model.ErrorWithStatusCode
switch relayMode { switch relayMode {
case constant.RelayModeImagesGenerations: case relaymode.ImagesGenerations:
err = controller.RelayImageHelper(c, relayMode) err = controller.RelayImageHelper(c, relayMode)
case constant.RelayModeAudioSpeech: case relaymode.AudioSpeech:
fallthrough fallthrough
case constant.RelayModeAudioTranslation: case relaymode.AudioTranslation:
fallthrough fallthrough
case constant.RelayModeAudioTranscription: case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode) err = controller.RelayAudioHelper(c, relayMode)
default: default:
err = controller.RelayTextHelper(c) err = controller.RelayTextHelper(c)
@ -41,7 +41,7 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
ctx := c.Request.Context() ctx := c.Request.Context()
relayMode := constant.Path2RelayMode(c.Request.URL.Path) relayMode := relaymode.GetByPath(c.Request.URL.Path)
if config.DebugEnabled { if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c) requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody)) logger.Debugf(ctx, "request body: %s", string(requestBody))

View File

@ -28,7 +28,7 @@ func main() {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
if config.DebugEnabled { if config.DebugEnabled {
logger.SysLog("running in debug mode") logger.SysLog("running in debug relaymode")
} }
var err error var err error
// Initialize SQL Database // Initialize SQL Database

View File

@ -6,6 +6,7 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype"
"net/http" "net/http"
"strconv" "strconv"
) )
@ -66,15 +67,15 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility // this is for backward compatibility
switch channel.Type { switch channel.Type {
case common.ChannelTypeAzure: case channeltype.Azure:
c.Set(common.ConfigKeyAPIVersion, channel.Other) c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei: case channeltype.Xunfei:
c.Set(common.ConfigKeyAPIVersion, channel.Other) c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeGemini: case channeltype.Gemini:
c.Set(common.ConfigKeyAPIVersion, channel.Other) c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary: case channeltype.AIProxyLibrary:
c.Set(common.ConfigKeyLibraryID, channel.Other) c.Set(common.ConfigKeyLibraryID, channel.Other)
case common.ChannelTypeAli: case channeltype.Ali:
c.Set(common.ConfigKeyPlugin, channel.Other) c.Set(common.ConfigKeyPlugin, channel.Other)
} }
cfg, _ := channel.LoadConfig() cfg, _ := channel.LoadConfig()

17
relay/apitype/define.go Normal file
View File

@ -0,0 +1,17 @@
package apitype
const (
OpenAI = iota
Anthropic
PaLM
Baidu
Zhipu
Ali
Xunfei
AIProxyLibrary
Tencent
Gemini
Ollama
Dummy // this one is only for count, do not add any channel after this
)

View File

@ -6,8 +6,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
@ -25,9 +25,9 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
fullRequestURL := "" fullRequestURL := ""
switch meta.Mode { switch meta.Mode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
case constant.RelayModeImagesGenerations: case relaymode.ImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
default: default:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut
} }
req.Header.Set("Authorization", "Bearer "+meta.APIKey) req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.Mode == constant.RelayModeImagesGenerations { if meta.Mode == relaymode.ImagesGenerations {
req.Header.Set("X-DashScope-Async", "enable") req.Header.Set("X-DashScope-Async", "enable")
} }
if c.GetString(common.ConfigKeyPlugin) != "" { if c.GetString(common.ConfigKeyPlugin) != "" {
@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
aliEmbeddingRequest := ConvertEmbeddingRequest(*request) aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
return aliEmbeddingRequest, nil return aliEmbeddingRequest, nil
default: default:
@ -85,9 +85,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp) err, usage = EmbeddingHandler(c, resp)
case constant.RelayModeImagesGenerations: case relaymode.ImagesGenerations:
err, usage = ImageHandler(c, resp) err, usage = ImageHandler(c, resp)
default: default:
err, usage = Handler(c, resp) err, usage = Handler(c, resp)

View File

@ -3,13 +3,13 @@ package baidu
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
) )
@ -100,7 +100,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil return baiduEmbeddingRequest, nil
default: default:
@ -125,7 +125,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp) err, usage = EmbeddingHandler(c, resp)
default: default:
err, usage = Handler(c, resp) err, usage = Handler(c, resp)

View File

@ -2,13 +2,13 @@ package minimax
import ( import (
"fmt" "fmt"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
) )
func GetRequestURL(meta *util.RelayMeta) (string, error) { func GetRequestURL(meta *util.RelayMeta) (string, error) {
if meta.Mode == constant.RelayModeChatCompletions { if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
} }
return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) return "", fmt.Errorf("unsupported relay relaymode %d for minimax", meta.Mode)
} }

View File

@ -3,12 +3,12 @@ package ollama
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
) )
@ -23,7 +23,7 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md // https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings { if meta.Mode == relaymode.Embeddings {
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
} }
return fullRequestURL, nil return fullRequestURL, nil
@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request) ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
return ollamaEmbeddingRequest, nil return ollamaEmbeddingRequest, nil
default: default:
@ -64,7 +64,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp) err, usage = EmbeddingHandler(c, resp)
default: default:
err, usage = Handler(c, resp) err, usage = Handler(c, resp)

View File

@ -4,11 +4,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
@ -25,8 +25,8 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
switch meta.ChannelType { switch meta.ChannelType {
case common.ChannelTypeAzure: case channeltype.Azure:
if meta.Mode == constant.RelayModeImagesGenerations { if meta.Mode == relaymode.ImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion) fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion)
@ -43,7 +43,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version} // {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case common.ChannelTypeMinimax: case channeltype.Minimax:
return minimax.GetRequestURL(meta) return minimax.GetRequestURL(meta)
default: default:
return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
@ -52,12 +52,12 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta) channel.SetupCommonRequestHeader(c, req, meta)
if meta.ChannelType == common.ChannelTypeAzure { if meta.ChannelType == channeltype.Azure {
req.Header.Set("api-key", meta.APIKey) req.Header.Set("api-key", meta.APIKey)
return nil return nil
} }
req.Header.Set("Authorization", "Bearer "+meta.APIKey) req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.ChannelType == common.ChannelTypeOpenRouter { if meta.ChannelType == channeltype.OpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API") req.Header.Set("X-Title", "One API")
} }
@ -91,7 +91,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
} }
} else { } else {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeImagesGenerations: case relaymode.ImagesGenerations:
err, _ = ImageHandler(c, resp) err, _ = ImageHandler(c, resp)
default: default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)

View File

@ -1,7 +1,6 @@
package openai package openai
import ( import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/groq" "github.com/songquanpeng/one-api/relay/channel/groq"
@ -10,39 +9,40 @@ import (
"github.com/songquanpeng/one-api/relay/channel/mistral" "github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/channel/stepfun" "github.com/songquanpeng/one-api/relay/channel/stepfun"
"github.com/songquanpeng/one-api/relay/channeltype"
) )
var CompatibleChannels = []int{ var CompatibleChannels = []int{
common.ChannelTypeAzure, channeltype.Azure,
common.ChannelType360, channeltype.AI360,
common.ChannelTypeMoonshot, channeltype.Moonshot,
common.ChannelTypeBaichuan, channeltype.Baichuan,
common.ChannelTypeMinimax, channeltype.Minimax,
common.ChannelTypeMistral, channeltype.Mistral,
common.ChannelTypeGroq, channeltype.Groq,
common.ChannelTypeLingYiWanWu, channeltype.LingYiWanWu,
common.ChannelTypeStepFun, channeltype.StepFun,
} }
func GetCompatibleChannelMeta(channelType int) (string, []string) { func GetCompatibleChannelMeta(channelType int) (string, []string) {
switch channelType { switch channelType {
case common.ChannelTypeAzure: case channeltype.Azure:
return "azure", ModelList return "azure", ModelList
case common.ChannelType360: case channeltype.AI360:
return "360", ai360.ModelList return "360", ai360.ModelList
case common.ChannelTypeMoonshot: case channeltype.Moonshot:
return "moonshot", moonshot.ModelList return "moonshot", moonshot.ModelList
case common.ChannelTypeBaichuan: case channeltype.Baichuan:
return "baichuan", baichuan.ModelList return "baichuan", baichuan.ModelList
case common.ChannelTypeMinimax: case channeltype.Minimax:
return "minimax", minimax.ModelList return "minimax", minimax.ModelList
case common.ChannelTypeMistral: case channeltype.Mistral:
return "mistralai", mistral.ModelList return "mistralai", mistral.ModelList
case common.ChannelTypeGroq: case channeltype.Groq:
return "groq", groq.ModelList return "groq", groq.ModelList
case common.ChannelTypeLingYiWanWu: case channeltype.LingYiWanWu:
return "lingyiwanwu", lingyiwanwu.ModelList return "lingyiwanwu", lingyiwanwu.ModelList
case common.ChannelTypeStepFun: case channeltype.StepFun:
return "stepfun", stepfun.ModelList return "stepfun", stepfun.ModelList
default: default:
return "openai", ModelList return "openai", ModelList

View File

@ -8,8 +8,8 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -46,7 +46,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
switch relayMode { switch relayMode {
case constant.RelayModeChatCompletions: case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse) err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil { if err != nil {
@ -59,7 +59,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
if streamResponse.Usage != nil { if streamResponse.Usage != nil {
usage = streamResponse.Usage usage = streamResponse.Usage
} }
case constant.RelayModeCompletions: case relaymode.Completions:
var streamResponse CompletionsStreamResponse var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse) err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil { if err != nil {

View File

@ -6,8 +6,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
"math" "math"
@ -33,9 +33,9 @@ func (a *Adaptor) SetVersionByModeName(modelName string) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeImagesGenerations: case relaymode.ImagesGenerations:
return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
} }
a.SetVersionByModeName(meta.ActualModelName) a.SetVersionByModeName(meta.ActualModelName)
@ -61,7 +61,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil return baiduEmbeddingRequest, nil
default: default:
@ -107,10 +107,10 @@ func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.R
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
err, usage = EmbeddingsHandler(c, resp) err, usage = EmbeddingsHandler(c, resp)
return return
case constant.RelayModeImagesGenerations: case relaymode.ImagesGenerations:
err, usage = openai.ImageHandler(c, resp) err, usage = openai.ImageHandler(c, resp)
return return
} }
@ -120,7 +120,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
if meta.Mode == constant.RelayModeEmbeddings { if meta.Mode == relaymode.Embeddings {
err, usage = EmbeddingsHandler(c, resp) err, usage = EmbeddingsHandler(c, resp)
} else { } else {
err, usage = Handler(c, resp) err, usage = Handler(c, resp)

View File

@ -0,0 +1,39 @@
package channeltype
const (
Unknown = iota
OpenAI
API2D
Azure
CloseAI
OpenAISB
OpenAIMax
OhMyGPT
Custom
Ails
AIProxy
PaLM
API2GPT
AIGC2D
Anthropic
Baidu
Zhipu
Ali
Xunfei
AI360
OpenRouter
AIProxyLibrary
FastGPT
Tencent
Gemini
Moonshot
Baichuan
Minimax
Mistral
Groq
Ollama
LingYiWanWu
StepFun
Dummy
)

View File

@ -0,0 +1,30 @@
package channeltype
import "github.com/songquanpeng/one-api/relay/apitype"
func ToAPIType(channelType int) int {
apiType := apitype.OpenAI
switch channelType {
case Anthropic:
apiType = apitype.Anthropic
case Baidu:
apiType = apitype.Baidu
case PaLM:
apiType = apitype.PaLM
case Zhipu:
apiType = apitype.Zhipu
case Ali:
apiType = apitype.Ali
case Xunfei:
apiType = apitype.Xunfei
case AIProxyLibrary:
apiType = apitype.AIProxyLibrary
case Tencent:
apiType = apitype.Tencent
case Gemini:
apiType = apitype.Gemini
case Ollama:
apiType = apitype.Ollama
}
return apiType
}

View File

@ -1,48 +0,0 @@
package constant
import (
"github.com/songquanpeng/one-api/common"
)
const (
APITypeOpenAI = iota
APITypeAnthropic
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
APITypeOllama
APITypeDummy // this one is only for count, do not add any channel after this
)
func ChannelType2APIType(channelType int) int {
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeAnthropic
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
case common.ChannelTypeOllama:
apiType = APITypeOllama
}
return apiType
}

View File

@ -1,42 +1 @@
package constant package constant
import "strings"
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
func Path2RelayMode(path string) int {
relayMode := RelayModeUnknown
if strings.HasPrefix(path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
}
return relayMode
}

View File

@ -13,8 +13,9 @@ import (
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/channeltype"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
@ -33,7 +34,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
var ttsRequest openai.TextToSpeechRequest var ttsRequest openai.TextToSpeechRequest
if relayMode == constant.RelayModeAudioSpeech { if relayMode == relaymode.AudioSpeech {
// Read JSON // Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest) err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid // Check if JSON is valid
@ -53,7 +54,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
var quota int64 var quota int64
var preConsumedQuota int64 var preConsumedQuota int64
switch relayMode { switch relayMode {
case constant.RelayModeAudioSpeech: case relaymode.AudioSpeech:
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota quota = preConsumedQuota
default: default:
@ -122,12 +123,12 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure { if channelType == channeltype.Azure {
apiVersion := util.GetAzureAPIVersion(c) apiVersion := util.GetAzureAPIVersion(c)
if relayMode == constant.RelayModeAudioTranscription { if relayMode == relaymode.AudioTranscription {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
} else if relayMode == constant.RelayModeAudioSpeech { } else if relayMode == relaymode.AudioSpeech {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion)
} }
@ -146,7 +147,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
} }
if (relayMode == constant.RelayModeAudioTranscription || relayMode == constant.RelayModeAudioSpeech) && channelType == common.ChannelTypeAzure { if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization") apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
@ -172,7 +173,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
} }
if relayMode != constant.RelayModeAudioSpeech { if relayMode != relaymode.AudioSpeech {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)

View File

@ -10,8 +10,10 @@ import (
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"math" "math"
"net/http" "net/http"
@ -23,10 +25,10 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
if err != nil { if err != nil {
return nil, err return nil, err
} }
if relayMode == constant.RelayModeModerations && textRequest.Model == "" { if relayMode == relaymode.Moderations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest" textRequest.Model = "text-moderation-latest"
} }
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { if relayMode == relaymode.Embeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model") textRequest.Model = c.Param("model")
} }
err = util.ValidateTextRequest(textRequest, relayMode) err = util.ValidateTextRequest(textRequest, relayMode)
@ -86,7 +88,7 @@ func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.Rela
// Number of generated images validation // Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) { if !isWithinRange(imageRequest.Model, imageRequest.N) {
// channel not azure // channel not azure
if meta.ChannelType != common.ChannelTypeAzure { if meta.ChannelType != channeltype.Azure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
} }
} }
@ -110,11 +112,11 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode { switch relayMode {
case constant.RelayModeChatCompletions: case relaymode.ChatCompletions:
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
case constant.RelayModeCompletions: case relaymode.Completions:
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
case constant.RelayModeModerations: case relaymode.Moderations:
return openai.CountTokenInput(textRequest.Input, textRequest.Model) return openai.CountTokenInput(textRequest.Input, textRequest.Model)
} }
return 0 return 0

View File

@ -11,6 +11,7 @@ import (
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
@ -55,7 +56,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
var requestBody io.Reader var requestBody io.Reader
if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest) jsonStr, err := json.Marshal(imageRequest)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
@ -71,11 +72,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
switch meta.ChannelType { switch meta.ChannelType {
case common.ChannelTypeAli: case channeltype.Ali:
fallthrough fallthrough
case common.ChannelTypeBaidu: case channeltype.Baidu:
fallthrough fallthrough
case common.ChannelTypeZhipu: case channeltype.Zhipu:
finalRequest, err := adaptor.ConvertImageRequest(imageRequest) finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)

View File

@ -7,8 +7,9 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/helper"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
@ -53,9 +54,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
// get request body // get request body
var requestBody io.Reader var requestBody io.Reader
if meta.APIType == constant.APITypeOpenAI { if meta.APIType == apitype.OpenAI {
// no need to convert request for openai // no need to convert request for openai
shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody { if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest) jsonStr, err := json.Marshal(textRequest)
if err != nil { if err != nil {

View File

@ -1,6 +1,7 @@
package helper package helper
import ( import (
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/aiproxy" "github.com/songquanpeng/one-api/relay/channel/aiproxy"
"github.com/songquanpeng/one-api/relay/channel/ali" "github.com/songquanpeng/one-api/relay/channel/ali"
@ -13,32 +14,31 @@ import (
"github.com/songquanpeng/one-api/relay/channel/tencent" "github.com/songquanpeng/one-api/relay/channel/tencent"
"github.com/songquanpeng/one-api/relay/channel/xunfei" "github.com/songquanpeng/one-api/relay/channel/xunfei"
"github.com/songquanpeng/one-api/relay/channel/zhipu" "github.com/songquanpeng/one-api/relay/channel/zhipu"
"github.com/songquanpeng/one-api/relay/constant"
) )
func GetAdaptor(apiType int) channel.Adaptor { func GetAdaptor(apiType int) channel.Adaptor {
switch apiType { switch apiType {
case constant.APITypeAIProxyLibrary: case apitype.AIProxyLibrary:
return &aiproxy.Adaptor{} return &aiproxy.Adaptor{}
case constant.APITypeAli: case apitype.Ali:
return &ali.Adaptor{} return &ali.Adaptor{}
case constant.APITypeAnthropic: case apitype.Anthropic:
return &anthropic.Adaptor{} return &anthropic.Adaptor{}
case constant.APITypeBaidu: case apitype.Baidu:
return &baidu.Adaptor{} return &baidu.Adaptor{}
case constant.APITypeGemini: case apitype.Gemini:
return &gemini.Adaptor{} return &gemini.Adaptor{}
case constant.APITypeOpenAI: case apitype.OpenAI:
return &openai.Adaptor{} return &openai.Adaptor{}
case constant.APITypePaLM: case apitype.PaLM:
return &palm.Adaptor{} return &palm.Adaptor{}
case constant.APITypeTencent: case apitype.Tencent:
return &tencent.Adaptor{} return &tencent.Adaptor{}
case constant.APITypeXunfei: case apitype.Xunfei:
return &xunfei.Adaptor{} return &xunfei.Adaptor{}
case constant.APITypeZhipu: case apitype.Zhipu:
return &zhipu.Adaptor{} return &zhipu.Adaptor{}
case constant.APITypeOllama: case apitype.Ollama:
return &ollama.Adaptor{} return &ollama.Adaptor{}
} }
return nil return nil

14
relay/relaymode/define.go Normal file
View File

@ -0,0 +1,14 @@
package relaymode
const (
Unknown = iota
ChatCompletions
Completions
Embeddings
Moderations
ImagesGenerations
Edits
AudioSpeech
AudioTranscription
AudioTranslation
)

29
relay/relaymode/helper.go Normal file
View File

@ -0,0 +1,29 @@
package relaymode
import "strings"
func GetByPath(path string) int {
relayMode := Unknown
if strings.HasPrefix(path, "/v1/chat/completions") {
relayMode = ChatCompletions
} else if strings.HasPrefix(path, "/v1/completions") {
relayMode = Completions
} else if strings.HasPrefix(path, "/v1/embeddings") {
relayMode = Embeddings
} else if strings.HasSuffix(path, "embeddings") {
relayMode = Embeddings
} else if strings.HasPrefix(path, "/v1/moderations") {
relayMode = Moderations
} else if strings.HasPrefix(path, "/v1/images/generations") {
relayMode = ImagesGenerations
} else if strings.HasPrefix(path, "/v1/edits") {
relayMode = Edits
} else if strings.HasPrefix(path, "/v1/audio/speech") {
relayMode = AudioSpeech
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
relayMode = AudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = AudioTranslation
}
return relayMode
}

View File

@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"
@ -155,9 +156,9 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType { switch channelType {
case common.ChannelTypeOpenAI: case channeltype.OpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure: case channeltype.Azure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
} }
} }

View File

@ -3,7 +3,8 @@ package util
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/relaymode"
"strings" "strings"
) )
@ -30,7 +31,7 @@ type RelayMeta struct {
func GetRelayMeta(c *gin.Context) *RelayMeta { func GetRelayMeta(c *gin.Context) *RelayMeta {
meta := RelayMeta{ meta := RelayMeta{
Mode: constant.Path2RelayMode(c.Request.URL.Path), Mode: relaymode.GetByPath(c.Request.URL.Path),
ChannelType: c.GetInt("channel"), ChannelType: c.GetInt("channel"),
ChannelId: c.GetInt("channel_id"), ChannelId: c.GetInt("channel_id"),
TokenId: c.GetInt("token_id"), TokenId: c.GetInt("token_id"),
@ -44,12 +45,12 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
Config: nil, Config: nil,
RequestURLPath: c.Request.URL.String(), RequestURLPath: c.Request.URL.String(),
} }
if meta.ChannelType == common.ChannelTypeAzure { if meta.ChannelType == channeltype.Azure {
meta.APIVersion = GetAzureAPIVersion(c) meta.APIVersion = GetAzureAPIVersion(c)
} }
if meta.BaseURL == "" { if meta.BaseURL == "" {
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
} }
meta.APIType = constant.ChannelType2APIType(meta.ChannelType) meta.APIType = channeltype.ToAPIType(meta.ChannelType)
return &meta return &meta
} }

View File

@ -2,8 +2,8 @@ package util
import ( import (
"errors" "errors"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"math" "math"
) )
@ -15,20 +15,20 @@ func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int)
return errors.New("model is required") return errors.New("model is required")
} }
switch relayMode { switch relayMode {
case constant.RelayModeCompletions: case relaymode.Completions:
if textRequest.Prompt == "" { if textRequest.Prompt == "" {
return errors.New("field prompt is required") return errors.New("field prompt is required")
} }
case constant.RelayModeChatCompletions: case relaymode.ChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 { if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errors.New("field messages is required") return errors.New("field messages is required")
} }
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
case constant.RelayModeModerations: case relaymode.Moderations:
if textRequest.Input == "" { if textRequest.Input == "" {
return errors.New("field input is required") return errors.New("field input is required")
} }
case constant.RelayModeEdits: case relaymode.Edits:
if textRequest.Instruction == "" { if textRequest.Instruction == "" {
return errors.New("field instruction is required") return errors.New("field instruction is required")
} }