feat: support cogview-3

This commit is contained in:
JustSong 2024-04-06 00:30:08 +08:00
parent 0cb224e62e
commit 880e12c855
9 changed files with 127 additions and 73 deletions

View File

@ -62,8 +62,8 @@ var ModelRatio = map[string]float64{
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image
"dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image
"dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image
// https://www.anthropic.com/api#pricing
"claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 8.0 / 1000 * USD,
@ -96,14 +96,15 @@ var ModelRatio = map[string]float64{
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
// https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
"glm-3-turbo": 0.005 * RMB,
"embedding-2": 0.0005 * RMB,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
"glm-3-turbo": 0.005 * RMB,
"embedding-2": 0.0005 * RMB,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"cogview-3": 0.25 * RMB,
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens

View File

@ -0,0 +1,44 @@
package openai
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var imageResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, nil
}

View File

@ -149,37 +149,3 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
}
return nil, &textResponse.Usage
}
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var imageResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, nil
}

View File

@ -32,13 +32,16 @@ func (a *Adaptor) SetVersionByModeName(modelName string) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
switch meta.Mode {
case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
a.SetVersionByModeName(meta.ActualModelName)
if a.APIVersion == "v4" {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
}
if meta.Mode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
method := "invoke"
if meta.IsStream {
method = "sse-invoke"
@ -81,7 +84,12 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
newRequest := ImageRequest{
Model: request.Model,
Prompt: request.Prompt,
UserId: request.User,
}
return newRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
@ -98,10 +106,17 @@ func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.R
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingsHandler(c, resp)
return
case constant.RelayModeImagesGenerations:
err, usage = openai.ImageHandler(c, resp)
return
}
if a.APIVersion == "v4" {
return a.DoResponseV4(c, resp, meta)
}
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {

View File

@ -3,4 +3,5 @@ package zhipu
var ModelList = []string{
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
"glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
"cogview-3",
}

View File

@ -62,3 +62,9 @@ type EmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
}
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
UserId string `json:"user_id,omitempty"`
}

View File

@ -1,6 +1,6 @@
package constant
var DalleSizeRatios = map[string]map[string]float64{
var ImageSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
@ -11,7 +11,14 @@ var DalleSizeRatios = map[string]map[string]float64{
"1024x1792": 2,
"1792x1024": 2,
},
"stable-diffusion-xl": {
"ali-stable-diffusion-xl": {
"512x1024": 1,
"1024x768": 1,
"1024x1024": 1,
"576x1024": 1,
"1024x576": 1,
},
"ali-stable-diffusion-v1.5": {
"512x1024": 1,
"1024x768": 1,
"1024x1024": 1,
@ -25,17 +32,20 @@ var DalleSizeRatios = map[string]map[string]float64{
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
"stable-diffusion-xl": {1, 4}, // Ali
"wanx-v1": {1, 4}, // Ali
var ImageGenerationAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
"ali-stable-diffusion-xl": {1, 4}, // Ali
"ali-stable-diffusion-v1.5": {1, 4}, // Ali
"wanx-v1": {1, 4}, // Ali
"cogview-3": {1, 1},
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
"stable-diffusion-xl": 4000,
"wanx-v1": 4000,
"cogview-3": 833,
var ImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
"ali-stable-diffusion-xl": 4000,
"ali-stable-diffusion-v1.5": 4000,
"wanx-v1": 4000,
"cogview-3": 833,
}

View File

@ -54,9 +54,25 @@ func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, e
return imageRequest, nil
}
func isValidImageSize(model string, size string) bool {
if model == "cogview-3" {
return true
}
_, ok := constant.ImageSizeRatios[model][size]
return ok
}
func getImageSizeRatio(model string, size string) float64 {
ratio, ok := constant.ImageSizeRatios[model][size]
if !ok {
return 1
}
return ratio
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode {
// model validation
_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
if !hasValidSize {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
@ -64,7 +80,7 @@ func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.Rela
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] {
if len(imageRequest.Prompt) > constant.ImagePromptLengthLimitations[imageRequest.Model] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
@ -81,10 +97,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
if !hasValidSize {
return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size)
}
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2

View File

@ -20,12 +20,11 @@ import (
)
func isWithinRange(element string, value int) bool {
if _, ok := constant.DalleGenerationImageAmounts[element]; !ok {
if _, ok := constant.ImageGenerationAmounts[element]; !ok {
return false
}
min := constant.DalleGenerationImageAmounts[element][0]
max := constant.DalleGenerationImageAmounts[element][1]
min := constant.ImageGenerationAmounts[element][0]
max := constant.ImageGenerationAmounts[element][1]
return value >= min && value <= max
}
@ -81,7 +80,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
}
jsonStr, err := json.Marshal(finalRequest)
if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)