chore: reorganize constant related package

This commit is contained in:
JustSong 2024-04-06 00:44:33 +08:00
parent 880e12c855
commit f9d914873f
30 changed files with 269 additions and 257 deletions

View File

@ -38,44 +38,6 @@ const (
ChannelStatusAutoDisabled = 3
)
const (
ChannelTypeUnknown = iota
ChannelTypeOpenAI
ChannelTypeAPI2D
ChannelTypeAzure
ChannelTypeCloseAI
ChannelTypeOpenAISB
ChannelTypeOpenAIMax
ChannelTypeOhMyGPT
ChannelTypeCustom
ChannelTypeAILS
ChannelTypeAIProxy
ChannelTypePaLM
ChannelTypeAPI2GPT
ChannelTypeAIGC2D
ChannelTypeAnthropic
ChannelTypeBaidu
ChannelTypeZhipu
ChannelTypeAli
ChannelTypeXunfei
ChannelType360
ChannelTypeOpenRouter
ChannelTypeAIProxyLibrary
ChannelTypeFastGPT
ChannelTypeTencent
ChannelTypeGemini
ChannelTypeMoonshot
ChannelTypeBaichuan
ChannelTypeMinimax
ChannelTypeMistral
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeLingYiWanWu
ChannelTypeStepFun
ChannelTypeDummy
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1

View File

@ -9,6 +9,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
@ -209,23 +210,23 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
channel.BaseURL = &baseURL
}
switch channel.Type {
case common.ChannelTypeOpenAI:
case channeltype.OpenAI:
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
case common.ChannelTypeAzure:
case channeltype.Azure:
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
case channeltype.Custom:
baseURL = channel.GetBaseURL()
case common.ChannelTypeCloseAI:
case channeltype.CloseAI:
return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB:
case channeltype.OpenAISB:
return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
case channeltype.AIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:
case channeltype.API2GPT:
return updateChannelAPI2GPTBalance(channel)
case common.ChannelTypeAIGC2D:
case channeltype.AIGC2D:
return updateChannelAIGC2DBalance(channel)
default:
return 0, errors.New("尚未实现")
@ -305,7 +306,7 @@ func updateAllChannelsBalance() error {
continue
}
// TODO: support Azure
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom {
continue
}
balance, err := updateChannelBalance(channel)

View File

@ -12,9 +12,10 @@ import (
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
@ -57,7 +58,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, "")
meta := util.GetRelayMeta(c)
apiType := constant.ChannelType2APIType(channel.Type)
apiType := channeltype.ToAPIType(channel.Type)
adaptor := helper.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
@ -73,7 +74,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
request := buildTestRequest()
request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil {
return err, nil
}

View File

@ -3,10 +3,10 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
@ -62,8 +62,8 @@ func init() {
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
for i := 0; i < constant.APITypeDummy; i++ {
if i == constant.APITypeAIProxyLibrary {
for i := 0; i < apitype.Dummy; i++ {
if i == apitype.AIProxyLibrary {
continue
}
adaptor := helper.GetAdaptor(i)
@ -82,7 +82,7 @@ func init() {
}
}
for _, channelType := range openai.CompatibleChannels {
if channelType == common.ChannelTypeAzure {
if channelType == channeltype.Azure {
continue
}
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
@ -103,8 +103,8 @@ func init() {
modelsMap[model.Id] = model
}
channelId2Models = make(map[int][]string)
for i := 1; i < common.ChannelTypeDummy; i++ {
adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i))
for i := 1; i < channeltype.Dummy; i++ {
adaptor := helper.GetAdaptor(channeltype.ToAPIType(i))
meta := &util.RelayMeta{
ChannelType: i,
}

View File

@ -12,9 +12,9 @@ import (
"github.com/songquanpeng/one-api/middleware"
dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
@ -25,13 +25,13 @@ import (
func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode
switch relayMode {
case constant.RelayModeImagesGenerations:
case relaymode.ImagesGenerations:
err = controller.RelayImageHelper(c, relayMode)
case constant.RelayModeAudioSpeech:
case relaymode.AudioSpeech:
fallthrough
case constant.RelayModeAudioTranslation:
case relaymode.AudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
default:
err = controller.RelayTextHelper(c)
@ -41,7 +41,7 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
func Relay(c *gin.Context) {
ctx := c.Request.Context()
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
relayMode := relaymode.GetByPath(c.Request.URL.Path)
if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody))

View File

@ -28,7 +28,7 @@ func main() {
gin.SetMode(gin.ReleaseMode)
}
if config.DebugEnabled {
logger.SysLog("running in debug mode")
logger.SysLog("running in debug relaymode")
}
var err error
// Initialize SQL Database

View File

@ -6,6 +6,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype"
"net/http"
"strconv"
)
@ -66,15 +67,15 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility
switch channel.Type {
case common.ChannelTypeAzure:
case channeltype.Azure:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei:
case channeltype.Xunfei:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeGemini:
case channeltype.Gemini:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary:
case channeltype.AIProxyLibrary:
c.Set(common.ConfigKeyLibraryID, channel.Other)
case common.ChannelTypeAli:
case channeltype.Ali:
c.Set(common.ConfigKeyPlugin, channel.Other)
}
cfg, _ := channel.LoadConfig()

17
relay/apitype/define.go Normal file
View File

@ -0,0 +1,17 @@
package apitype
const (
OpenAI = iota
Anthropic
PaLM
Baidu
Zhipu
Ali
Xunfei
AIProxyLibrary
Tencent
Gemini
Ollama
Dummy // this one is only for count, do not add any channel after this
)

View File

@ -6,8 +6,8 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
@ -25,9 +25,9 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
fullRequestURL := ""
switch meta.Mode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
case constant.RelayModeImagesGenerations:
case relaymode.ImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
default:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut
}
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.Mode == constant.RelayModeImagesGenerations {
if meta.Mode == relaymode.ImagesGenerations {
req.Header.Set("X-DashScope-Async", "enable")
}
if c.GetString(common.ConfigKeyPlugin) != "" {
@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
return aliEmbeddingRequest, nil
default:
@ -85,9 +85,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
case constant.RelayModeImagesGenerations:
case relaymode.ImagesGenerations:
err, usage = ImageHandler(c, resp)
default:
err, usage = Handler(c, resp)

View File

@ -3,13 +3,13 @@ package baidu
import (
"errors"
"fmt"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
)
@ -100,7 +100,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
@ -125,7 +125,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)

View File

@ -2,13 +2,13 @@ package minimax
import (
"fmt"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util"
)
func GetRequestURL(meta *util.RelayMeta) (string, error) {
if meta.Mode == constant.RelayModeChatCompletions {
if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
}
return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode)
return "", fmt.Errorf("unsupported relay relaymode %d for minimax", meta.Mode)
}

View File

@ -3,12 +3,12 @@ package ollama
import (
"errors"
"fmt"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
)
@ -23,7 +23,7 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings {
if meta.Mode == relaymode.Embeddings {
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
}
return fullRequestURL, nil
@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
return ollamaEmbeddingRequest, nil
default:
@ -64,7 +64,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)

View File

@ -4,11 +4,11 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
@ -25,8 +25,8 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
switch meta.ChannelType {
case common.ChannelTypeAzure:
if meta.Mode == constant.RelayModeImagesGenerations {
case channeltype.Azure:
if meta.Mode == relaymode.ImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion)
@ -43,7 +43,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case common.ChannelTypeMinimax:
case channeltype.Minimax:
return minimax.GetRequestURL(meta)
default:
return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
@ -52,12 +52,12 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
if meta.ChannelType == common.ChannelTypeAzure {
if meta.ChannelType == channeltype.Azure {
req.Header.Set("api-key", meta.APIKey)
return nil
}
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.ChannelType == common.ChannelTypeOpenRouter {
if meta.ChannelType == channeltype.OpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
}
@ -91,7 +91,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
}
} else {
switch meta.Mode {
case constant.RelayModeImagesGenerations:
case relaymode.ImagesGenerations:
err, _ = ImageHandler(c, resp)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)

View File

@ -1,7 +1,6 @@
package openai
import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/groq"
@ -10,39 +9,40 @@ import (
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/channel/stepfun"
"github.com/songquanpeng/one-api/relay/channeltype"
)
var CompatibleChannels = []int{
common.ChannelTypeAzure,
common.ChannelType360,
common.ChannelTypeMoonshot,
common.ChannelTypeBaichuan,
common.ChannelTypeMinimax,
common.ChannelTypeMistral,
common.ChannelTypeGroq,
common.ChannelTypeLingYiWanWu,
common.ChannelTypeStepFun,
channeltype.Azure,
channeltype.AI360,
channeltype.Moonshot,
channeltype.Baichuan,
channeltype.Minimax,
channeltype.Mistral,
channeltype.Groq,
channeltype.LingYiWanWu,
channeltype.StepFun,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
switch channelType {
case common.ChannelTypeAzure:
case channeltype.Azure:
return "azure", ModelList
case common.ChannelType360:
case channeltype.AI360:
return "360", ai360.ModelList
case common.ChannelTypeMoonshot:
case channeltype.Moonshot:
return "moonshot", moonshot.ModelList
case common.ChannelTypeBaichuan:
case channeltype.Baichuan:
return "baichuan", baichuan.ModelList
case common.ChannelTypeMinimax:
case channeltype.Minimax:
return "minimax", minimax.ModelList
case common.ChannelTypeMistral:
case channeltype.Mistral:
return "mistralai", mistral.ModelList
case common.ChannelTypeGroq:
case channeltype.Groq:
return "groq", groq.ModelList
case common.ChannelTypeLingYiWanWu:
case channeltype.LingYiWanWu:
return "lingyiwanwu", lingyiwanwu.ModelList
case common.ChannelTypeStepFun:
case channeltype.StepFun:
return "stepfun", stepfun.ModelList
default:
return "openai", ModelList

View File

@ -8,8 +8,8 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
@ -46,7 +46,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
switch relayMode {
case constant.RelayModeChatCompletions:
case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
@ -59,7 +59,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
case constant.RelayModeCompletions:
case relaymode.Completions:
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {

View File

@ -6,8 +6,8 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util"
"io"
"math"
@ -33,9 +33,9 @@ func (a *Adaptor) SetVersionByModeName(modelName string) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
switch meta.Mode {
case constant.RelayModeImagesGenerations:
case relaymode.ImagesGenerations:
return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
a.SetVersionByModeName(meta.ActualModelName)
@ -61,7 +61,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
@ -107,10 +107,10 @@ func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.R
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
switch meta.Mode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
err, usage = EmbeddingsHandler(c, resp)
return
case constant.RelayModeImagesGenerations:
case relaymode.ImagesGenerations:
err, usage = openai.ImageHandler(c, resp)
return
}
@ -120,7 +120,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
if meta.Mode == constant.RelayModeEmbeddings {
if meta.Mode == relaymode.Embeddings {
err, usage = EmbeddingsHandler(c, resp)
} else {
err, usage = Handler(c, resp)

View File

@ -0,0 +1,39 @@
package channeltype
const (
Unknown = iota
OpenAI
API2D
Azure
CloseAI
OpenAISB
OpenAIMax
OhMyGPT
Custom
Ails
AIProxy
PaLM
API2GPT
AIGC2D
Anthropic
Baidu
Zhipu
Ali
Xunfei
AI360
OpenRouter
AIProxyLibrary
FastGPT
Tencent
Gemini
Moonshot
Baichuan
Minimax
Mistral
Groq
Ollama
LingYiWanWu
StepFun
Dummy
)

View File

@ -0,0 +1,30 @@
package channeltype
import "github.com/songquanpeng/one-api/relay/apitype"
func ToAPIType(channelType int) int {
apiType := apitype.OpenAI
switch channelType {
case Anthropic:
apiType = apitype.Anthropic
case Baidu:
apiType = apitype.Baidu
case PaLM:
apiType = apitype.PaLM
case Zhipu:
apiType = apitype.Zhipu
case Ali:
apiType = apitype.Ali
case Xunfei:
apiType = apitype.Xunfei
case AIProxyLibrary:
apiType = apitype.AIProxyLibrary
case Tencent:
apiType = apitype.Tencent
case Gemini:
apiType = apitype.Gemini
case Ollama:
apiType = apitype.Ollama
}
return apiType
}

View File

@ -1,48 +0,0 @@
package constant
import (
"github.com/songquanpeng/one-api/common"
)
const (
APITypeOpenAI = iota
APITypeAnthropic
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
APITypeOllama
APITypeDummy // this one is only for count, do not add any channel after this
)
func ChannelType2APIType(channelType int) int {
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeAnthropic
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
case common.ChannelTypeOllama:
apiType = APITypeOllama
}
return apiType
}

View File

@ -1,42 +1 @@
package constant
import "strings"
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
func Path2RelayMode(path string) int {
relayMode := RelayModeUnknown
if strings.HasPrefix(path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
}
return relayMode
}

View File

@ -13,8 +13,9 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/channeltype"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
@ -33,7 +34,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
tokenName := c.GetString("token_name")
var ttsRequest openai.TextToSpeechRequest
if relayMode == constant.RelayModeAudioSpeech {
if relayMode == relaymode.AudioSpeech {
// Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid
@ -53,7 +54,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
var quota int64
var preConsumedQuota int64
switch relayMode {
case constant.RelayModeAudioSpeech:
case relaymode.AudioSpeech:
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
default:
@ -122,12 +123,12 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure {
if channelType == channeltype.Azure {
apiVersion := util.GetAzureAPIVersion(c)
if relayMode == constant.RelayModeAudioTranscription {
if relayMode == relaymode.AudioTranscription {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
} else if relayMode == constant.RelayModeAudioSpeech {
} else if relayMode == relaymode.AudioSpeech {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion)
}
@ -146,7 +147,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if (relayMode == constant.RelayModeAudioTranscription || relayMode == constant.RelayModeAudioSpeech) && channelType == common.ChannelTypeAzure {
if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
@ -172,7 +173,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if relayMode != constant.RelayModeAudioSpeech {
if relayMode != relaymode.AudioSpeech {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)

View File

@ -10,8 +10,10 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/constant"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util"
"math"
"net/http"
@ -23,10 +25,10 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
if err != nil {
return nil, err
}
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
if relayMode == relaymode.Moderations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
if relayMode == relaymode.Embeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
err = util.ValidateTextRequest(textRequest, relayMode)
@ -86,7 +88,7 @@ func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.Rela
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
// channel not azure
if meta.ChannelType != common.ChannelTypeAzure {
if meta.ChannelType != channeltype.Azure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
}
@ -110,11 +112,11 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode {
case constant.RelayModeChatCompletions:
case relaymode.ChatCompletions:
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
case constant.RelayModeCompletions:
case relaymode.Completions:
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
case constant.RelayModeModerations:
case relaymode.Moderations:
return openai.CountTokenInput(textRequest.Input, textRequest.Model)
}
return 0

View File

@ -11,6 +11,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
@ -55,7 +56,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
var requestBody io.Reader
if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body
if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
@ -71,11 +72,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
switch meta.ChannelType {
case common.ChannelTypeAli:
case channeltype.Ali:
fallthrough
case common.ChannelTypeBaidu:
case channeltype.Baidu:
fallthrough
case common.ChannelTypeZhipu:
case channeltype.Zhipu:
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)

View File

@ -7,8 +7,9 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
@ -53,9 +54,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
// get request body
var requestBody io.Reader
if meta.APIType == constant.APITypeOpenAI {
if meta.APIType == apitype.OpenAI {
// no need to convert request for openai
shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {

View File

@ -1,6 +1,7 @@
package helper
import (
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/aiproxy"
"github.com/songquanpeng/one-api/relay/channel/ali"
@ -13,32 +14,31 @@ import (
"github.com/songquanpeng/one-api/relay/channel/tencent"
"github.com/songquanpeng/one-api/relay/channel/xunfei"
"github.com/songquanpeng/one-api/relay/channel/zhipu"
"github.com/songquanpeng/one-api/relay/constant"
)
func GetAdaptor(apiType int) channel.Adaptor {
switch apiType {
case constant.APITypeAIProxyLibrary:
case apitype.AIProxyLibrary:
return &aiproxy.Adaptor{}
case constant.APITypeAli:
case apitype.Ali:
return &ali.Adaptor{}
case constant.APITypeAnthropic:
case apitype.Anthropic:
return &anthropic.Adaptor{}
case constant.APITypeBaidu:
case apitype.Baidu:
return &baidu.Adaptor{}
case constant.APITypeGemini:
case apitype.Gemini:
return &gemini.Adaptor{}
case constant.APITypeOpenAI:
case apitype.OpenAI:
return &openai.Adaptor{}
case constant.APITypePaLM:
case apitype.PaLM:
return &palm.Adaptor{}
case constant.APITypeTencent:
case apitype.Tencent:
return &tencent.Adaptor{}
case constant.APITypeXunfei:
case apitype.Xunfei:
return &xunfei.Adaptor{}
case constant.APITypeZhipu:
case apitype.Zhipu:
return &zhipu.Adaptor{}
case constant.APITypeOllama:
case apitype.Ollama:
return &ollama.Adaptor{}
}
return nil

14
relay/relaymode/define.go Normal file
View File

@ -0,0 +1,14 @@
package relaymode
const (
Unknown = iota
ChatCompletions
Completions
Embeddings
Moderations
ImagesGenerations
Edits
AudioSpeech
AudioTranscription
AudioTranslation
)

29
relay/relaymode/helper.go Normal file
View File

@ -0,0 +1,29 @@
package relaymode
import "strings"
func GetByPath(path string) int {
relayMode := Unknown
if strings.HasPrefix(path, "/v1/chat/completions") {
relayMode = ChatCompletions
} else if strings.HasPrefix(path, "/v1/completions") {
relayMode = Completions
} else if strings.HasPrefix(path, "/v1/embeddings") {
relayMode = Embeddings
} else if strings.HasSuffix(path, "embeddings") {
relayMode = Embeddings
} else if strings.HasPrefix(path, "/v1/moderations") {
relayMode = Moderations
} else if strings.HasPrefix(path, "/v1/images/generations") {
relayMode = ImagesGenerations
} else if strings.HasPrefix(path, "/v1/edits") {
relayMode = Edits
} else if strings.HasPrefix(path, "/v1/audio/speech") {
relayMode = AudioSpeech
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
relayMode = AudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = AudioTranslation
}
return relayMode
}

View File

@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
@ -155,9 +156,9 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case common.ChannelTypeOpenAI:
case channeltype.OpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure:
case channeltype.Azure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}

View File

@ -3,7 +3,8 @@ package util
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/relaymode"
"strings"
)
@ -30,7 +31,7 @@ type RelayMeta struct {
func GetRelayMeta(c *gin.Context) *RelayMeta {
meta := RelayMeta{
Mode: constant.Path2RelayMode(c.Request.URL.Path),
Mode: relaymode.GetByPath(c.Request.URL.Path),
ChannelType: c.GetInt("channel"),
ChannelId: c.GetInt("channel_id"),
TokenId: c.GetInt("token_id"),
@ -44,12 +45,12 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
Config: nil,
RequestURLPath: c.Request.URL.String(),
}
if meta.ChannelType == common.ChannelTypeAzure {
if meta.ChannelType == channeltype.Azure {
meta.APIVersion = GetAzureAPIVersion(c)
}
if meta.BaseURL == "" {
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
}
meta.APIType = constant.ChannelType2APIType(meta.ChannelType)
meta.APIType = channeltype.ToAPIType(meta.ChannelType)
return &meta
}

View File

@ -2,8 +2,8 @@ package util
import (
"errors"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"math"
)
@ -15,20 +15,20 @@ func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int)
return errors.New("model is required")
}
switch relayMode {
case constant.RelayModeCompletions:
case relaymode.Completions:
if textRequest.Prompt == "" {
return errors.New("field prompt is required")
}
case constant.RelayModeChatCompletions:
case relaymode.ChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errors.New("field messages is required")
}
case constant.RelayModeEmbeddings:
case constant.RelayModeModerations:
case relaymode.Embeddings:
case relaymode.Moderations:
if textRequest.Input == "" {
return errors.New("field input is required")
}
case constant.RelayModeEdits:
case relaymode.Edits:
if textRequest.Instruction == "" {
return errors.New("field instruction is required")
}