♻️ refactor: 重构moderation接口

This commit is contained in:
Martial BE 2023-11-29 16:54:37 +08:00
parent 455269c145
commit 1c7c2d40bb
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
10 changed files with 183 additions and 28 deletions

3
.gitignore vendored
View File

@ -6,4 +6,5 @@ upload
build
*.db-journal
logs
data
data
tmp/

View File

@ -214,3 +214,16 @@ var ChannelBaseURLs = []string{
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)

View File

@ -24,7 +24,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatus
// 获取 Provider
provider := providers.GetProvider(channelType, c)
if provider == nil {
return types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
return types.ErrorWrapper(errors.New("channel not found"), "channel_not_found", http.StatusNotImplemented)
}
if !provider.SupportAPI(relayMode) {
return types.ErrorWrapper(errors.New("channel does not support this API"), "channel_not_support_api", http.StatusNotImplemented)
}
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
@ -45,12 +49,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatus
var openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeChatCompletions:
case common.RelayModeChatCompletions:
usage, openAIErrorWithStatusCode = handleChatCompletions(c, provider, modelMap, quotaInfo, group)
case RelayModeCompletions:
case common.RelayModeCompletions:
usage, openAIErrorWithStatusCode = handleCompletions(c, provider, modelMap, quotaInfo, group)
case RelayModeEmbeddings:
case common.RelayModeEmbeddings:
usage, openAIErrorWithStatusCode = handleEmbeddings(c, provider, modelMap, quotaInfo, group)
case common.RelayModeModerations:
usage, openAIErrorWithStatusCode = handleModerations(c, provider, modelMap, quotaInfo, group)
default:
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
}
@ -84,14 +90,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatus
func handleChatCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
var chatRequest types.ChatCompletionRequest
isModelMapped := false
chatProvider, ok := provider.(providers_base.ChatInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
err := common.UnmarshalBodyReusable(c, &chatRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if chatRequest.Messages == nil || len(chatRequest.Messages) == 0 {
return nil, types.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
}
if modelMap != nil && modelMap[chatRequest.Model] != "" {
chatRequest.Model = modelMap[chatRequest.Model]
isModelMapped = true
@ -114,10 +127,16 @@ func handleCompletions(c *gin.Context, provider providers_base.ProviderInterface
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
err := common.UnmarshalBodyReusable(c, &completionRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if completionRequest.Prompt == "" {
return nil, types.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
}
if modelMap != nil && modelMap[completionRequest.Model] != "" {
completionRequest.Model = modelMap[completionRequest.Model]
isModelMapped = true
@ -140,10 +159,16 @@ func handleEmbeddings(c *gin.Context, provider providers_base.ProviderInterface,
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
err := common.UnmarshalBodyReusable(c, &embeddingsRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if embeddingsRequest.Input == "" {
return nil, types.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
}
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
isModelMapped = true
@ -158,3 +183,39 @@ func handleEmbeddings(c *gin.Context, provider providers_base.ProviderInterface,
}
return embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens)
}
func handleModerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
var moderationRequest types.ModerationRequest
isModelMapped := false
moderationProvider, ok := provider.(providers_base.ModerationInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
err := common.UnmarshalBodyReusable(c, &moderationRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if moderationRequest.Input == "" {
return nil, types.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
}
if moderationRequest.Model == "" {
moderationRequest.Model = "text-moderation-latest"
}
if modelMap != nil && modelMap[moderationRequest.Model] != "" {
moderationRequest.Model = modelMap[moderationRequest.Model]
isModelMapped = true
}
promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model)
quotaInfo.modelName = moderationRequest.Model
quotaInfo.initQuotaInfo(group)
quota_err := quotaInfo.preQuotaConsumption()
if quota_err != nil {
return nil, quota_err
}
return moderationProvider.ModerationAction(&moderationRequest, isModelMapped, promptTokens)
}

View File

@ -56,19 +56,6 @@ func (m Message) StringContent() string {
return ""
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
// https://platform.openai.com/docs/api-reference/chat
type ResponseFormat struct {
@ -237,21 +224,18 @@ type CompletionsStreamResponse struct {
func Relay(c *gin.Context) {
var err *types.OpenAIErrorWithStatusCode
relayMode := RelayModeUnknown
relayMode := common.RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
// err = relayChatHelper(c)
relayMode = RelayModeChatCompletions
relayMode = common.RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
// err = relayCompletionHelper(c)
relayMode = RelayModeCompletions
relayMode = common.RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
// err = relayEmbeddingsHelper(c)
relayMode = RelayModeEmbeddings
relayMode = common.RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
relayMode = common.RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = common.RelayModeModerations
}
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
// relayMode = RelayModeModerations
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
// relayMode = RelayModeImagesGenerations
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {

View File

@ -21,6 +21,7 @@ type BaseProvider struct {
ChatCompletions string
Embeddings string
AudioSpeech string
Moderation string
AudioTranscriptions string
AudioTranslations string
Proxy string
@ -125,3 +126,24 @@ func (p *BaseProvider) HandleErrorResp(resp *http.Response) (openAIErrorWithStat
}
return
}
func (p *BaseProvider) SupportAPI(relayMode int) bool {
switch relayMode {
case common.RelayModeChatCompletions:
return p.ChatCompletions != ""
case common.RelayModeCompletions:
return p.Completions != ""
case common.RelayModeEmbeddings:
return p.Embeddings != ""
case common.RelayModeAudioSpeech:
return p.AudioSpeech != ""
case common.RelayModeAudioTranscription:
return p.AudioTranscriptions != ""
case common.RelayModeAudioTranslation:
return p.AudioTranslations != ""
case common.RelayModeModerations:
return p.Moderation != ""
default:
return false
}
}

View File

@ -11,6 +11,7 @@ type ProviderInterface interface {
GetBaseURL() string
GetFullRequestURL(requestURL string, modelName string) string
GetRequestHeaders() (headers map[string]string)
SupportAPI(relayMode int) bool
}
// 完成接口
@ -31,6 +32,12 @@ type EmbeddingsInterface interface {
EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
}
// 审查接口
type ModerationInterface interface {
ProviderInterface
ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
}
// 余额接口
type BalanceInterface interface {
BalanceAction(channel *model.Channel) (float64, error)

View File

@ -34,6 +34,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
Completions: "/v1/completions",
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
Moderation: "/v1/moderations",
AudioSpeech: "/v1/audio/speech",
AudioTranscriptions: "/v1/audio/transcriptions",
AudioTranslations: "/v1/audio/translations",

View File

@ -0,0 +1,49 @@
package openai
import (
"net/http"
"one-api/common"
"one-api/types"
)
func (c *OpenAIProviderModerationResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
}
return nil
}
func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.getRequestBody(&request, isModelMapped)
if err != nil {
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.Moderation, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderModerationResponse := &OpenAIProviderModerationResponse{}
errWithCode = p.sendRequest(req, openAIProviderModerationResponse)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
return
}

View File

@ -21,3 +21,8 @@ type OpenAIProviderEmbeddingsResponse struct {
types.EmbeddingResponse
types.OpenAIErrorResponse
}
type OpenAIProviderModerationResponse struct {
types.ModerationResponse
types.OpenAIErrorResponse
}

12
types/moderation.go Normal file
View File

@ -0,0 +1,12 @@
package types
type ModerationRequest struct {
Input string `json:"input,omitempty"`
Model string `json:"model,omitempty"`
}
type ModerationResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Results any `json:"results"`
}