feat: Add image upload

This commit is contained in:
Martial BE 2024-04-18 15:29:25 +08:00
parent 8e8d4a3a84
commit cfa68df4aa
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
8 changed files with 254 additions and 4 deletions

View File

@ -0,0 +1,90 @@
package drives
import (
"bytes"
"fmt"
"net/http"
"one-api/common/requester"
)
var smUploadURL = "https://sm.ms/api/v2/upload"
type SMUpload struct {
Secret string
}
type SMData struct {
URL string `json:"url"`
// FileID int `json:"file_id"`
// Width int `json:"width"`
// Height int `json:"height"`
// Filename string `json:"filename"`
// Storename string `json:"storename"`
// Size int `json:"size"`
// Path string `json:"path"`
// Hash string `json:"hash"`
// Delete string `json:"delete"`
// Page string `json:"page"`
}
type SMResponse struct {
Success bool `json:"success"`
Code string `json:"code"`
Message string `json:"message"`
Data SMData `json:"data"`
RequestID string `json:"RequestId"`
}
func NewSMUpload(secret string) *SMUpload {
return &SMUpload{
Secret: secret,
}
}
func (sm *SMUpload) Name() string {
return "SM.MS"
}
func (sm *SMUpload) Upload(data []byte, fileName string) (string, error) {
client := requester.NewHTTPRequester("", nil)
var formBody bytes.Buffer
builder := client.CreateFormBuilder(&formBody)
err := builder.CreateFormFileReader("smfile", bytes.NewReader(data), fileName)
if err != nil {
return "", fmt.Errorf("creating form file: %w", err)
}
builder.WriteField("format", "json")
headers := map[string]string{
"Content-type": "application/json",
"Authorization": sm.Secret,
}
req, err := client.NewRequest(
http.MethodPost,
smUploadURL,
client.WithBody(&formBody),
client.WithHeader(headers),
client.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
if err != nil {
return "", fmt.Errorf("new request failed: %w", err)
}
defer req.Body.Close()
smResponse := &SMResponse{}
_, errWithCode := client.SendRequest(req, smResponse, false)
if errWithCode != nil {
return "", fmt.Errorf("%s", errWithCode.Message)
}
if !smResponse.Success {
return "", fmt.Errorf("upload failed: %s", smResponse.Message)
}
return smResponse.Data.URL, nil
}

25
common/storage/storage.go Normal file
View File

@ -0,0 +1,25 @@
package storage
import (
"one-api/common/storage/drives"
"github.com/spf13/viper"
)
type Storage struct {
drives map[string]StorageDrive
}
func InitStorage() {
InitSMStorage()
}
func InitSMStorage() {
smSecret := viper.GetString("storage.smms.secret")
if smSecret == "" {
return
}
smUpload := drives.NewSMUpload(smSecret)
AddStorageDrive(smUpload)
}

View File

@ -0,0 +1,36 @@
package storage
var storageDrives = New()
type StorageDrive interface {
Upload(data []byte, fileName string) (string, error)
Name() string
}
func New() *Storage {
storageDrive := &Storage{
drives: make(map[string]StorageDrive, 0),
}
return storageDrive
}
func AddStorageDrive(drives ...StorageDrive) {
storageDrives.addDrives(drives...)
}
func (s *Storage) addDrives(drives ...StorageDrive) {
for _, d := range drives {
s.addDrive(d)
}
}
func (s *Storage) addDrive(drive StorageDrive) {
if drive != nil {
driveName := drive.Name()
if _, ok := s.drives[driveName]; ok {
return
}
s.drives[driveName] = drive
}
}

File diff suppressed because one or more lines are too long

34
common/storage/upload.go Normal file
View File

@ -0,0 +1,34 @@
package storage
import (
"context"
"fmt"
"one-api/common"
)
func (s *Storage) Upload(ctx context.Context, data []byte, fileName string) string {
if ctx == nil {
ctx = context.Background()
}
for driveName, drive := range s.drives {
if drive == nil {
continue
}
url, err := drive.Upload(data, fileName)
if err != nil {
common.LogError(ctx, fmt.Sprintf("%s err: %s", driveName, err.Error()))
} else {
return url
}
}
return ""
}
func Upload(data []byte, fileName string) string {
//lint:ignore SA1029 reason: 需要使用该类型作为错误处理
ctx := context.WithValue(context.Background(), common.RequestIdKey, "Upload")
return storageDrives.Upload(ctx, data, fileName)
}

View File

@ -60,4 +60,7 @@ notify: # 通知设置, 配置了几个通知方式,就会同时发送几次
pushkey: "" # pushkey
telegram: # Telegram 通知
bot_api_key: "" # 你的 Telegram bot 的 API 密钥
chat_id: "" # 你的 Telegram chat_id
chat_id: "" # 你的 Telegram chat_id
storage: # 存储设置 (可选,主要用于图片生成有些供应商不提供url只能返回base64图片设置后可以正常返回url格式的图片生成)
smms: # sm.ms 图床设置
secret: "" # 你的 sm.ms API 密钥

View File

@ -7,6 +7,7 @@ import (
"one-api/common/config"
"one-api/common/notify"
"one-api/common/requester"
"one-api/common/storage"
"one-api/common/telegram"
"one-api/controller"
"one-api/cron"
@ -50,6 +51,7 @@ func main() {
controller.InitMidjourneyTask()
notify.InitNotifier()
cron.InitCron()
storage.InitStorage()
initHttpServer()
}

View File

@ -5,6 +5,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/common/storage"
"one-api/types"
"time"
)
@ -43,16 +44,25 @@ func (p *CloudflareAIProvider) CreateImageGenerations(request *types.ImageReques
return nil, common.ErrorWrapper(err, "read_response_failed", http.StatusInternalServerError)
}
base64Image := base64.StdEncoding.EncodeToString(body)
url := ""
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
url = storage.Upload(body, common.GetUUID()+".png")
}
openaiResponse := &types.ImageResponse{
Created: time.Now().Unix(),
Data: []types.ImageResponseDataInner{{B64JSON: base64Image}},
}
if url == "" {
base64Image := base64.StdEncoding.EncodeToString(body)
openaiResponse.Data = []types.ImageResponseDataInner{{B64JSON: base64Image}}
} else {
openaiResponse.Data = []types.ImageResponseDataInner{{URL: url}}
}
p.Usage.PromptTokens = 1000
return openaiResponse, nil
}
func convertFromIamgeOpenai(request *types.ImageRequest) *ImageRequest {