Merge branch 'main' into main

This commit is contained in:
JustSong 2024-03-03 23:57:00 +08:00 committed by GitHub
commit 4716cfbf12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 221 additions and 131 deletions

View File

@ -67,6 +67,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
+ [x] [Anthropic Claude 系列模型](https://anthropic.com) + [x] [Anthropic Claude 系列模型](https://anthropic.com)
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
+ [x] [Mistral 系列模型](https://mistral.ai/)
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)

View File

@ -66,6 +66,7 @@ const (
ChannelTypeMoonshot = 25 ChannelTypeMoonshot = 25
ChannelTypeBaichuan = 26 ChannelTypeBaichuan = 26
ChannelTypeMinimax = 27 ChannelTypeMinimax = 27
ChannelTypeMistral = 28
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{
@ -97,6 +98,7 @@ var ChannelBaseURLs = []string{
"https://api.moonshot.cn", // 25 "https://api.moonshot.cn", // 25
"https://api.baichuan-ai.com", // 26 "https://api.baichuan-ai.com", // 26
"https://api.minimax.chat", // 27 "https://api.minimax.chat", // 27
"https://api.mistral.ai", // 28
} }
const ( const (

View File

@ -7,29 +7,6 @@ import (
"time" "time"
) )
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}
const ( const (
USD2RMB = 7 USD2RMB = 7
USD = 500 // $0.002 = 1 -> $1 = 500 USD = 500 // $0.002 = 1 -> $1 = 500
@ -40,7 +17,6 @@ const (
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
// https://openai.com/pricing // https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens // 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens // 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{ var ModelRatio = map[string]float64{
@ -141,15 +117,29 @@ var ModelRatio = map[string]float64{
"abab6-chat": 0.1 * RMB, "abab6-chat": 0.1 * RMB,
"abab5.5-chat": 0.015 * RMB, "abab5.5-chat": 0.015 * RMB,
"abab5.5s-chat": 0.005 * RMB, "abab5.5s-chat": 0.005 * RMB,
// https://docs.mistral.ai/platform/pricing/
"open-mistral-7b": 0.25 / 1000 * USD,
"open-mixtral-8x7b": 0.7 / 1000 * USD,
"mistral-small-latest": 2.0 / 1000 * USD,
"mistral-medium-latest": 2.7 / 1000 * USD,
"mistral-large-latest": 8.0 / 1000 * USD,
"mistral-embed": 0.1 / 1000 * USD,
} }
var CompletionRatio = map[string]float64{}
var DefaultModelRatio map[string]float64 var DefaultModelRatio map[string]float64
var DefaultCompletionRatio map[string]float64
func init() { func init() {
DefaultModelRatio = make(map[string]float64) DefaultModelRatio = make(map[string]float64)
for k, v := range ModelRatio { for k, v := range ModelRatio {
DefaultModelRatio[k] = v DefaultModelRatio[k] = v
} }
DefaultCompletionRatio = make(map[string]float64)
for k, v := range CompletionRatio {
DefaultCompletionRatio[k] = v
}
} }
func ModelRatio2JSONString() string { func ModelRatio2JSONString() string {
@ -180,8 +170,6 @@ func GetModelRatio(name string) float64 {
return ratio return ratio
} }
var CompletionRatio = map[string]float64{}
func CompletionRatio2JSONString() string { func CompletionRatio2JSONString() string {
jsonBytes, err := json.Marshal(CompletionRatio) jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil { if err != nil {
@ -199,6 +187,9 @@ func GetCompletionRatio(name string) float64 {
if ratio, ok := CompletionRatio[name]; ok { if ratio, ok := CompletionRatio[name]; ok {
return ratio return ratio
} }
if ratio, ok := DefaultCompletionRatio[name]; ok {
return ratio
}
if strings.HasPrefix(name, "gpt-3.5") { if strings.HasPrefix(name, "gpt-3.5") {
if strings.HasSuffix(name, "0125") { if strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates // https://openai.com/blog/new-embedding-models-and-api-updates
@ -231,5 +222,8 @@ func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "claude-2") { if strings.HasPrefix(name, "claude-2") {
return 2.965517 return 2.965517
} }
if strings.HasPrefix(name, "mistral-") {
return 3
}
return 1 return 1
} }

8
common/random.go Normal file
View File

@ -0,0 +1,8 @@
package common
import "math/rand"
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

View File

@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/helper"
@ -18,6 +19,7 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@ -51,6 +53,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
c.Request.Header.Set("Content-Type", "application/json") c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, "")
meta := util.GetRelayMeta(c) meta := util.GetRelayMeta(c)
apiType := constant.ChannelType2APIType(channel.Type) apiType := constant.ChannelType2APIType(channel.Type)
adaptor := helper.GetAdaptor(apiType) adaptor := helper.GetAdaptor(apiType)
@ -59,6 +62,12 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
} }
adaptor.Init(meta) adaptor.Init(meta)
modelName := adaptor.GetModelList()[0] modelName := adaptor.GetModelList()[0]
if !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 {
modelName = modelNames[0]
}
}
request := buildTestRequest() request := buildTestRequest()
request.Model = modelName request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName

View File

@ -6,6 +6,7 @@ import (
"github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/helper"
@ -122,6 +123,17 @@ func init() {
Parent: nil, Parent: nil,
}) })
} }
for _, modelName := range mistral.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "mistralai",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
openAIModelsMap = make(map[string]OpenAIModels) openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels { for _, model := range openAIModels {
openAIModelsMap[model.Id] = model openAIModelsMap[model.Id] = model

View File

@ -62,7 +62,7 @@ func Relay(c *gin.Context) {
retryTimes = 0 retryTimes = 0
} }
for i := retryTimes; i > 0; i-- { for i := retryTimes; i > 0; i-- {
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
if err != nil { if err != nil {
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
break break

View File

@ -68,7 +68,7 @@ func Distribute() func(c *gin.Context) {
} }
} }
requestModel = modelRequest.Model requestModel = modelRequest.Model
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
if err != nil { if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil { if channel != nil {

View File

@ -191,7 +191,7 @@ func SyncChannelCache(frequency int) {
} }
} }
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
if !config.MemoryCacheEnabled { if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model) return GetRandomSatisfiedChannel(group, model)
} }
@ -213,5 +213,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
} }
} }
idx := rand.Intn(endIdx) idx := rand.Intn(endIdx)
if ignoreFirstPriority {
if endIdx < len(channels) { // which means there are more than one priority
idx = common.RandRange(endIdx, len(channels))
}
}
return channels[idx], nil return channels[idx], nil
} }

View File

@ -33,6 +33,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
enableSearch = true enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
} }
if request.TopP >= 1 {
request.TopP = 0.9999
}
return &ChatRequest{ return &ChatRequest{
Model: aliModel, Model: aliModel,
Input: Input{ Input: Input{
@ -42,6 +45,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
EnableSearch: enableSearch, EnableSearch: enableSearch,
IncrementalOutput: request.Stream, IncrementalOutput: request.Stream,
Seed: uint64(request.Seed), Seed: uint64(request.Seed),
MaxTokens: request.MaxTokens,
Temperature: request.Temperature,
TopP: request.TopP,
}, },
} }
} }

View File

@ -16,6 +16,8 @@ type Parameters struct {
Seed uint64 `json:"seed,omitempty"` Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"` EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
} }
type ChatRequest struct { type ChatRequest struct {

View File

@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1": case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
default:
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + meta.ActualModelName
} }
var accessToken string var accessToken string
var err error var err error

View File

@ -1,7 +1,7 @@
package gemini package gemini
var ModelList = []string{ var ModelList = []string{
"gemini-pro", "gemini-pro", "gemini-1.0-pro-001",
"gemini-pro-vision", "gemini-pro-vision", "gemini-1.0-pro-vision-001",
"embedding-001", "embedding-001",
} }

View File

@ -0,0 +1,10 @@
package mistral
var ModelList = []string{
"open-mistral-7b",
"open-mixtral-8x7b",
"mistral-small-latest",
"mistral-medium-latest",
"mistral-large-latest",
"mistral-embed",
}

View File

@ -9,6 +9,7 @@ import (
"github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
@ -94,6 +95,8 @@ func (a *Adaptor) GetModelList() []string {
return baichuan.ModelList return baichuan.ModelList
case common.ChannelTypeMinimax: case common.ChannelTypeMinimax:
return minimax.ModelList return minimax.ModelList
case common.ChannelTypeMistral:
return mistral.ModelList
default: default:
return ModelList return ModelList
} }
@ -111,6 +114,8 @@ func (a *Adaptor) GetChannelName() string {
return "baichuan" return "baichuan"
case common.ChannelTypeMinimax: case common.ChannelTypeMinimax:
return "minimax" return "minimax"
case common.ChannelTypeMistral:
return "mistralai"
default: default:
return "openai" return "openai"
} }

24
relay/constant/image.go Normal file
View File

@ -0,0 +1,24 @@
package constant
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}

View File

@ -36,6 +36,65 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
return textRequest, nil return textRequest, nil
} }
func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) {
imageRequest := &openai.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
return imageRequest, nil
}
func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode {
// model validation
_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
if !hasValidSize {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
// check prompt length
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
// channel not azure
if meta.ChannelType != common.ChannelTypeAzure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
}
return nil
}
func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
if !hasValidSize {
return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size)
}
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
return imageCostRatio, nil
}
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode { switch relayMode {
case constant.RelayModeChatCompletions: case constant.RelayModeChatCompletions:
@ -113,10 +172,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R
if err != nil { if err != nil {
logger.Error(ctx, "error update user quota cache: "+err.Error()) logger.Error(ctx, "error update user quota cache: "+err.Error())
} }
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
@ -20,120 +21,65 @@ import (
) )
func isWithinRange(element string, value int) bool { func isWithinRange(element string, value int) bool {
if _, ok := common.DalleGenerationImageAmounts[element]; !ok { if _, ok := constant.DalleGenerationImageAmounts[element]; !ok {
return false return false
} }
min := common.DalleGenerationImageAmounts[element][0] min := constant.DalleGenerationImageAmounts[element][0]
max := common.DalleGenerationImageAmounts[element][1] max := constant.DalleGenerationImageAmounts[element][1]
return value >= min && value <= max return value >= min && value <= max
} }
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
imageModel := "dall-e-2" ctx := c.Request.Context()
imageSize := "1024x1024" meta := util.GetRelayMeta(c)
imageRequest, err := getImageRequest(c, meta.Mode)
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
var imageRequest openai.ImageRequest
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
} return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
if imageRequest.N == 0 {
imageRequest.N = 1
}
// Size validation
if imageRequest.Size != "" {
imageSize = imageRequest.Size
}
// Model validation
if imageRequest.Model != "" {
imageModel = imageRequest.Model
}
imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
// Check if model is supported
if hasValidSize {
if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
if imageSize == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
} else {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
// Prompt validation
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
// Check prompt length
if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageModel, imageRequest.N) {
// channel not azure
if channelType != common.ChannelTypeAzure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
} }
// map model name // map model name
modelMapping := c.GetString("model_mapping") var isModelMapped bool
isModelMapped := false meta.OriginModelName = imageRequest.Model
if modelMapping != "" { imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping)
modelMap := make(map[string]string) meta.ActualModelName = imageRequest.Model
err := json.Unmarshal([]byte(modelMapping), &modelMap)
// model validation
bizErr := validateImageRequest(imageRequest, meta)
if bizErr != nil {
return bizErr
}
imageCostRatio, err := getImageCostRatio(imageRequest)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
} }
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" { fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
baseURL = c.GetString("base_url") if meta.ChannelType == common.ChannelTypeAzure {
}
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := util.GetAzureAPIVersion(c) apiVersion := util.GetAzureAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion)
} }
var requestBody io.Reader var requestBody io.Reader
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest) jsonStr, err := json.Marshal(imageRequest)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
} }
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
} else { } else {
requestBody = c.Request.Body requestBody = c.Request.Body
} }
modelRatio := common.GetModelRatio(imageModel) modelRatio := common.GetModelRatio(imageRequest.Model)
groupRatio := common.GetGroupRatio(group) groupRatio := common.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(meta.UserId)
quota := int(ratio*imageCostRatio*1000) * imageRequest.N quota := int(ratio*imageCostRatio*1000) * imageRequest.N
@ -146,7 +92,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
} }
token := c.Request.Header.Get("Authorization") token := c.Request.Header.Get("Authorization")
if channelType == common.ChannelTypeAzure { // Azure authentication if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication
token = strings.TrimPrefix(token, "Bearer ") token = strings.TrimPrefix(token, "Bearer ")
req.Header.Set("api-key", token) req.Header.Set("api-key", token)
} else { } else {
@ -169,25 +115,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
} }
var textResponse openai.ImageResponse var imageResponse openai.ImageResponse
defer func(ctx context.Context) { defer func(ctx context.Context) {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return return
} }
err := model.PostConsumeTokenQuota(tokenId, quota) err := model.PostConsumeTokenQuota(meta.TokenId, quota)
if err != nil { if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error()) logger.SysError("error consuming token remain quota: " + err.Error())
} }
err = model.CacheUpdateUserQuota(userId) err = model.CacheUpdateUserQuota(meta.UserId)
if err != nil { if err != nil {
logger.SysError("error update user quota cache: " + err.Error()) logger.SysError("error update user quota cache: " + err.Error())
} }
if quota != 0 { if quota != 0 {
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
} }
@ -202,7 +148,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
} }
err = json.Unmarshal(responseBody, &textResponse) err = json.Unmarshal(responseBody, &imageResponse)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
} }

View File

@ -29,6 +29,12 @@ export const CHANNEL_OPTIONS = {
value: 24, value: 24,
color: 'orange' color: 'orange'
}, },
28: {
key: 28,
text: 'Mistral AI',
value: 28,
color: 'orange'
},
15: { 15: {
key: 15, key: 15,
text: '百度文心千帆', text: '百度文心千帆',

View File

@ -4,6 +4,7 @@ export const CHANNEL_OPTIONS = [
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
{ key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, { key: 24, text: 'Google Gemini', value: 24, color: 'orange' },
{ key: 28, text: 'Mistral AI', value: 28, color: 'orange' },
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },