feat: add Stability AI

This commit is contained in:
Martial BE 2024-04-18 18:53:49 +08:00
parent b20659dfcc
commit 303fe3360b
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
10 changed files with 216 additions and 0 deletions

View File

@ -82,6 +82,8 @@ _本项目是基于[one-api](https://github.com/songquanpeng/one-api)二次开
| [零一万物](https://platform.lingyiwanwu.com/details) | ✅ | - | - | - | - |
| [Cloudflare AI](https://ai.cloudflare.com/) | ✅ | - | ⚠️ stt | ⚠️ 图片生成 | - |
| [Midjourney](https://www.midjourney.com/) | - | - | - | - | [midjourney-proxy](https://github.com/novicezk/midjourney-proxy) |
| [Cohere](https://cohere.com/) | ✅ | - | - | - | - |
| [Stability AI](https://platform.stability.ai/account/credits) | - | - | - | ⚠️ 图片生成 | - |
## 感谢

View File

@ -174,6 +174,7 @@ const (
ChannelTypeMidjourney = 34
ChannelTypeCloudflareAI = 35
ChannelTypeCohere = 36
ChannelTypeStabilityAI = 37
)
var ChannelBaseURLs = []string{
@ -214,6 +215,7 @@ var ChannelBaseURLs = []string{
"", //34
"", //35
"https://api.cohere.ai/v1", //36
"https://api.stability.ai/v2beta", //37
}
const (

View File

@ -299,6 +299,13 @@ func GetDefaultPrice() []*Price {
"command-r": {[]float64{0.25, 0.75}, common.ChannelTypeCohere},
//$3 /1M TOKENS $15/1M TOKENS
"command-r-plus": {[]float64{1.5, 7.5}, common.ChannelTypeCohere},
// 0.065
"sd3": {[]float64{32.5, 32.5}, common.ChannelTypeStabilityAI},
// 0.04
"sd3-turbo": {[]float64{20, 20}, common.ChannelTypeStabilityAI},
// 0.03
"stable-image-core": {[]float64{15, 15}, common.ChannelTypeStabilityAI},
}
var prices []*Price

View File

@ -21,6 +21,7 @@ import (
"one-api/providers/mistral"
"one-api/providers/openai"
"one-api/providers/palm"
"one-api/providers/stabilityAI"
"one-api/providers/tencent"
"one-api/providers/xunfei"
"one-api/providers/zhipu"
@ -58,6 +59,7 @@ func init() {
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{}
providerFactories[common.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{}
}

View File

@ -0,0 +1,79 @@
package stabilityAI
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
)
type StabilityAIProviderFactory struct{}
// 创建 StabilityAIProvider
func (f StabilityAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &StabilityAIProvider{
BaseProvider: base.BaseProvider{
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
},
}
}
type StabilityAIProvider struct {
base.BaseProvider
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://api.stability.ai/v2beta",
ImagesGenerations: "/stable-image/generate",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
stabilityAIError := &StabilityAIError{}
err := json.NewDecoder(resp.Body).Decode(stabilityAIError)
if err != nil {
return nil
}
return errorHandle(stabilityAIError)
}
// 错误处理
func errorHandle(stabilityAIError *StabilityAIError) *types.OpenAIError {
openaiError := &types.OpenAIError{
Type: "stabilityAI_error",
}
if stabilityAIError.Name != "" {
openaiError.Message = stabilityAIError.String()
openaiError.Code = stabilityAIError.Name
} else {
openaiError.Message = stabilityAIError.Message
openaiError.Code = "stabilityAI_error"
}
return openaiError
}
func (p *StabilityAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
return fmt.Sprintf("%s%s/%s", baseURL, requestURL, modelName)
}
// 获取请求头
func (p *StabilityAIProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = "Bearer " + p.Channel.Key
return headers
}

View File

@ -0,0 +1,87 @@
package stabilityAI
import (
"bytes"
"encoding/base64"
"net/http"
"one-api/common"
"one-api/common/storage"
"one-api/types"
"time"
)
func convertModelName(modelName string) string {
if modelName == "stable-image-core" {
return "core"
}
return "sd3"
}
func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, convertModelName(request.Model))
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_stabilityAI_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
headers["Accept"] = "application/json; type=image/png"
var formBody bytes.Buffer
builder := p.Requester.CreateFormBuilder(&formBody)
builder.WriteField("prompt", request.Prompt)
builder.WriteField("output_format", "png")
if request.Model != "stable-image-core" {
builder.WriteField("model", request.Model)
}
builder.Close()
req, err := p.Requester.NewRequest(
http.MethodPost,
fullRequestURL,
p.Requester.WithBody(&formBody),
p.Requester.WithHeader(headers),
p.Requester.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
stabilityAIResponse := &generateResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, stabilityAIResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
openaiResponse := &types.ImageResponse{
Created: time.Now().Unix(),
}
imgUrl := ""
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
body, err := base64.StdEncoding.DecodeString(stabilityAIResponse.Image)
if err == nil {
imgUrl = storage.Upload(body, common.GetUUID()+".png")
}
}
if imgUrl == "" {
openaiResponse.Data = []types.ImageResponseDataInner{{B64JSON: stabilityAIResponse.Image}}
} else {
openaiResponse.Data = []types.ImageResponseDataInner{{URL: imgUrl}}
}
p.Usage.PromptTokens = 1000
return openaiResponse, nil
}

View File

@ -0,0 +1,20 @@
package stabilityAI
import "strings"
type StabilityAIError struct {
Name string `json:"name,omitempty"`
Errors []string `json:"errors,omitempty"`
Success bool `json:"success,omitempty"`
Message string `json:"message,omitempty"`
}
func (e StabilityAIError) String() string {
return strings.Join(e.Errors, ", ")
}
type generateResponse struct {
Image string `json:"image"`
FinishReason string `json:"finish_reason,omitempty"`
Seed int `json:"seed,omitempty"`
}

View File

@ -27,5 +27,6 @@ func init() {
common.ChannelTypeMidjourney: "Midjourney",
common.ChannelTypeCloudflareAI: "Cloudflare AI",
common.ChannelTypeCohere: "Cohere",
common.ChannelTypeStabilityAI: "Stability AI",
}
}

View File

@ -153,6 +153,13 @@ export const CHANNEL_OPTIONS = {
color: 'default',
url: ''
},
37: {
key: 37,
text: 'Stability AI',
value: 37,
color: 'default',
url: ''
},
24: {
key: 24,
text: 'Azure Speech',

View File

@ -291,6 +291,15 @@ const typeConfig = {
test_model: 'command-r'
},
modelGroup: 'Cohere'
},
37: {
input: {
models: ['sd3', 'sd3-turbo', 'stable-image-core']
},
prompt: {
test_model: ''
},
modelGroup: 'Stability AI'
}
};