✨ feat: add Stability AI
This commit is contained in:
parent
b20659dfcc
commit
303fe3360b
@ -82,6 +82,8 @@ _本项目是基于[one-api](https://github.com/songquanpeng/one-api)二次开
|
|||||||
| [零一万物](https://platform.lingyiwanwu.com/details) | ✅ | - | - | - | - |
|
| [零一万物](https://platform.lingyiwanwu.com/details) | ✅ | - | - | - | - |
|
||||||
| [Cloudflare AI](https://ai.cloudflare.com/) | ✅ | - | ⚠️ stt | ⚠️ 图片生成 | - |
|
| [Cloudflare AI](https://ai.cloudflare.com/) | ✅ | - | ⚠️ stt | ⚠️ 图片生成 | - |
|
||||||
| [Midjourney](https://www.midjourney.com/) | - | - | - | - | [midjourney-proxy](https://github.com/novicezk/midjourney-proxy) |
|
| [Midjourney](https://www.midjourney.com/) | - | - | - | - | [midjourney-proxy](https://github.com/novicezk/midjourney-proxy) |
|
||||||
|
| [Cohere](https://cohere.com/) | ✅ | - | - | - | - |
|
||||||
|
| [Stability AI](https://platform.stability.ai/account/credits) | - | - | - | ⚠️ 图片生成 | - |
|
||||||
|
|
||||||
## 感谢
|
## 感谢
|
||||||
|
|
||||||
|
@ -174,6 +174,7 @@ const (
|
|||||||
ChannelTypeMidjourney = 34
|
ChannelTypeMidjourney = 34
|
||||||
ChannelTypeCloudflareAI = 35
|
ChannelTypeCloudflareAI = 35
|
||||||
ChannelTypeCohere = 36
|
ChannelTypeCohere = 36
|
||||||
|
ChannelTypeStabilityAI = 37
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@ -214,6 +215,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"", //34
|
"", //34
|
||||||
"", //35
|
"", //35
|
||||||
"https://api.cohere.ai/v1", //36
|
"https://api.cohere.ai/v1", //36
|
||||||
|
"https://api.stability.ai/v2beta", //37
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -299,6 +299,13 @@ func GetDefaultPrice() []*Price {
|
|||||||
"command-r": {[]float64{0.25, 0.75}, common.ChannelTypeCohere},
|
"command-r": {[]float64{0.25, 0.75}, common.ChannelTypeCohere},
|
||||||
//$3 /1M TOKENS $15/1M TOKENS
|
//$3 /1M TOKENS $15/1M TOKENS
|
||||||
"command-r-plus": {[]float64{1.5, 7.5}, common.ChannelTypeCohere},
|
"command-r-plus": {[]float64{1.5, 7.5}, common.ChannelTypeCohere},
|
||||||
|
|
||||||
|
// 0.065
|
||||||
|
"sd3": {[]float64{32.5, 32.5}, common.ChannelTypeStabilityAI},
|
||||||
|
// 0.04
|
||||||
|
"sd3-turbo": {[]float64{20, 20}, common.ChannelTypeStabilityAI},
|
||||||
|
// 0.03
|
||||||
|
"stable-image-core": {[]float64{15, 15}, common.ChannelTypeStabilityAI},
|
||||||
}
|
}
|
||||||
|
|
||||||
var prices []*Price
|
var prices []*Price
|
||||||
|
@ -21,6 +21,7 @@ import (
|
|||||||
"one-api/providers/mistral"
|
"one-api/providers/mistral"
|
||||||
"one-api/providers/openai"
|
"one-api/providers/openai"
|
||||||
"one-api/providers/palm"
|
"one-api/providers/palm"
|
||||||
|
"one-api/providers/stabilityAI"
|
||||||
"one-api/providers/tencent"
|
"one-api/providers/tencent"
|
||||||
"one-api/providers/xunfei"
|
"one-api/providers/xunfei"
|
||||||
"one-api/providers/zhipu"
|
"one-api/providers/zhipu"
|
||||||
@ -58,6 +59,7 @@ func init() {
|
|||||||
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
|
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
|
||||||
providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
|
providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
|
||||||
providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{}
|
providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{}
|
||||||
|
providerFactories[common.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
79
providers/stabilityAI/base.go
Normal file
79
providers/stabilityAI/base.go
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
package stabilityAI
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common/requester"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StabilityAIProviderFactory struct{}
|
||||||
|
|
||||||
|
// 创建 StabilityAIProvider
|
||||||
|
func (f StabilityAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||||
|
return &StabilityAIProvider{
|
||||||
|
BaseProvider: base.BaseProvider{
|
||||||
|
Config: getConfig(),
|
||||||
|
Channel: channel,
|
||||||
|
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type StabilityAIProvider struct {
|
||||||
|
base.BaseProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
func getConfig() base.ProviderConfig {
|
||||||
|
return base.ProviderConfig{
|
||||||
|
BaseURL: "https://api.stability.ai/v2beta",
|
||||||
|
ImagesGenerations: "/stable-image/generate",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 请求错误处理
|
||||||
|
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||||
|
stabilityAIError := &StabilityAIError{}
|
||||||
|
err := json.NewDecoder(resp.Body).Decode(stabilityAIError)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return errorHandle(stabilityAIError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 错误处理
|
||||||
|
func errorHandle(stabilityAIError *StabilityAIError) *types.OpenAIError {
|
||||||
|
openaiError := &types.OpenAIError{
|
||||||
|
Type: "stabilityAI_error",
|
||||||
|
}
|
||||||
|
|
||||||
|
if stabilityAIError.Name != "" {
|
||||||
|
openaiError.Message = stabilityAIError.String()
|
||||||
|
openaiError.Code = stabilityAIError.Name
|
||||||
|
} else {
|
||||||
|
openaiError.Message = stabilityAIError.Message
|
||||||
|
openaiError.Code = "stabilityAI_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
return openaiError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *StabilityAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s%s/%s", baseURL, requestURL, modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求头
|
||||||
|
func (p *StabilityAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
p.CommonRequestHeaders(headers)
|
||||||
|
headers["Authorization"] = "Bearer " + p.Channel.Key
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
87
providers/stabilityAI/image_generations.go
Normal file
87
providers/stabilityAI/image_generations.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package stabilityAI
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/common/storage"
|
||||||
|
"one-api/types"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func convertModelName(modelName string) string {
|
||||||
|
if modelName == "stable-image-core" {
|
||||||
|
return "core"
|
||||||
|
}
|
||||||
|
|
||||||
|
return "sd3"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return nil, errWithCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求地址
|
||||||
|
fullRequestURL := p.GetFullRequestURL(url, convertModelName(request.Model))
|
||||||
|
if fullRequestURL == "" {
|
||||||
|
return nil, common.ErrorWrapper(nil, "invalid_stabilityAI_config", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求头
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
headers["Accept"] = "application/json; type=image/png"
|
||||||
|
|
||||||
|
var formBody bytes.Buffer
|
||||||
|
builder := p.Requester.CreateFormBuilder(&formBody)
|
||||||
|
builder.WriteField("prompt", request.Prompt)
|
||||||
|
builder.WriteField("output_format", "png")
|
||||||
|
if request.Model != "stable-image-core" {
|
||||||
|
builder.WriteField("model", request.Model)
|
||||||
|
}
|
||||||
|
builder.Close()
|
||||||
|
|
||||||
|
req, err := p.Requester.NewRequest(
|
||||||
|
http.MethodPost,
|
||||||
|
fullRequestURL,
|
||||||
|
p.Requester.WithBody(&formBody),
|
||||||
|
p.Requester.WithHeader(headers),
|
||||||
|
p.Requester.WithContentType(builder.FormDataContentType()))
|
||||||
|
req.ContentLength = int64(formBody.Len())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
stabilityAIResponse := &generateResponse{}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
_, errWithCode = p.Requester.SendRequest(req, stabilityAIResponse, false)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return nil, errWithCode
|
||||||
|
}
|
||||||
|
|
||||||
|
openaiResponse := &types.ImageResponse{
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
imgUrl := ""
|
||||||
|
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
|
||||||
|
body, err := base64.StdEncoding.DecodeString(stabilityAIResponse.Image)
|
||||||
|
if err == nil {
|
||||||
|
imgUrl = storage.Upload(body, common.GetUUID()+".png")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if imgUrl == "" {
|
||||||
|
openaiResponse.Data = []types.ImageResponseDataInner{{B64JSON: stabilityAIResponse.Image}}
|
||||||
|
} else {
|
||||||
|
openaiResponse.Data = []types.ImageResponseDataInner{{URL: imgUrl}}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Usage.PromptTokens = 1000
|
||||||
|
|
||||||
|
return openaiResponse, nil
|
||||||
|
}
|
20
providers/stabilityAI/type.go
Normal file
20
providers/stabilityAI/type.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package stabilityAI
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
type StabilityAIError struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Errors []string `json:"errors,omitempty"`
|
||||||
|
Success bool `json:"success,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e StabilityAIError) String() string {
|
||||||
|
return strings.Join(e.Errors, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
type generateResponse struct {
|
||||||
|
Image string `json:"image"`
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"`
|
||||||
|
Seed int `json:"seed,omitempty"`
|
||||||
|
}
|
@ -27,5 +27,6 @@ func init() {
|
|||||||
common.ChannelTypeMidjourney: "Midjourney",
|
common.ChannelTypeMidjourney: "Midjourney",
|
||||||
common.ChannelTypeCloudflareAI: "Cloudflare AI",
|
common.ChannelTypeCloudflareAI: "Cloudflare AI",
|
||||||
common.ChannelTypeCohere: "Cohere",
|
common.ChannelTypeCohere: "Cohere",
|
||||||
|
common.ChannelTypeStabilityAI: "Stability AI",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -153,6 +153,13 @@ export const CHANNEL_OPTIONS = {
|
|||||||
color: 'default',
|
color: 'default',
|
||||||
url: ''
|
url: ''
|
||||||
},
|
},
|
||||||
|
37: {
|
||||||
|
key: 37,
|
||||||
|
text: 'Stability AI',
|
||||||
|
value: 37,
|
||||||
|
color: 'default',
|
||||||
|
url: ''
|
||||||
|
},
|
||||||
24: {
|
24: {
|
||||||
key: 24,
|
key: 24,
|
||||||
text: 'Azure Speech',
|
text: 'Azure Speech',
|
||||||
|
@ -291,6 +291,15 @@ const typeConfig = {
|
|||||||
test_model: 'command-r'
|
test_model: 'command-r'
|
||||||
},
|
},
|
||||||
modelGroup: 'Cohere'
|
modelGroup: 'Cohere'
|
||||||
|
},
|
||||||
|
37: {
|
||||||
|
input: {
|
||||||
|
models: ['sd3', 'sd3-turbo', 'stable-image-core']
|
||||||
|
},
|
||||||
|
prompt: {
|
||||||
|
test_model: ''
|
||||||
|
},
|
||||||
|
modelGroup: 'Stability AI'
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user