✨ feat: add Stability AI
This commit is contained in:
parent
b20659dfcc
commit
303fe3360b
@ -82,6 +82,8 @@ _本项目是基于[one-api](https://github.com/songquanpeng/one-api)二次开
|
||||
| [零一万物](https://platform.lingyiwanwu.com/details) | ✅ | - | - | - | - |
|
||||
| [Cloudflare AI](https://ai.cloudflare.com/) | ✅ | - | ⚠️ stt | ⚠️ 图片生成 | - |
|
||||
| [Midjourney](https://www.midjourney.com/) | - | - | - | - | [midjourney-proxy](https://github.com/novicezk/midjourney-proxy) |
|
||||
| [Cohere](https://cohere.com/) | ✅ | - | - | - | - |
|
||||
| [Stability AI](https://platform.stability.ai/account/credits) | - | - | - | ⚠️ 图片生成 | - |
|
||||
|
||||
## 感谢
|
||||
|
||||
|
@ -174,6 +174,7 @@ const (
|
||||
ChannelTypeMidjourney = 34
|
||||
ChannelTypeCloudflareAI = 35
|
||||
ChannelTypeCohere = 36
|
||||
ChannelTypeStabilityAI = 37
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
@ -214,6 +215,7 @@ var ChannelBaseURLs = []string{
|
||||
"", //34
|
||||
"", //35
|
||||
"https://api.cohere.ai/v1", //36
|
||||
"https://api.stability.ai/v2beta", //37
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -299,6 +299,13 @@ func GetDefaultPrice() []*Price {
|
||||
"command-r": {[]float64{0.25, 0.75}, common.ChannelTypeCohere},
|
||||
//$3 /1M TOKENS $15/1M TOKENS
|
||||
"command-r-plus": {[]float64{1.5, 7.5}, common.ChannelTypeCohere},
|
||||
|
||||
// 0.065
|
||||
"sd3": {[]float64{32.5, 32.5}, common.ChannelTypeStabilityAI},
|
||||
// 0.04
|
||||
"sd3-turbo": {[]float64{20, 20}, common.ChannelTypeStabilityAI},
|
||||
// 0.03
|
||||
"stable-image-core": {[]float64{15, 15}, common.ChannelTypeStabilityAI},
|
||||
}
|
||||
|
||||
var prices []*Price
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"one-api/providers/mistral"
|
||||
"one-api/providers/openai"
|
||||
"one-api/providers/palm"
|
||||
"one-api/providers/stabilityAI"
|
||||
"one-api/providers/tencent"
|
||||
"one-api/providers/xunfei"
|
||||
"one-api/providers/zhipu"
|
||||
@ -58,6 +59,7 @@ func init() {
|
||||
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
|
||||
providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
|
||||
providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{}
|
||||
providerFactories[common.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{}
|
||||
|
||||
}
|
||||
|
||||
|
79
providers/stabilityAI/base.go
Normal file
79
providers/stabilityAI/base.go
Normal file
@ -0,0 +1,79 @@
|
||||
package stabilityAI
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type StabilityAIProviderFactory struct{}
|
||||
|
||||
// 创建 StabilityAIProvider
|
||||
func (f StabilityAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &StabilityAIProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type StabilityAIProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://api.stability.ai/v2beta",
|
||||
ImagesGenerations: "/stable-image/generate",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
stabilityAIError := &StabilityAIError{}
|
||||
err := json.NewDecoder(resp.Body).Decode(stabilityAIError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(stabilityAIError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(stabilityAIError *StabilityAIError) *types.OpenAIError {
|
||||
openaiError := &types.OpenAIError{
|
||||
Type: "stabilityAI_error",
|
||||
}
|
||||
|
||||
if stabilityAIError.Name != "" {
|
||||
openaiError.Message = stabilityAIError.String()
|
||||
openaiError.Code = stabilityAIError.Name
|
||||
} else {
|
||||
openaiError.Message = stabilityAIError.Message
|
||||
openaiError.Code = "stabilityAI_error"
|
||||
}
|
||||
|
||||
return openaiError
|
||||
}
|
||||
|
||||
func (p *StabilityAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
return fmt.Sprintf("%s%s/%s", baseURL, requestURL, modelName)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *StabilityAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
headers["Authorization"] = "Bearer " + p.Channel.Key
|
||||
|
||||
return headers
|
||||
}
|
87
providers/stabilityAI/image_generations.go
Normal file
87
providers/stabilityAI/image_generations.go
Normal file
@ -0,0 +1,87 @@
|
||||
package stabilityAI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/storage"
|
||||
"one-api/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
func convertModelName(modelName string) string {
|
||||
if modelName == "stable-image-core" {
|
||||
return "core"
|
||||
}
|
||||
|
||||
return "sd3"
|
||||
}
|
||||
|
||||
func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, convertModelName(request.Model))
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_stabilityAI_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
headers["Accept"] = "application/json; type=image/png"
|
||||
|
||||
var formBody bytes.Buffer
|
||||
builder := p.Requester.CreateFormBuilder(&formBody)
|
||||
builder.WriteField("prompt", request.Prompt)
|
||||
builder.WriteField("output_format", "png")
|
||||
if request.Model != "stable-image-core" {
|
||||
builder.WriteField("model", request.Model)
|
||||
}
|
||||
builder.Close()
|
||||
|
||||
req, err := p.Requester.NewRequest(
|
||||
http.MethodPost,
|
||||
fullRequestURL,
|
||||
p.Requester.WithBody(&formBody),
|
||||
p.Requester.WithHeader(headers),
|
||||
p.Requester.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
stabilityAIResponse := &generateResponse{}
|
||||
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, stabilityAIResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
openaiResponse := &types.ImageResponse{
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
|
||||
imgUrl := ""
|
||||
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
|
||||
body, err := base64.StdEncoding.DecodeString(stabilityAIResponse.Image)
|
||||
if err == nil {
|
||||
imgUrl = storage.Upload(body, common.GetUUID()+".png")
|
||||
}
|
||||
}
|
||||
|
||||
if imgUrl == "" {
|
||||
openaiResponse.Data = []types.ImageResponseDataInner{{B64JSON: stabilityAIResponse.Image}}
|
||||
} else {
|
||||
openaiResponse.Data = []types.ImageResponseDataInner{{URL: imgUrl}}
|
||||
}
|
||||
|
||||
p.Usage.PromptTokens = 1000
|
||||
|
||||
return openaiResponse, nil
|
||||
}
|
20
providers/stabilityAI/type.go
Normal file
20
providers/stabilityAI/type.go
Normal file
@ -0,0 +1,20 @@
|
||||
package stabilityAI
|
||||
|
||||
import "strings"
|
||||
|
||||
type StabilityAIError struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Success bool `json:"success,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
func (e StabilityAIError) String() string {
|
||||
return strings.Join(e.Errors, ", ")
|
||||
}
|
||||
|
||||
type generateResponse struct {
|
||||
Image string `json:"image"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
}
|
@ -27,5 +27,6 @@ func init() {
|
||||
common.ChannelTypeMidjourney: "Midjourney",
|
||||
common.ChannelTypeCloudflareAI: "Cloudflare AI",
|
||||
common.ChannelTypeCohere: "Cohere",
|
||||
common.ChannelTypeStabilityAI: "Stability AI",
|
||||
}
|
||||
}
|
||||
|
@ -153,6 +153,13 @@ export const CHANNEL_OPTIONS = {
|
||||
color: 'default',
|
||||
url: ''
|
||||
},
|
||||
37: {
|
||||
key: 37,
|
||||
text: 'Stability AI',
|
||||
value: 37,
|
||||
color: 'default',
|
||||
url: ''
|
||||
},
|
||||
24: {
|
||||
key: 24,
|
||||
text: 'Azure Speech',
|
||||
|
@ -291,6 +291,15 @@ const typeConfig = {
|
||||
test_model: 'command-r'
|
||||
},
|
||||
modelGroup: 'Cohere'
|
||||
},
|
||||
37: {
|
||||
input: {
|
||||
models: ['sd3', 'sd3-turbo', 'stable-image-core']
|
||||
},
|
||||
prompt: {
|
||||
test_model: ''
|
||||
},
|
||||
modelGroup: 'Stability AI'
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user