✨ feat: Add image upload
This commit is contained in:
parent
8e8d4a3a84
commit
cfa68df4aa
90
common/storage/drives/sm.go
Normal file
90
common/storage/drives/sm.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package drives
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common/requester"
|
||||||
|
)
|
||||||
|
|
||||||
|
var smUploadURL = "https://sm.ms/api/v2/upload"
|
||||||
|
|
||||||
|
type SMUpload struct {
|
||||||
|
Secret string
|
||||||
|
}
|
||||||
|
|
||||||
|
type SMData struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
// FileID int `json:"file_id"`
|
||||||
|
// Width int `json:"width"`
|
||||||
|
// Height int `json:"height"`
|
||||||
|
// Filename string `json:"filename"`
|
||||||
|
// Storename string `json:"storename"`
|
||||||
|
// Size int `json:"size"`
|
||||||
|
// Path string `json:"path"`
|
||||||
|
// Hash string `json:"hash"`
|
||||||
|
// Delete string `json:"delete"`
|
||||||
|
// Page string `json:"page"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SMResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data SMData `json:"data"`
|
||||||
|
RequestID string `json:"RequestId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSMUpload(secret string) *SMUpload {
|
||||||
|
return &SMUpload{
|
||||||
|
Secret: secret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *SMUpload) Name() string {
|
||||||
|
return "SM.MS"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *SMUpload) Upload(data []byte, fileName string) (string, error) {
|
||||||
|
client := requester.NewHTTPRequester("", nil)
|
||||||
|
|
||||||
|
var formBody bytes.Buffer
|
||||||
|
builder := client.CreateFormBuilder(&formBody)
|
||||||
|
|
||||||
|
err := builder.CreateFormFileReader("smfile", bytes.NewReader(data), fileName)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("creating form file: %w", err)
|
||||||
|
}
|
||||||
|
builder.WriteField("format", "json")
|
||||||
|
|
||||||
|
headers := map[string]string{
|
||||||
|
"Content-type": "application/json",
|
||||||
|
"Authorization": sm.Secret,
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := client.NewRequest(
|
||||||
|
http.MethodPost,
|
||||||
|
smUploadURL,
|
||||||
|
client.WithBody(&formBody),
|
||||||
|
client.WithHeader(headers),
|
||||||
|
client.WithContentType(builder.FormDataContentType()))
|
||||||
|
req.ContentLength = int64(formBody.Len())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("new request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer req.Body.Close()
|
||||||
|
|
||||||
|
smResponse := &SMResponse{}
|
||||||
|
_, errWithCode := client.SendRequest(req, smResponse, false)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return "", fmt.Errorf("%s", errWithCode.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !smResponse.Success {
|
||||||
|
return "", fmt.Errorf("upload failed: %s", smResponse.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return smResponse.Data.URL, nil
|
||||||
|
}
|
25
common/storage/storage.go
Normal file
25
common/storage/storage.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/common/storage/drives"
|
||||||
|
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Storage struct {
|
||||||
|
drives map[string]StorageDrive
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitStorage() {
|
||||||
|
InitSMStorage()
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitSMStorage() {
|
||||||
|
smSecret := viper.GetString("storage.smms.secret")
|
||||||
|
if smSecret == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
smUpload := drives.NewSMUpload(smSecret)
|
||||||
|
AddStorageDrive(smUpload)
|
||||||
|
}
|
36
common/storage/storageDrive.go
Normal file
36
common/storage/storageDrive.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
var storageDrives = New()
|
||||||
|
|
||||||
|
type StorageDrive interface {
|
||||||
|
Upload(data []byte, fileName string) (string, error)
|
||||||
|
Name() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *Storage {
|
||||||
|
storageDrive := &Storage{
|
||||||
|
drives: make(map[string]StorageDrive, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
return storageDrive
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddStorageDrive(drives ...StorageDrive) {
|
||||||
|
storageDrives.addDrives(drives...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Storage) addDrives(drives ...StorageDrive) {
|
||||||
|
for _, d := range drives {
|
||||||
|
s.addDrive(d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Storage) addDrive(drive StorageDrive) {
|
||||||
|
if drive != nil {
|
||||||
|
driveName := drive.Name()
|
||||||
|
if _, ok := s.drives[driveName]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.drives[driveName] = drive
|
||||||
|
}
|
||||||
|
}
|
50
common/storage/storage_test.go
Normal file
50
common/storage/storage_test.go
Normal file
File diff suppressed because one or more lines are too long
34
common/storage/upload.go
Normal file
34
common/storage/upload.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Storage) Upload(ctx context.Context, data []byte, fileName string) string {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
for driveName, drive := range s.drives {
|
||||||
|
if drive == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
url, err := drive.Upload(data, fileName)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("%s err: %s", driveName, err.Error()))
|
||||||
|
} else {
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func Upload(data []byte, fileName string) string {
|
||||||
|
//lint:ignore SA1029 reason: 需要使用该类型作为错误处理
|
||||||
|
ctx := context.WithValue(context.Background(), common.RequestIdKey, "Upload")
|
||||||
|
|
||||||
|
return storageDrives.Upload(ctx, data, fileName)
|
||||||
|
}
|
@ -60,4 +60,7 @@ notify: # 通知设置, 配置了几个通知方式,就会同时发送几次
|
|||||||
pushkey: "" # pushkey
|
pushkey: "" # pushkey
|
||||||
telegram: # Telegram 通知
|
telegram: # Telegram 通知
|
||||||
bot_api_key: "" # 你的 Telegram bot 的 API 密钥
|
bot_api_key: "" # 你的 Telegram bot 的 API 密钥
|
||||||
chat_id: "" # 你的 Telegram chat_id
|
chat_id: "" # 你的 Telegram chat_id
|
||||||
|
storage: # 存储设置 (可选,主要用于图片生成,有些供应商不提供url,只能返回base64图片,设置后可以正常返回url格式的图片生成)
|
||||||
|
smms: # sm.ms 图床设置
|
||||||
|
secret: "" # 你的 sm.ms API 密钥
|
2
main.go
2
main.go
@ -7,6 +7,7 @@ import (
|
|||||||
"one-api/common/config"
|
"one-api/common/config"
|
||||||
"one-api/common/notify"
|
"one-api/common/notify"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/storage"
|
||||||
"one-api/common/telegram"
|
"one-api/common/telegram"
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/cron"
|
"one-api/cron"
|
||||||
@ -50,6 +51,7 @@ func main() {
|
|||||||
controller.InitMidjourneyTask()
|
controller.InitMidjourneyTask()
|
||||||
notify.InitNotifier()
|
notify.InitNotifier()
|
||||||
cron.InitCron()
|
cron.InitCron()
|
||||||
|
storage.InitStorage()
|
||||||
|
|
||||||
initHttpServer()
|
initHttpServer()
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/storage"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -43,16 +44,25 @@ func (p *CloudflareAIProvider) CreateImageGenerations(request *types.ImageReques
|
|||||||
return nil, common.ErrorWrapper(err, "read_response_failed", http.StatusInternalServerError)
|
return nil, common.ErrorWrapper(err, "read_response_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
base64Image := base64.StdEncoding.EncodeToString(body)
|
url := ""
|
||||||
|
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
|
||||||
|
url = storage.Upload(body, common.GetUUID()+".png")
|
||||||
|
}
|
||||||
|
|
||||||
openaiResponse := &types.ImageResponse{
|
openaiResponse := &types.ImageResponse{
|
||||||
Created: time.Now().Unix(),
|
Created: time.Now().Unix(),
|
||||||
Data: []types.ImageResponseDataInner{{B64JSON: base64Image}},
|
}
|
||||||
|
|
||||||
|
if url == "" {
|
||||||
|
base64Image := base64.StdEncoding.EncodeToString(body)
|
||||||
|
openaiResponse.Data = []types.ImageResponseDataInner{{B64JSON: base64Image}}
|
||||||
|
} else {
|
||||||
|
openaiResponse.Data = []types.ImageResponseDataInner{{URL: url}}
|
||||||
}
|
}
|
||||||
|
|
||||||
p.Usage.PromptTokens = 1000
|
p.Usage.PromptTokens = 1000
|
||||||
|
|
||||||
return openaiResponse, nil
|
return openaiResponse, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertFromIamgeOpenai(request *types.ImageRequest) *ImageRequest {
|
func convertFromIamgeOpenai(request *types.ImageRequest) *ImageRequest {
|
||||||
|
Loading…
Reference in New Issue
Block a user