diff --git a/README.md b/README.md index 97e1ec8c..7db2db3b 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,8 @@ _本项目是基于[one-api](https://github.com/songquanpeng/one-api)二次开 | [零一万物](https://platform.lingyiwanwu.com/details) | ✅ | - | - | - | - | | [Cloudflare AI](https://ai.cloudflare.com/) | ✅ | - | ⚠️ stt | ⚠️ 图片生成 | - | | [Midjourney](https://www.midjourney.com/) | - | - | - | - | [midjourney-proxy](https://github.com/novicezk/midjourney-proxy) | +| [Cohere](https://cohere.com/) | ✅ | - | - | - | - | +| [Stability AI](https://platform.stability.ai/account/credits) | - | - | - | ⚠️ 图片生成 | - | ## 感谢 diff --git a/common/constants.go b/common/constants.go index 99e94f42..82dbb90c 100644 --- a/common/constants.go +++ b/common/constants.go @@ -174,6 +174,7 @@ const ( ChannelTypeMidjourney = 34 ChannelTypeCloudflareAI = 35 ChannelTypeCohere = 36 + ChannelTypeStabilityAI = 37 ) var ChannelBaseURLs = []string{ @@ -214,6 +215,7 @@ var ChannelBaseURLs = []string{ "", //34 "", //35 "https://api.cohere.ai/v1", //36 + "https://api.stability.ai/v2beta", //37 } const ( diff --git a/model/price.go b/model/price.go index 4937bc89..f0cf1582 100644 --- a/model/price.go +++ b/model/price.go @@ -299,6 +299,13 @@ func GetDefaultPrice() []*Price { "command-r": {[]float64{0.25, 0.75}, common.ChannelTypeCohere}, //$3 /1M TOKENS $15/1M TOKENS "command-r-plus": {[]float64{1.5, 7.5}, common.ChannelTypeCohere}, + + // 0.065 + "sd3": {[]float64{32.5, 32.5}, common.ChannelTypeStabilityAI}, + // 0.04 + "sd3-turbo": {[]float64{20, 20}, common.ChannelTypeStabilityAI}, + // 0.03 + "stable-image-core": {[]float64{15, 15}, common.ChannelTypeStabilityAI}, } var prices []*Price diff --git a/providers/providers.go b/providers/providers.go index 149460c1..bce39afd 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -21,6 +21,7 @@ import ( "one-api/providers/mistral" "one-api/providers/openai" "one-api/providers/palm" + "one-api/providers/stabilityAI" "one-api/providers/tencent" "one-api/providers/xunfei" "one-api/providers/zhipu" @@ -58,6 +59,7 @@ func init() { providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{} providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{} providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{} + providerFactories[common.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{} } diff --git a/providers/stabilityAI/base.go b/providers/stabilityAI/base.go new file mode 100644 index 00000000..c339a99e --- /dev/null +++ b/providers/stabilityAI/base.go @@ -0,0 +1,79 @@ +package stabilityAI + +import ( + "encoding/json" + "fmt" + "net/http" + "one-api/common/requester" + "one-api/model" + "one-api/providers/base" + "one-api/types" + "strings" +) + +type StabilityAIProviderFactory struct{} + +// 创建 StabilityAIProvider +func (f StabilityAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + return &StabilityAIProvider{ + BaseProvider: base.BaseProvider{ + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle), + }, + } +} + +type StabilityAIProvider struct { + base.BaseProvider +} + +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://api.stability.ai/v2beta", + ImagesGenerations: "/stable-image/generate", + } +} + +// 请求错误处理 +func requestErrorHandle(resp *http.Response) *types.OpenAIError { + stabilityAIError := &StabilityAIError{} + err := json.NewDecoder(resp.Body).Decode(stabilityAIError) + if err != nil { + return nil + } + + return errorHandle(stabilityAIError) +} + +// 错误处理 +func errorHandle(stabilityAIError *StabilityAIError) *types.OpenAIError { + openaiError := &types.OpenAIError{ + Type: "stabilityAI_error", + } + + if stabilityAIError.Name != "" { + openaiError.Message = stabilityAIError.String() + openaiError.Code = stabilityAIError.Name + } else { + openaiError.Message = stabilityAIError.Message + openaiError.Code = "stabilityAI_error" + } + + return openaiError +} + +func (p *StabilityAIProvider) GetFullRequestURL(requestURL string, modelName string) string { + baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") + + return fmt.Sprintf("%s%s/%s", baseURL, requestURL, modelName) +} + +// 获取请求头 +func (p *StabilityAIProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + p.CommonRequestHeaders(headers) + headers["Authorization"] = "Bearer " + p.Channel.Key + + return headers +} diff --git a/providers/stabilityAI/image_generations.go b/providers/stabilityAI/image_generations.go new file mode 100644 index 00000000..52987c95 --- /dev/null +++ b/providers/stabilityAI/image_generations.go @@ -0,0 +1,87 @@ +package stabilityAI + +import ( + "bytes" + "encoding/base64" + "net/http" + "one-api/common" + "one-api/common/storage" + "one-api/types" + "time" +) + +func convertModelName(modelName string) string { + if modelName == "stable-image-core" { + return "core" + } + + return "sd3" +} + +func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations) + if errWithCode != nil { + return nil, errWithCode + } + + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url, convertModelName(request.Model)) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_stabilityAI_config", http.StatusInternalServerError) + } + + // 获取请求头 + headers := p.GetRequestHeaders() + headers["Accept"] = "application/json; type=image/png" + + var formBody bytes.Buffer + builder := p.Requester.CreateFormBuilder(&formBody) + builder.WriteField("prompt", request.Prompt) + builder.WriteField("output_format", "png") + if request.Model != "stable-image-core" { + builder.WriteField("model", request.Model) + } + builder.Close() + + req, err := p.Requester.NewRequest( + http.MethodPost, + fullRequestURL, + p.Requester.WithBody(&formBody), + p.Requester.WithHeader(headers), + p.Requester.WithContentType(builder.FormDataContentType())) + req.ContentLength = int64(formBody.Len()) + + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + stabilityAIResponse := &generateResponse{} + + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, stabilityAIResponse, false) + if errWithCode != nil { + return nil, errWithCode + } + + openaiResponse := &types.ImageResponse{ + Created: time.Now().Unix(), + } + + imgUrl := "" + if request.ResponseFormat == "" || request.ResponseFormat == "url" { + body, err := base64.StdEncoding.DecodeString(stabilityAIResponse.Image) + if err == nil { + imgUrl = storage.Upload(body, common.GetUUID()+".png") + } + } + + if imgUrl == "" { + openaiResponse.Data = []types.ImageResponseDataInner{{B64JSON: stabilityAIResponse.Image}} + } else { + openaiResponse.Data = []types.ImageResponseDataInner{{URL: imgUrl}} + } + + p.Usage.PromptTokens = 1000 + + return openaiResponse, nil +} diff --git a/providers/stabilityAI/type.go b/providers/stabilityAI/type.go new file mode 100644 index 00000000..03f314ba --- /dev/null +++ b/providers/stabilityAI/type.go @@ -0,0 +1,20 @@ +package stabilityAI + +import "strings" + +type StabilityAIError struct { + Name string `json:"name,omitempty"` + Errors []string `json:"errors,omitempty"` + Success bool `json:"success,omitempty"` + Message string `json:"message,omitempty"` +} + +func (e StabilityAIError) String() string { + return strings.Join(e.Errors, ", ") +} + +type generateResponse struct { + Image string `json:"image"` + FinishReason string `json:"finish_reason,omitempty"` + Seed int `json:"seed,omitempty"` +} diff --git a/relay/util/type.go b/relay/util/type.go index 6606dfad..1cea1065 100644 --- a/relay/util/type.go +++ b/relay/util/type.go @@ -27,5 +27,6 @@ func init() { common.ChannelTypeMidjourney: "Midjourney", common.ChannelTypeCloudflareAI: "Cloudflare AI", common.ChannelTypeCohere: "Cohere", + common.ChannelTypeStabilityAI: "Stability AI", } } diff --git a/web/src/constants/ChannelConstants.js b/web/src/constants/ChannelConstants.js index 818df56c..179ea070 100644 --- a/web/src/constants/ChannelConstants.js +++ b/web/src/constants/ChannelConstants.js @@ -153,6 +153,13 @@ export const CHANNEL_OPTIONS = { color: 'default', url: '' }, + 37: { + key: 37, + text: 'Stability AI', + value: 37, + color: 'default', + url: '' + }, 24: { key: 24, text: 'Azure Speech', diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index 7a1a0a31..5c6416a1 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -291,6 +291,15 @@ const typeConfig = { test_model: 'command-r' }, modelGroup: 'Cohere' + }, + 37: { + input: { + models: ['sd3', 'sd3-turbo', 'stable-image-core'] + }, + prompt: { + test_model: '' + }, + modelGroup: 'Stability AI' } };