2023-09-17 07:39:46 +00:00
|
|
|
package middleware
|
|
|
|
|
|
|
|
import (
|
2024-04-03 18:08:18 +00:00
|
|
|
"fmt"
|
2023-09-17 07:39:46 +00:00
|
|
|
"github.com/gin-gonic/gin"
|
2024-04-03 18:08:18 +00:00
|
|
|
"github.com/songquanpeng/one-api/common"
|
2024-01-28 11:38:58 +00:00
|
|
|
"github.com/songquanpeng/one-api/common/helper"
|
|
|
|
"github.com/songquanpeng/one-api/common/logger"
|
2024-04-03 18:08:18 +00:00
|
|
|
"strings"
|
2023-09-17 07:39:46 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
|
|
|
c.JSON(statusCode, gin.H{
|
|
|
|
"error": gin.H{
|
2024-04-26 15:05:48 +00:00
|
|
|
"message": helper.MessageWithRequestId(message, c.GetString(helper.RequestIdKey)),
|
2023-09-17 07:39:46 +00:00
|
|
|
"type": "one_api_error",
|
|
|
|
},
|
|
|
|
})
|
|
|
|
c.Abort()
|
2024-01-21 15:21:42 +00:00
|
|
|
logger.Error(c.Request.Context(), message)
|
2023-09-17 07:39:46 +00:00
|
|
|
}
|
2024-04-03 18:08:18 +00:00
|
|
|
|
|
|
|
func getRequestModel(c *gin.Context) (string, error) {
|
|
|
|
var modelRequest ModelRequest
|
|
|
|
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
|
|
|
|
}
|
|
|
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
|
|
|
if modelRequest.Model == "" {
|
|
|
|
modelRequest.Model = "text-moderation-stable"
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
|
|
|
if modelRequest.Model == "" {
|
|
|
|
modelRequest.Model = c.Param("model")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
|
|
|
if modelRequest.Model == "" {
|
|
|
|
modelRequest.Model = "dall-e-2"
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
|
|
|
if modelRequest.Model == "" {
|
|
|
|
modelRequest.Model = "whisper-1"
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return modelRequest.Model, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func isModelInList(modelName string, models string) bool {
|
|
|
|
modelList := strings.Split(models, ",")
|
|
|
|
for _, model := range modelList {
|
|
|
|
if modelName == model {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false
|
|
|
|
}
|