♻️ refactor: 重构moderation接口
This commit is contained in:
parent
455269c145
commit
1c7c2d40bb
3
.gitignore
vendored
3
.gitignore
vendored
@ -6,4 +6,5 @@ upload
|
|||||||
build
|
build
|
||||||
*.db-journal
|
*.db-journal
|
||||||
logs
|
logs
|
||||||
data
|
data
|
||||||
|
tmp/
|
@ -214,3 +214,16 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://fastgpt.run/api/openapi", // 22
|
"https://fastgpt.run/api/openapi", // 22
|
||||||
"https://hunyuan.cloud.tencent.com", //23
|
"https://hunyuan.cloud.tencent.com", //23
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
RelayModeUnknown = iota
|
||||||
|
RelayModeChatCompletions
|
||||||
|
RelayModeCompletions
|
||||||
|
RelayModeEmbeddings
|
||||||
|
RelayModeModerations
|
||||||
|
RelayModeImagesGenerations
|
||||||
|
RelayModeEdits
|
||||||
|
RelayModeAudioSpeech
|
||||||
|
RelayModeAudioTranscription
|
||||||
|
RelayModeAudioTranslation
|
||||||
|
)
|
||||||
|
@ -24,7 +24,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatus
|
|||||||
// 获取 Provider
|
// 获取 Provider
|
||||||
provider := providers.GetProvider(channelType, c)
|
provider := providers.GetProvider(channelType, c)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
return types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
return types.ErrorWrapper(errors.New("channel not found"), "channel_not_found", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !provider.SupportAPI(relayMode) {
|
||||||
|
return types.ErrorWrapper(errors.New("channel does not support this API"), "channel_not_support_api", http.StatusNotImplemented)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
||||||
@ -45,12 +49,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatus
|
|||||||
var openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode
|
var openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode
|
||||||
|
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case RelayModeChatCompletions:
|
case common.RelayModeChatCompletions:
|
||||||
usage, openAIErrorWithStatusCode = handleChatCompletions(c, provider, modelMap, quotaInfo, group)
|
usage, openAIErrorWithStatusCode = handleChatCompletions(c, provider, modelMap, quotaInfo, group)
|
||||||
case RelayModeCompletions:
|
case common.RelayModeCompletions:
|
||||||
usage, openAIErrorWithStatusCode = handleCompletions(c, provider, modelMap, quotaInfo, group)
|
usage, openAIErrorWithStatusCode = handleCompletions(c, provider, modelMap, quotaInfo, group)
|
||||||
case RelayModeEmbeddings:
|
case common.RelayModeEmbeddings:
|
||||||
usage, openAIErrorWithStatusCode = handleEmbeddings(c, provider, modelMap, quotaInfo, group)
|
usage, openAIErrorWithStatusCode = handleEmbeddings(c, provider, modelMap, quotaInfo, group)
|
||||||
|
case common.RelayModeModerations:
|
||||||
|
usage, openAIErrorWithStatusCode = handleModerations(c, provider, modelMap, quotaInfo, group)
|
||||||
default:
|
default:
|
||||||
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
|
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
@ -84,14 +90,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatus
|
|||||||
func handleChatCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
func handleChatCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||||
var chatRequest types.ChatCompletionRequest
|
var chatRequest types.ChatCompletionRequest
|
||||||
isModelMapped := false
|
isModelMapped := false
|
||||||
|
|
||||||
chatProvider, ok := provider.(providers_base.ChatInterface)
|
chatProvider, ok := provider.(providers_base.ChatInterface)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := common.UnmarshalBodyReusable(c, &chatRequest)
|
err := common.UnmarshalBodyReusable(c, &chatRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if chatRequest.Messages == nil || len(chatRequest.Messages) == 0 {
|
||||||
|
return nil, types.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
if modelMap != nil && modelMap[chatRequest.Model] != "" {
|
if modelMap != nil && modelMap[chatRequest.Model] != "" {
|
||||||
chatRequest.Model = modelMap[chatRequest.Model]
|
chatRequest.Model = modelMap[chatRequest.Model]
|
||||||
isModelMapped = true
|
isModelMapped = true
|
||||||
@ -114,10 +127,16 @@ func handleCompletions(c *gin.Context, provider providers_base.ProviderInterface
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := common.UnmarshalBodyReusable(c, &completionRequest)
|
err := common.UnmarshalBodyReusable(c, &completionRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if completionRequest.Prompt == "" {
|
||||||
|
return nil, types.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
if modelMap != nil && modelMap[completionRequest.Model] != "" {
|
if modelMap != nil && modelMap[completionRequest.Model] != "" {
|
||||||
completionRequest.Model = modelMap[completionRequest.Model]
|
completionRequest.Model = modelMap[completionRequest.Model]
|
||||||
isModelMapped = true
|
isModelMapped = true
|
||||||
@ -140,10 +159,16 @@ func handleEmbeddings(c *gin.Context, provider providers_base.ProviderInterface,
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := common.UnmarshalBodyReusable(c, &embeddingsRequest)
|
err := common.UnmarshalBodyReusable(c, &embeddingsRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if embeddingsRequest.Input == "" {
|
||||||
|
return nil, types.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
|
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
|
||||||
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
|
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
|
||||||
isModelMapped = true
|
isModelMapped = true
|
||||||
@ -158,3 +183,39 @@ func handleEmbeddings(c *gin.Context, provider providers_base.ProviderInterface,
|
|||||||
}
|
}
|
||||||
return embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens)
|
return embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleModerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
var moderationRequest types.ModerationRequest
|
||||||
|
isModelMapped := false
|
||||||
|
moderationProvider, ok := provider.(providers_base.ModerationInterface)
|
||||||
|
if !ok {
|
||||||
|
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := common.UnmarshalBodyReusable(c, &moderationRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
if moderationRequest.Input == "" {
|
||||||
|
return nil, types.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
if moderationRequest.Model == "" {
|
||||||
|
moderationRequest.Model = "text-moderation-latest"
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelMap != nil && modelMap[moderationRequest.Model] != "" {
|
||||||
|
moderationRequest.Model = modelMap[moderationRequest.Model]
|
||||||
|
isModelMapped = true
|
||||||
|
}
|
||||||
|
promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model)
|
||||||
|
|
||||||
|
quotaInfo.modelName = moderationRequest.Model
|
||||||
|
quotaInfo.initQuotaInfo(group)
|
||||||
|
quota_err := quotaInfo.preQuotaConsumption()
|
||||||
|
if quota_err != nil {
|
||||||
|
return nil, quota_err
|
||||||
|
}
|
||||||
|
return moderationProvider.ModerationAction(&moderationRequest, isModelMapped, promptTokens)
|
||||||
|
}
|
||||||
|
@ -56,19 +56,6 @@ func (m Message) StringContent() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
RelayModeUnknown = iota
|
|
||||||
RelayModeChatCompletions
|
|
||||||
RelayModeCompletions
|
|
||||||
RelayModeEmbeddings
|
|
||||||
RelayModeModerations
|
|
||||||
RelayModeImagesGenerations
|
|
||||||
RelayModeEdits
|
|
||||||
RelayModeAudioSpeech
|
|
||||||
RelayModeAudioTranscription
|
|
||||||
RelayModeAudioTranslation
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/chat
|
// https://platform.openai.com/docs/api-reference/chat
|
||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
@ -237,21 +224,18 @@ type CompletionsStreamResponse struct {
|
|||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
var err *types.OpenAIErrorWithStatusCode
|
var err *types.OpenAIErrorWithStatusCode
|
||||||
|
|
||||||
relayMode := RelayModeUnknown
|
relayMode := common.RelayModeUnknown
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||||
// err = relayChatHelper(c)
|
relayMode = common.RelayModeChatCompletions
|
||||||
relayMode = RelayModeChatCompletions
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
||||||
// err = relayCompletionHelper(c)
|
relayMode = common.RelayModeCompletions
|
||||||
relayMode = RelayModeCompletions
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
||||||
// err = relayEmbeddingsHelper(c)
|
relayMode = common.RelayModeEmbeddings
|
||||||
relayMode = RelayModeEmbeddings
|
|
||||||
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||||
relayMode = RelayModeEmbeddings
|
relayMode = common.RelayModeEmbeddings
|
||||||
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
|
relayMode = common.RelayModeModerations
|
||||||
}
|
}
|
||||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
|
||||||
// relayMode = RelayModeModerations
|
|
||||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
// relayMode = RelayModeImagesGenerations
|
// relayMode = RelayModeImagesGenerations
|
||||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
||||||
|
@ -21,6 +21,7 @@ type BaseProvider struct {
|
|||||||
ChatCompletions string
|
ChatCompletions string
|
||||||
Embeddings string
|
Embeddings string
|
||||||
AudioSpeech string
|
AudioSpeech string
|
||||||
|
Moderation string
|
||||||
AudioTranscriptions string
|
AudioTranscriptions string
|
||||||
AudioTranslations string
|
AudioTranslations string
|
||||||
Proxy string
|
Proxy string
|
||||||
@ -125,3 +126,24 @@ func (p *BaseProvider) HandleErrorResp(resp *http.Response) (openAIErrorWithStat
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *BaseProvider) SupportAPI(relayMode int) bool {
|
||||||
|
switch relayMode {
|
||||||
|
case common.RelayModeChatCompletions:
|
||||||
|
return p.ChatCompletions != ""
|
||||||
|
case common.RelayModeCompletions:
|
||||||
|
return p.Completions != ""
|
||||||
|
case common.RelayModeEmbeddings:
|
||||||
|
return p.Embeddings != ""
|
||||||
|
case common.RelayModeAudioSpeech:
|
||||||
|
return p.AudioSpeech != ""
|
||||||
|
case common.RelayModeAudioTranscription:
|
||||||
|
return p.AudioTranscriptions != ""
|
||||||
|
case common.RelayModeAudioTranslation:
|
||||||
|
return p.AudioTranslations != ""
|
||||||
|
case common.RelayModeModerations:
|
||||||
|
return p.Moderation != ""
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -11,6 +11,7 @@ type ProviderInterface interface {
|
|||||||
GetBaseURL() string
|
GetBaseURL() string
|
||||||
GetFullRequestURL(requestURL string, modelName string) string
|
GetFullRequestURL(requestURL string, modelName string) string
|
||||||
GetRequestHeaders() (headers map[string]string)
|
GetRequestHeaders() (headers map[string]string)
|
||||||
|
SupportAPI(relayMode int) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// 完成接口
|
// 完成接口
|
||||||
@ -31,6 +32,12 @@ type EmbeddingsInterface interface {
|
|||||||
EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 审查接口
|
||||||
|
type ModerationInterface interface {
|
||||||
|
ProviderInterface
|
||||||
|
ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
// 余额接口
|
// 余额接口
|
||||||
type BalanceInterface interface {
|
type BalanceInterface interface {
|
||||||
BalanceAction(channel *model.Channel) (float64, error)
|
BalanceAction(channel *model.Channel) (float64, error)
|
||||||
|
@ -34,6 +34,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
|||||||
Completions: "/v1/completions",
|
Completions: "/v1/completions",
|
||||||
ChatCompletions: "/v1/chat/completions",
|
ChatCompletions: "/v1/chat/completions",
|
||||||
Embeddings: "/v1/embeddings",
|
Embeddings: "/v1/embeddings",
|
||||||
|
Moderation: "/v1/moderations",
|
||||||
AudioSpeech: "/v1/audio/speech",
|
AudioSpeech: "/v1/audio/speech",
|
||||||
AudioTranscriptions: "/v1/audio/transcriptions",
|
AudioTranscriptions: "/v1/audio/transcriptions",
|
||||||
AudioTranslations: "/v1/audio/translations",
|
AudioTranslations: "/v1/audio/translations",
|
||||||
|
49
providers/openai/moderation.go
Normal file
49
providers/openai/moderation.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *OpenAIProviderModerationResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if c.Error.Type != "" {
|
||||||
|
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: c.Error,
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
|
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.Moderation, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIProviderModerationResponse := &OpenAIProviderModerationResponse{}
|
||||||
|
errWithCode = p.sendRequest(req, openAIProviderModerationResponse)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = &types.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: promptTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
@ -21,3 +21,8 @@ type OpenAIProviderEmbeddingsResponse struct {
|
|||||||
types.EmbeddingResponse
|
types.EmbeddingResponse
|
||||||
types.OpenAIErrorResponse
|
types.OpenAIErrorResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpenAIProviderModerationResponse struct {
|
||||||
|
types.ModerationResponse
|
||||||
|
types.OpenAIErrorResponse
|
||||||
|
}
|
||||||
|
12
types/moderation.go
Normal file
12
types/moderation.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type ModerationRequest struct {
|
||||||
|
Input string `json:"input,omitempty"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModerationResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Results any `json:"results"`
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user