From de7b9710a52a943e9b7dd54837916949cbda1663 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=AE=A1=E5=AE=9C=E5=B0=A7?=
Date: Fri, 17 Nov 2023 19:40:59 +0800
Subject: [PATCH 01/17] fix: fix PaLM not working issue (#667)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* bugfix for #515 最新版本谷歌PaLM模型无法使用
* update
* chore: remove unrelated file
* chore: add comment
---------
Co-authored-by: JustSong
---
controller/relay-text.go | 2 ++
1 file changed, 2 insertions(+)
diff --git a/controller/relay-text.go b/controller/relay-text.go
index a61c6f7c..b9a300b4 100644
--- a/controller/relay-text.go
+++ b/controller/relay-text.go
@@ -367,6 +367,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
+ case APITypePaLM:
+ // do not set Authorization header
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
From 1d15157f7d2f3d2b2e98c965a383a06c6535e665 Mon Sep 17 00:00:00 2001
From: ckt1031 <65409152+ckt1031@users.noreply.github.com>
Date: Fri, 17 Nov 2023 20:03:16 +0800
Subject: [PATCH 02/17] feat: keep sync with dall-e updates (#679)
* Updated ImageRequest struct and OpenAIModels,
added new Dall-E models and size ratios
* Fixed suspect `or`
* Refactored size ratio calculation in
relayImageHelper function
* Updated the format of resolution keys in
DalleSizeRatios map
* Added error handling for unsupported image size in
relayImageHelper function
* Added validation for number of generated images
and defined image generation ratios
* Refactored variable name from
DalleGenerationImageAmountRatios to
DalleGenerationImageAmounts
* Added validation for prompt length in
relayImageHelper function
* Updated model validation and removed size not
supported error in relayImageHelper function
* Refactored image size and model validation in
relayImageHelper function
* chore: discard binary file
* chore: update impl
---------
Co-authored-by: cktsun1031 <65409152+cktsun1031@users.noreply.github.com>
Co-authored-by: JustSong
---
common/model-ratio.go | 26 ++++++++++++++-
controller/model.go | 13 ++++++--
controller/relay-image.go | 67 ++++++++++++++++++++++++++++-----------
controller/relay.go | 12 +++++--
4 files changed, 93 insertions(+), 25 deletions(-)
diff --git a/common/model-ratio.go b/common/model-ratio.go
index 681f0ae7..b4a471dc 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -6,6 +6,29 @@ import (
"time"
)
+var DalleSizeRatios = map[string]map[string]float64{
+ "dall-e-2": {
+ "256x256": 1,
+ "512x512": 1.125,
+ "1024x1024": 1.25,
+ },
+ "dall-e-3": {
+ "1024x1024": 1,
+ "1024x1792": 2,
+ "1792x1024": 2,
+ },
+}
+
+var DalleGenerationImageAmounts = map[string][2]int{
+ "dall-e-2": {1, 10},
+ "dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
+}
+
+var DalleImagePromptLengthLimitations = map[string]int{
+ "dall-e-2": 1000,
+ "dall-e-3": 4000,
+}
+
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
@@ -45,7 +68,8 @@ var ModelRatio = map[string]float64{
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
- "dall-e": 8,
+ "dall-e-2": 8, // $0.016 - $0.020 / image
+ "dall-e-3": 20, // $0.040 - $0.120 / image
"claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
diff --git a/controller/model.go b/controller/model.go
index 7bd9d097..f9904330 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -55,12 +55,21 @@ func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
{
- Id: "dall-e",
+ Id: "dall-e-2",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
- Root: "dall-e",
+ Root: "dall-e-2",
+ Parent: nil,
+ },
+ {
+ Id: "dall-e-3",
+ Object: "model",
+ Created: 1677649963,
+ OwnedBy: "openai",
+ Permission: permission,
+ Root: "dall-e-3",
Parent: nil,
},
{
diff --git a/controller/relay-image.go b/controller/relay-image.go
index ccd52dce..1d1b71ba 100644
--- a/controller/relay-image.go
+++ b/controller/relay-image.go
@@ -6,15 +6,28 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
+
+ "github.com/gin-gonic/gin"
)
+func isWithinRange(element string, value int) bool {
+ if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
+ return false
+ }
+
+ min := common.DalleGenerationImageAmounts[element][0]
+ max := common.DalleGenerationImageAmounts[element][1]
+
+ return value >= min && value <= max
+}
+
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
- imageModel := "dall-e"
+ imageModel := "dall-e-2"
+ imageSize := "1024x1024"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
@@ -31,19 +44,44 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
}
+ // Size validation
+ if imageRequest.Size != "" {
+ imageSize = imageRequest.Size
+ }
+
+ // Model validation
+ if imageRequest.Model != "" {
+ imageModel = imageRequest.Model
+ }
+
+ imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
+
+ // Check if model is supported
+ if hasValidSize {
+ if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
+ if imageSize == "1024x1024" {
+ imageCostRatio *= 2
+ } else {
+ imageCostRatio *= 1.5
+ }
+ }
+ } else {
+ return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
+ }
+
// Prompt validation
if imageRequest.Prompt == "" {
- return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
+ return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
- // Not "256x256", "512x512", or "1024x1024"
- if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
- return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
+ // Check prompt length
+ if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
+ return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
- // N should between 1 and 10
- if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
- return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
+ // Number of generated images validation
+ if isWithinRange(imageModel, imageRequest.N) == false {
+ return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
// map model name
@@ -82,16 +120,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
- sizeRatio := 1.0
- // Size
- if imageRequest.Size == "256x256" {
- sizeRatio = 1
- } else if imageRequest.Size == "512x512" {
- sizeRatio = 1.125
- } else if imageRequest.Size == "1024x1024" {
- sizeRatio = 1.25
- }
- quota := int(ratio*sizeRatio*1000) * imageRequest.N
+ quota := int(ratio*imageCostRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
diff --git a/controller/relay.go b/controller/relay.go
index 1926110e..9cff887b 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -77,10 +77,16 @@ type TextRequest struct {
//Stream bool `json:"stream"`
}
+// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
type ImageRequest struct {
- Prompt string `json:"prompt"`
- N int `json:"n"`
- Size string `json:"size"`
+ Model string `json:"model"`
+ Prompt string `json:"prompt" binding:"required"`
+ N int `json:"n"`
+ Size string `json:"size"`
+ Quality string `json:"quality"`
+ ResponseFormat string `json:"response_format"`
+ Style string `json:"style"`
+ User string `json:"user"`
}
type AudioResponse struct {
From ddcaf95f5faddce75c095395744526d2d5713343 Mon Sep 17 00:00:00 2001
From: ckt1031 <65409152+ckt1031@users.noreply.github.com>
Date: Fri, 17 Nov 2023 21:18:51 +0800
Subject: [PATCH 03/17] feat: support tts model (#713)
* Added support for Text-to-Speech models and
endpoints
* chore: update impl
---------
Co-authored-by: JustSong
---
common/model-ratio.go | 6 +-
controller/model.go | 36 ++++++++++++
controller/relay-audio.go | 118 +++++++++++++++++++++-----------------
controller/relay-utils.go | 19 ++++++
controller/relay.go | 28 +++++++--
middleware/distributor.go | 9 +--
router/relay-router.go | 1 +
7 files changed, 151 insertions(+), 66 deletions(-)
diff --git a/common/model-ratio.go b/common/model-ratio.go
index b4a471dc..74c74a90 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -59,7 +59,11 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
- "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+ "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+ "tts-1": 7.5, // $0.015 / 1K characters
+ "tts-1-1106": 7.5,
+ "tts-1-hd": 15, // $0.030 / 1K characters
+ "tts-1-hd-1106": 15,
"davinci": 10,
"curie": 10,
"babbage": 10,
diff --git a/controller/model.go b/controller/model.go
index f9904330..59ea22e8 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -81,6 +81,42 @@ func init() {
Root: "whisper-1",
Parent: nil,
},
+ {
+ Id: "tts-1",
+ Object: "model",
+ Created: 1677649963,
+ OwnedBy: "openai",
+ Permission: permission,
+ Root: "tts-1",
+ Parent: nil,
+ },
+ {
+ Id: "tts-1-1106",
+ Object: "model",
+ Created: 1677649963,
+ OwnedBy: "openai",
+ Permission: permission,
+ Root: "tts-1-1106",
+ Parent: nil,
+ },
+ {
+ Id: "tts-1-hd",
+ Object: "model",
+ Created: 1677649963,
+ OwnedBy: "openai",
+ Permission: permission,
+ Root: "tts-1-hd",
+ Parent: nil,
+ },
+ {
+ Id: "tts-1-hd-1106",
+ Object: "model",
+ Created: 1677649963,
+ OwnedBy: "openai",
+ Permission: permission,
+ Root: "tts-1-hd-1106",
+ Parent: nil,
+ },
{
Id: "gpt-3.5-turbo",
Object: "model",
diff --git a/controller/relay-audio.go b/controller/relay-audio.go
index 53833108..01267fbf 100644
--- a/controller/relay-audio.go
+++ b/controller/relay-audio.go
@@ -5,7 +5,6 @@ import (
"context"
"encoding/json"
"errors"
- "fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -21,6 +20,22 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
+ tokenName := c.GetString("token_name")
+
+ var ttsRequest TextToSpeechRequest
+ if relayMode == RelayModeAudioSpeech {
+ // Read JSON
+ err := common.UnmarshalBodyReusable(c, &ttsRequest)
+ // Check if JSON is valid
+ if err != nil {
+ return errorWrapper(err, "invalid_json", http.StatusBadRequest)
+ }
+ audioModel = ttsRequest.Model
+ // Check if text is too long 4096
+ if len(ttsRequest.Input) > 4096 {
+ return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
+ }
+ }
preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioModel)
@@ -31,22 +46,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
- if userQuota-preConsumedQuota < 0 {
- return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
- }
- err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
- if err != nil {
- return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
- }
- if userQuota > 100*preConsumedQuota {
- // in this case, we do not pre-consume quota
- // because the user has enough quota
- preConsumedQuota = 0
- }
- if preConsumedQuota > 0 {
- err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
+
+ quota := 0
+ // Check if user quota is enough
+ if relayMode == RelayModeAudioSpeech {
+ quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio)
+ if quota > userQuota {
+ return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+ }
+ } else {
+ if userQuota-preConsumedQuota < 0 {
+ return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+ }
+ err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
- return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+ return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+ }
+ if userQuota > 100*preConsumedQuota {
+ // in this case, we do not pre-consume quota
+ // because the user has enough quota
+ preConsumedQuota = 0
+ }
+ if preConsumedQuota > 0 {
+ err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
+ if err != nil {
+ return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+ }
}
}
@@ -93,47 +118,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
- var audioResponse AudioResponse
- defer func(ctx context.Context) {
- go func() {
- quota := countTokenText(audioResponse.Text, audioModel)
+ if relayMode == RelayModeAudioSpeech {
+ defer func(ctx context.Context) {
+ go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
+ }(c.Request.Context())
+ } else {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+ }
+ var whisperResponse WhisperResponse
+ err = json.Unmarshal(responseBody, &whisperResponse)
+ if err != nil {
+ return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ }
+ defer func(ctx context.Context) {
+ quota := countTokenText(whisperResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
- err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
- if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
- }
- err = model.CacheUpdateUserQuota(userId)
- if err != nil {
- common.SysError("error update user quota cache: " + err.Error())
- }
- if quota != 0 {
- tokenName := c.GetString("token_name")
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- channelId := c.GetInt("channel_id")
- model.UpdateChannelUsedQuota(channelId, quota)
- }
- }()
- }(c.Request.Context())
-
- responseBody, err := io.ReadAll(resp.Body)
-
- if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
+ }(c.Request.Context())
+ resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
- err = resp.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
- }
- err = json.Unmarshal(responseBody, &audioResponse)
- if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
- }
-
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
-
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
diff --git a/controller/relay-utils.go b/controller/relay-utils.go
index cf5d9b69..888187cb 100644
--- a/controller/relay-utils.go
+++ b/controller/relay-utils.go
@@ -1,6 +1,7 @@
package controller
import (
+ "context"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
@@ -8,6 +9,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/model"
"strconv"
"strings"
)
@@ -186,3 +188,20 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin
}
return fullRequestURL
}
+
+func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
+ err := model.PostConsumeTokenQuota(tokenId, quota)
+ if err != nil {
+ common.SysError("error consuming token remain quota: " + err.Error())
+ }
+ err = model.CacheUpdateUserQuota(userId)
+ if err != nil {
+ common.SysError("error update user quota cache: " + err.Error())
+ }
+ if quota != 0 {
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ model.UpdateChannelUsedQuota(channelId, quota)
+ }
+}
diff --git a/controller/relay.go b/controller/relay.go
index 9cff887b..863267b4 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -24,7 +24,9 @@ const (
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
- RelayModeAudio
+ RelayModeAudioSpeech
+ RelayModeAudioTranscription
+ RelayModeAudioTranslation
)
// https://platform.openai.com/docs/api-reference/chat
@@ -89,10 +91,18 @@ type ImageRequest struct {
User string `json:"user"`
}
-type AudioResponse struct {
+type WhisperResponse struct {
Text string `json:"text,omitempty"`
}
+type TextToSpeechRequest struct {
+ Model string `json:"model" binding:"required"`
+ Input string `json:"input" binding:"required"`
+ Voice string `json:"voice" binding:"required"`
+ Speed float64 `json:"speed"`
+ ResponseFormat string `json:"response_format"`
+}
+
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
@@ -189,14 +199,22 @@ func Relay(c *gin.Context) {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
- relayMode = RelayModeAudio
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
+ relayMode = RelayModeAudioSpeech
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcription") {
+ relayMode = RelayModeAudioTranscription
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translation") {
+ relayMode = RelayModeAudioTranslation
}
var err *OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode)
- case RelayModeAudio:
+ case RelayModeAudioSpeech:
+ fallthrough
+ case RelayModeAudioTranslation:
+ fallthrough
+ case RelayModeAudioTranscription:
err = relayAudioHelper(c, relayMode)
default:
err = relayTextHelper(c, relayMode)
diff --git a/middleware/distributor.go b/middleware/distributor.go
index d80945fc..c4ddc3a0 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) {
} else {
// Select a channel for the user
var modelRequest ModelRequest
- var err error
- if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
- err = common.UnmarshalBodyReusable(c, &modelRequest)
- }
+ err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
@@ -60,10 +57,10 @@ func Distribute() func(c *gin.Context) {
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
- modelRequest.Model = "dall-e"
+ modelRequest.Model = "dall-e-2"
}
}
- if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
diff --git a/router/relay-router.go b/router/relay-router.go
index e84f02db..912f4989 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.Relay)
relayV1Router.POST("/audio/translations", controller.Relay)
+ relayV1Router.POST("/audio/speech", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
From 34d517cfa233d90b470e4d6ac0208051de4b4fea Mon Sep 17 00:00:00 2001
From: Mikey
Date: Fri, 17 Nov 2023 05:45:55 -0800
Subject: [PATCH 04/17] fix: cloudflare test & expose detailed info about test
failures (#715)
* fix: cloudflare test & expose detailed info about test failures
* fix: cloudflare test & expose detailed info about test failures
---------
Co-authored-by: JustSong
---
controller/channel-test.go | 16 +++++++++-------
1 file changed, 9 insertions(+), 7 deletions(-)
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 3c6c8f43..b47a44b9 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -6,11 +6,11 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
+ "io"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
- "strings"
"sync"
"time"
)
@@ -45,13 +45,11 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
} else {
- if channel.GetBaseURL() != "" {
- requestURL = channel.GetBaseURL()
+ if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
+ requestURL = baseURL
}
- requestURL += "/v1/chat/completions"
+ requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
}
- // for Cloudflare AI gateway: https://github.com/songquanpeng/one-api/pull/639
- requestURL = strings.Replace(requestURL, "/v1/v1", "/v1", 1)
jsonData, err := json.Marshal(request)
if err != nil {
@@ -73,10 +71,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
}
defer resp.Body.Close()
var response TextResponse
- err = json.NewDecoder(resp.Body).Decode(&response)
+ body, err := io.ReadAll(resp.Body)
if err != nil {
return err, nil
}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
+ }
if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
From 54e5f8ecd296a6d6d0662f6a16792ab3aa2cb9ab Mon Sep 17 00:00:00 2001
From: Buer <42402987+MartialBE@users.noreply.github.com>
Date: Sun, 19 Nov 2023 15:52:35 +0800
Subject: [PATCH 05/17] feat: support cloudflare gateway for azure (#666)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* 🐛 Fix cloudflare gateway request failure
* 🐛 fix channel test url error
---
controller/channel-test.go | 7 ++++---
controller/relay-text.go | 4 +++-
controller/relay-utils.go | 14 ++++++++++----
3 files changed, 17 insertions(+), 8 deletions(-)
diff --git a/controller/channel-test.go b/controller/channel-test.go
index b47a44b9..1b0b745a 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -13,6 +12,8 @@ import (
"strconv"
"sync"
"time"
+
+ "github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
@@ -43,14 +44,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
- requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
+ requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
} else {
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
requestURL = baseURL
}
+
requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
}
-
jsonData, err := json.Marshal(request)
if err != nil {
return err, nil
diff --git a/controller/relay-text.go b/controller/relay-text.go
index b9a300b4..018c8d8a 100644
--- a/controller/relay-text.go
+++ b/controller/relay-text.go
@@ -147,7 +147,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
- fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
+
+ requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
+ fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
}
case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete"
diff --git a/controller/relay-utils.go b/controller/relay-utils.go
index 888187cb..e2b77a97 100644
--- a/controller/relay-utils.go
+++ b/controller/relay-utils.go
@@ -4,14 +4,15 @@ import (
"context"
"encoding/json"
"fmt"
- "github.com/gin-gonic/gin"
- "github.com/pkoukk/tiktoken-go"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
+
+ "github.com/gin-gonic/gin"
+ "github.com/pkoukk/tiktoken-go"
)
var stopFinishReason = "stop"
@@ -181,11 +182,16 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- if channelType == common.ChannelTypeOpenAI {
- if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
+
+ if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
+ switch channelType {
+ case common.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
+ case common.ChannelTypeAzure:
+ fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
+
return fullRequestURL
}
From 969f539777971fde12bbc9c495b9dbea93bf2708 Mon Sep 17 00:00:00 2001
From: Ian Li
Date: Sun, 19 Nov 2023 16:11:39 +0800
Subject: [PATCH 06/17] fix: skip JSON deserialization when accessing
transcriptions and translations (#718)
* fix: Skip JSON deserialization when accessing transcriptions and translations.
* chore: update impl
---------
Co-authored-by: JustSong
---
common/gin.go | 9 ++++++++-
controller/relay.go | 4 ++--
2 files changed, 10 insertions(+), 3 deletions(-)
diff --git a/common/gin.go b/common/gin.go
index ffa1e218..f5012688 100644
--- a/common/gin.go
+++ b/common/gin.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
+ "strings"
)
func UnmarshalBodyReusable(c *gin.Context, v any) error {
@@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil {
return err
}
- err = json.Unmarshal(requestBody, &v)
+ contentType := c.Request.Header.Get("Content-Type")
+ if strings.HasPrefix(contentType, "application/json") {
+ err = json.Unmarshal(requestBody, &v)
+ } else {
+ // skip for now
+ // TODO: someday non json request have variant model, we will need to implementation this
+ }
if err != nil {
return err
}
diff --git a/controller/relay.go b/controller/relay.go
index 863267b4..504ee8ca 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -201,9 +201,9 @@ func Relay(c *gin.Context) {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcription") {
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translation") {
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
}
var err *OpenAIErrorWithStatusCode
From 915d13fdd4946145cfff87cacddad2dda2552ef9 Mon Sep 17 00:00:00 2001
From: JustSong
Date: Sun, 19 Nov 2023 17:22:35 +0800
Subject: [PATCH 07/17] docs: update readme (#724)
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 4ef6505c..904946d7 100644
--- a/README.md
+++ b/README.md
@@ -99,7 +99,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称,logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
-20. 支持通过系统访问令牌访问管理 API。
+20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。
21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式**:
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
From 76f9288c341dc3c3d7b527f18b4b86d513d3e3fb Mon Sep 17 00:00:00 2001
From: JustSong
Date: Sun, 19 Nov 2023 17:50:30 +0800
Subject: [PATCH 08/17] feat: update request struct (close #708)
---
README.md | 2 +-
controller/relay.go | 35 +++++++++++++++++++++++------------
2 files changed, 24 insertions(+), 13 deletions(-)
diff --git a/README.md b/README.md
index 904946d7..20c81361 100644
--- a/README.md
+++ b/README.md
@@ -92,7 +92,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
12. 支持**用户邀请奖励**。
13. 支持以美元为单位显示额度。
14. 支持发布公告,设置充值链接,设置新用户初始额度。
-15. 支持模型映射,重定向用户的请求模型。
+15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。
16. 支持失败自动重试。
17. 支持绘图接口。
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
diff --git a/controller/relay.go b/controller/relay.go
index 504ee8ca..5837c0b8 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -31,19 +31,30 @@ const (
// https://platform.openai.com/docs/api-reference/chat
+type ResponseFormat struct {
+ Type string `json:"type,omitempty"`
+}
+
type GeneralOpenAIRequest struct {
- Model string `json:"model,omitempty"`
- Messages []Message `json:"messages,omitempty"`
- Prompt any `json:"prompt,omitempty"`
- Stream bool `json:"stream,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- N int `json:"n,omitempty"`
- Input any `json:"input,omitempty"`
- Instruction string `json:"instruction,omitempty"`
- Size string `json:"size,omitempty"`
- Functions any `json:"functions,omitempty"`
+ Model string `json:"model,omitempty"`
+ Messages []Message `json:"messages,omitempty"`
+ Prompt any `json:"prompt,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ N int `json:"n,omitempty"`
+ Input any `json:"input,omitempty"`
+ Instruction string `json:"instruction,omitempty"`
+ Size string `json:"size,omitempty"`
+ Functions any `json:"functions,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
+ Seed float64 `json:"seed,omitempty"`
+ Tools any `json:"tools,omitempty"`
+ ToolChoice any `json:"tool_choice,omitempty"`
+ User string `json:"user,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
From 495fc628e432da0f400c93c8b9195c5f315aaaa7 Mon Sep 17 00:00:00 2001
From: JustSong
Date: Sun, 19 Nov 2023 18:38:54 +0800
Subject: [PATCH 09/17] feat: support gpt-4 with vision (#683, #714)
---
controller/relay-aiproxy.go | 2 +-
controller/relay-ali.go | 8 ++++----
controller/relay-baidu.go | 4 ++--
controller/relay-openai.go | 2 +-
controller/relay-palm.go | 2 +-
controller/relay-tencent.go | 4 ++--
controller/relay-utils.go | 2 +-
controller/relay-xunfei.go | 4 ++--
controller/relay-zhipu.go | 4 ++--
controller/relay.go | 41 ++++++++++++++++++++++++++++++++++++-
10 files changed, 56 insertions(+), 17 deletions(-)
diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go
index d0159ce8..543954f7 100644
--- a/controller/relay-aiproxy.go
+++ b/controller/relay-aiproxy.go
@@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct {
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := ""
if len(request.Messages) != 0 {
- query = request.Messages[len(request.Messages)-1].Content
+ query = request.Messages[len(request.Messages)-1].StringContent()
}
return &AIProxyLibraryRequest{
Model: request.Model,
diff --git a/controller/relay-ali.go b/controller/relay-ali.go
index 50dc743c..b41ca327 100644
--- a/controller/relay-ali.go
+++ b/controller/relay-ali.go
@@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
- User: message.Content,
+ User: message.StringContent(),
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
- prompt = message.Content
+ prompt = message.StringContent()
break
}
messages = append(messages, AliMessage{
- User: message.Content,
- Bot: request.Messages[i+1].Content,
+ User: message.StringContent(),
+ Bot: request.Messages[i+1].StringContent(),
})
i++
}
diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go
index ed08ac04..c75ec09a 100644
--- a/controller/relay-baidu.go
+++ b/controller/relay-baidu.go
@@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
if message.Role == "system" {
messages = append(messages, BaiduMessage{
Role: "user",
- Content: message.Content,
+ Content: message.StringContent(),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
@@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
} else {
messages = append(messages, BaiduMessage{
Role: message.Role,
- Content: message.Content,
+ Content: message.StringContent(),
})
}
}
diff --git a/controller/relay-openai.go b/controller/relay-openai.go
index 6bdfbc08..dcd20115 100644
--- a/controller/relay-openai.go
+++ b/controller/relay-openai.go
@@ -132,7 +132,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
- completionTokens += countTokenText(choice.Message.Content, model)
+ completionTokens += countTokenText(choice.Message.StringContent(), model)
}
textResponse.Usage = Usage{
PromptTokens: promptTokens,
diff --git a/controller/relay-palm.go b/controller/relay-palm.go
index a705b318..2bd0bcd8 100644
--- a/controller/relay-palm.go
+++ b/controller/relay-palm.go
@@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
- Content: message.Content,
+ Content: message.StringContent(),
}
if message.Role == "user" {
palmMessage.Author = "0"
diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go
index 024468bc..f66bf38f 100644
--- a/controller/relay-tencent.go
+++ b/controller/relay-tencent.go
@@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
- Content: message.Content,
+ Content: message.StringContent(),
})
messages = append(messages, TencentMessage{
Role: "assistant",
@@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
continue
}
messages = append(messages, TencentMessage{
- Content: message.Content,
+ Content: message.StringContent(),
Role: message.Role,
})
}
diff --git a/controller/relay-utils.go b/controller/relay-utils.go
index e2b77a97..c7cd4766 100644
--- a/controller/relay-utils.go
+++ b/controller/relay-utils.go
@@ -87,7 +87,7 @@ func countTokenMessages(messages []Message, model string) int {
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
- tokenNum += getTokenNum(tokenEncoder, message.Content)
+ tokenNum += getTokenNum(tokenEncoder, message.StringContent())
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName
diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go
index 91fb6042..00ec8981 100644
--- a/controller/relay-xunfei.go
+++ b/controller/relay-xunfei.go
@@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
- Content: message.Content,
+ Content: message.StringContent(),
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
@@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
- Content: message.Content,
+ Content: message.StringContent(),
})
}
}
diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go
index 7a4a582d..2e345ab5 100644
--- a/controller/relay-zhipu.go
+++ b/controller/relay-zhipu.go
@@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
if message.Role == "system" {
messages = append(messages, ZhipuMessage{
Role: "system",
- Content: message.Content,
+ Content: message.StringContent(),
})
messages = append(messages, ZhipuMessage{
Role: "user",
@@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
} else {
messages = append(messages, ZhipuMessage{
Role: message.Role,
- Content: message.Content,
+ Content: message.StringContent(),
})
}
}
diff --git a/controller/relay.go b/controller/relay.go
index 5837c0b8..f91ba6da 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -12,10 +12,49 @@ import (
type Message struct {
Role string `json:"role"`
- Content string `json:"content"`
+ Content any `json:"content"`
Name *string `json:"name,omitempty"`
}
+type ImageURL struct {
+ Url string `json:"url,omitempty"`
+ Detail string `json:"detail,omitempty"`
+}
+
+type TextContent struct {
+ Type string `json:"type,omitempty"`
+ Text string `json:"text,omitempty"`
+}
+
+type ImageContent struct {
+ Type string `json:"type,omitempty"`
+ ImageURL *ImageURL `json:"image_url,omitempty"`
+}
+
+func (m Message) StringContent() string {
+ content, ok := m.Content.(string)
+ if ok {
+ return content
+ }
+ contentList, ok := m.Content.([]any)
+ if ok {
+ var contentStr string
+ for _, contentItem := range contentList {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == "text" {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+ return ""
+}
+
const (
RelayModeUnknown = iota
RelayModeChatCompletions
From d85e356b6e6a7b22750f1aa93c3024862e82c58f Mon Sep 17 00:00:00 2001
From: igophper <34326532+igophper@users.noreply.github.com>
Date: Fri, 24 Nov 2023 20:42:29 +0800
Subject: [PATCH 10/17] refactor: remove consumeQuota related logic (#738)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* feat: 删除relay-text中的consumeQuota变量
该变量始终为true,可以删除
* chore: remove useless code
---------
Co-authored-by: JustSong
---
controller/relay-image.go | 71 +++++++++++++++++--------------------
controller/relay-openai.go | 45 ++++++++++++------------
controller/relay-text.go | 72 ++++++++++++++++++--------------------
middleware/auth.go | 6 ----
4 files changed, 88 insertions(+), 106 deletions(-)
diff --git a/controller/relay-image.go b/controller/relay-image.go
index 1d1b71ba..0ff18309 100644
--- a/controller/relay-image.go
+++ b/controller/relay-image.go
@@ -33,15 +33,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
- consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var imageRequest ImageRequest
- if consumeQuota {
- err := common.UnmarshalBodyReusable(c, &imageRequest)
- if err != nil {
- return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
- }
+ err := common.UnmarshalBodyReusable(c, &imageRequest)
+ if err != nil {
+ return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
// Size validation
@@ -122,7 +119,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
- if consumeQuota && userQuota-quota < 0 {
+ if userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
@@ -151,43 +148,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
var textResponse ImageResponse
defer func(ctx context.Context) {
- if consumeQuota {
- err := model.PostConsumeTokenQuota(tokenId, quota)
- if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
- }
- err = model.CacheUpdateUserQuota(userId)
- if err != nil {
- common.SysError("error update user quota cache: " + err.Error())
- }
- if quota != 0 {
- tokenName := c.GetString("token_name")
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- channelId := c.GetInt("channel_id")
- model.UpdateChannelUsedQuota(channelId, quota)
- }
+ err := model.PostConsumeTokenQuota(tokenId, quota)
+ if err != nil {
+ common.SysError("error consuming token remain quota: " + err.Error())
+ }
+ err = model.CacheUpdateUserQuota(userId)
+ if err != nil {
+ common.SysError("error update user quota cache: " + err.Error())
+ }
+ if quota != 0 {
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ channelId := c.GetInt("channel_id")
+ model.UpdateChannelUsedQuota(channelId, quota)
}
}(c.Request.Context())
- if consumeQuota {
- responseBody, err := io.ReadAll(resp.Body)
+ responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- }
- err = resp.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
- }
- err = json.Unmarshal(responseBody, &textResponse)
- if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
- }
-
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ if err != nil {
+ return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
+ err = resp.Body.Close()
+ if err != nil {
+ return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+ }
+ err = json.Unmarshal(responseBody, &textResponse)
+ if err != nil {
+ return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ }
+
+ resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
diff --git a/controller/relay-openai.go b/controller/relay-openai.go
index dcd20115..37867843 100644
--- a/controller/relay-openai.go
+++ b/controller/relay-openai.go
@@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
return nil, responseText
}
-func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
+func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
var textResponse TextResponse
- if consumeQuota {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- err = json.Unmarshal(responseBody, &textResponse)
- if err != nil {
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
- }
- if textResponse.Error.Type != "" {
- return &OpenAIErrorWithStatusCode{
- OpenAIError: textResponse.Error,
- StatusCode: resp.StatusCode,
- }, nil
- }
- // Reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
+ err = resp.Body.Close()
+ if err != nil {
+ return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = json.Unmarshal(responseBody, &textResponse)
+ if err != nil {
+ return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ if textResponse.Error.Type != "" {
+ return &OpenAIErrorWithStatusCode{
+ OpenAIError: textResponse.Error,
+ StatusCode: resp.StatusCode,
+ }, nil
+ }
+ // Reset response body
+ resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
@@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
- _, err := io.Copy(c.Writer, resp.Body)
+ _, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
diff --git a/controller/relay-text.go b/controller/relay-text.go
index 018c8d8a..dd9e7153 100644
--- a/controller/relay-text.go
+++ b/controller/relay-text.go
@@ -51,14 +51,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
- consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var textRequest GeneralOpenAIRequest
- if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
- err := common.UnmarshalBodyReusable(c, &textRequest)
- if err != nil {
- return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
- }
+ err := common.UnmarshalBodyReusable(c, &textRequest)
+ if err != nil {
+ return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
@@ -235,7 +232,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
- if consumeQuota && preConsumedQuota > 0 {
+ if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
@@ -414,37 +411,36 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
- if consumeQuota {
- quota := 0
- completionRatio := common.GetCompletionRatio(textRequest.Model)
- promptTokens = textResponse.Usage.PromptTokens
- completionTokens = textResponse.Usage.CompletionTokens
- quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
- if ratio != 0 && quota <= 0 {
- quota = 1
- }
- totalTokens := promptTokens + completionTokens
- if totalTokens == 0 {
- // in this case, must be some error happened
- // we cannot just return, because we may have to return the pre-consumed quota
- quota = 0
- }
- quotaDelta := quota - preConsumedQuota
- err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
- if err != nil {
- common.LogError(ctx, "error consuming token remain quota: "+err.Error())
- }
- err = model.CacheUpdateUserQuota(userId)
- if err != nil {
- common.LogError(ctx, "error update user quota cache: "+err.Error())
- }
- if quota != 0 {
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- model.UpdateChannelUsedQuota(channelId, quota)
- }
+ quota := 0
+ completionRatio := common.GetCompletionRatio(textRequest.Model)
+ promptTokens = textResponse.Usage.PromptTokens
+ completionTokens = textResponse.Usage.CompletionTokens
+ quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
+ if ratio != 0 && quota <= 0 {
+ quota = 1
}
+ totalTokens := promptTokens + completionTokens
+ if totalTokens == 0 {
+ // in this case, must be some error happened
+ // we cannot just return, because we may have to return the pre-consumed quota
+ quota = 0
+ }
+ quotaDelta := quota - preConsumedQuota
+ err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
+ if err != nil {
+ common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ }
+ err = model.CacheUpdateUserQuota(userId)
+ if err != nil {
+ common.LogError(ctx, "error update user quota cache: "+err.Error())
+ }
+ if quota != 0 {
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+ model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ model.UpdateChannelUsedQuota(channelId, quota)
+ }
+
}()
}(c.Request.Context())
switch apiType {
@@ -458,7 +454,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
- err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
+ err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
diff --git a/middleware/auth.go b/middleware/auth.go
index b0803612..ad7e64b7 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
- requestURL := c.Request.URL.String()
- consumeQuota := true
- if strings.HasPrefix(requestURL, "/v1/models") {
- consumeQuota = false
- }
- c.Set("consume_quota", consumeQuota)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
From b4d67ca6144eff90a2f8bf8f7f1b262af67d2e0e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?ShinChven=20=E2=9C=A8?=
Date: Fri, 24 Nov 2023 20:52:59 +0800
Subject: [PATCH 11/17] fix: add Message-ID header for email (#732)
* feat: Add Message-ID to email headers to comply with RFC 5322
- Extract domain from SMTPFrom
- Generate a unique Message-ID
- Add Message-ID to email headers
* chore: check slice length
---------
Co-authored-by: JustSong
---
common/email.go | 22 ++++++++++++++++++++--
1 file changed, 20 insertions(+), 2 deletions(-)
diff --git a/common/email.go b/common/email.go
index 74f4cccd..7d6963cc 100644
--- a/common/email.go
+++ b/common/email.go
@@ -1,6 +1,7 @@
package common
import (
+ "crypto/rand"
"crypto/tls"
"encoding/base64"
"fmt"
@@ -13,15 +14,32 @@ func SendEmail(subject string, receiver string, content string) error {
SMTPFrom = SMTPAccount
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
+
+ // Extract domain from SMTPFrom
+ parts := strings.Split(SMTPFrom, "@")
+ var domain string
+ if len(parts) > 1 {
+ domain = parts[1]
+ }
+ // Generate a unique Message-ID
+ buf := make([]byte, 16)
+ _, err := rand.Read(buf)
+ if err != nil {
+ return err
+ }
+ messageId := fmt.Sprintf("<%x@%s>", buf, domain)
+
mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+
"Subject: %s\r\n"+
+ "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
- receiver, SystemName, SMTPFrom, encodedSubject, content))
+ receiver, SystemName, SMTPFrom, encodedSubject, messageId, content))
+
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";")
- var err error
+
if SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
From 923e24534b4626f667479da4efc9a2867cdeff5a Mon Sep 17 00:00:00 2001
From: Tillman Bailee <51190972+YOMIkio@users.noreply.github.com>
Date: Fri, 24 Nov 2023 20:56:53 +0800
Subject: [PATCH 12/17] fix: add Date header for email (#742)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* 修复自建邮箱发送错误: INVALID HEADER Missing required header field: "Date"
* chore: fix style
---------
Co-authored-by: liyujie <29959257@qq.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
Co-authored-by: JustSong
---
common/email.go | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/common/email.go b/common/email.go
index 7d6963cc..b915f0f9 100644
--- a/common/email.go
+++ b/common/email.go
@@ -7,6 +7,7 @@ import (
"fmt"
"net/smtp"
"strings"
+ "time"
)
func SendEmail(subject string, receiver string, content string) error {
@@ -33,9 +34,9 @@ func SendEmail(subject string, receiver string, content string) error {
"From: %s<%s>\r\n"+
"Subject: %s\r\n"+
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
+ "Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
- receiver, SystemName, SMTPFrom, encodedSubject, messageId, content))
-
+ receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";")
From 3347a44023e8b1313ba640c222fad9d2a9ce755b Mon Sep 17 00:00:00 2001
From: Ian Li
Date: Fri, 24 Nov 2023 21:10:18 +0800
Subject: [PATCH 13/17] feat: support Azure's Whisper model (#720)
---
controller/relay-audio.go | 24 +++++++++++++++++++++++-
1 file changed, 23 insertions(+), 1 deletion(-)
diff --git a/controller/relay-audio.go b/controller/relay-audio.go
index 01267fbf..89a311a0 100644
--- a/controller/relay-audio.go
+++ b/controller/relay-audio.go
@@ -5,11 +5,13 @@ import (
"context"
"encoding/json"
"errors"
+ "fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
+ "strings"
)
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -95,13 +97,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
+ if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
+ // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
+ query := c.Request.URL.Query()
+ apiVersion := query.Get("api-version")
+ if apiVersion == "" {
+ apiVersion = c.GetString("api_version")
+ }
+ baseURL = c.GetString("base_url")
+ fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
+ }
+
requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
- req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+
+ if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
+ // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
+ apiKey := c.Request.Header.Get("Authorization")
+ apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+ req.Header.Set("api-key", apiKey)
+ req.ContentLength = c.Request.ContentLength
+ } else {
+ req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+ }
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
From b4e43d97fd11d7f9cbe18e5c1239de8b6ab5b6a5 Mon Sep 17 00:00:00 2001
From: JustSong
Date: Fri, 24 Nov 2023 21:21:03 +0800
Subject: [PATCH 14/17] docs: add pr template
---
pull_request_template.md | 3 +++
1 file changed, 3 insertions(+)
create mode 100644 pull_request_template.md
diff --git a/pull_request_template.md b/pull_request_template.md
new file mode 100644
index 00000000..bbcd969c
--- /dev/null
+++ b/pull_request_template.md
@@ -0,0 +1,3 @@
+close #issue_number
+
+我已确认该 PR 已自测通过,相关截图如下:
\ No newline at end of file
From b273464e777632bd45c4502df4fe12e6fdd264f2 Mon Sep 17 00:00:00 2001
From: JustSong
Date: Fri, 24 Nov 2023 21:23:16 +0800
Subject: [PATCH 15/17] docs: update readme
---
README.md | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index 20c81361..7e6a7b38 100644
--- a/README.md
+++ b/README.md
@@ -51,15 +51,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
赞赏支持
-> **Note**
+> [!NOTE]
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
>
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
-> **Warning**
+> [!WARNING]
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
-> **Warning**
+> [!WARNING]
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`!
## 功能
From 9889377f0e9260e852fb121d886ef3d9517ff8f9 Mon Sep 17 00:00:00 2001
From: JustSong
Date: Fri, 24 Nov 2023 21:39:44 +0800
Subject: [PATCH 16/17] feat: support claude-2.x (close #736)
---
common/model-ratio.go | 2 ++
controller/model.go | 18 ++++++++++++++++++
controller/relay-claude.go | 4 +++-
web/src/pages/Channel/EditChannel.js | 2 +-
4 files changed, 24 insertions(+), 2 deletions(-)
diff --git a/common/model-ratio.go b/common/model-ratio.go
index 74c74a90..ccbc05dd 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -76,6 +76,8 @@ var ModelRatio = map[string]float64{
"dall-e-3": 20, // $0.040 - $0.120 / image
"claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens
+ "claude-2.0": 5.51, // $11.02 / 1M tokens
+ "claude-2.1": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
diff --git a/controller/model.go b/controller/model.go
index 59ea22e8..8f79524d 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -360,6 +360,24 @@ func init() {
Root: "claude-2",
Parent: nil,
},
+ {
+ Id: "claude-2.1",
+ Object: "model",
+ Created: 1677649963,
+ OwnedBy: "anthropic",
+ Permission: permission,
+ Root: "claude-2.1",
+ Parent: nil,
+ },
+ {
+ Id: "claude-2.0",
+ Object: "model",
+ Created: 1677649963,
+ OwnedBy: "anthropic",
+ Permission: permission,
+ Root: "claude-2.0",
+ Parent: nil,
+ },
{
Id: "ERNIE-Bot",
Object: "model",
diff --git a/controller/relay-claude.go b/controller/relay-claude.go
index 1f4a3e7b..1b72b47d 100644
--- a/controller/relay-claude.go
+++ b/controller/relay-claude.go
@@ -70,7 +70,9 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
- prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
+ if prompt == "" {
+ prompt = message.StringContent()
+ }
}
}
prompt += "\n\nAssistant:"
diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js
index 654a5d51..bc3886a0 100644
--- a/web/src/pages/Channel/EditChannel.js
+++ b/web/src/pages/Channel/EditChannel.js
@@ -60,7 +60,7 @@ const EditChannel = () => {
let localModels = [];
switch (value) {
case 14:
- localModels = ['claude-instant-1', 'claude-2'];
+ localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1'];
break;
case 11:
localModels = ['PaLM-2'];
From 0e73418cdfef809fec7c8a2b6bb632a3b207eb88 Mon Sep 17 00:00:00 2001
From: JustSong
Date: Sun, 26 Nov 2023 12:05:16 +0800
Subject: [PATCH 17/17] fix: fix log recording & error handling for relay-audio
---
controller/relay-audio.go | 81 ++++++++++++++++++++++-----------------
controller/relay-utils.go | 17 +++++---
2 files changed, 57 insertions(+), 41 deletions(-)
diff --git a/controller/relay-audio.go b/controller/relay-audio.go
index 89a311a0..5b8898a7 100644
--- a/controller/relay-audio.go
+++ b/controller/relay-audio.go
@@ -39,41 +39,40 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
}
- preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
- preConsumedQuota := int(float64(preConsumedTokens) * ratio)
+ var quota int
+ var preConsumedQuota int
+ switch relayMode {
+ case RelayModeAudioSpeech:
+ preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
+ quota = preConsumedQuota
+ default:
+ preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
+ }
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
- quota := 0
// Check if user quota is enough
- if relayMode == RelayModeAudioSpeech {
- quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio)
- if quota > userQuota {
- return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
- }
- } else {
- if userQuota-preConsumedQuota < 0 {
- return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
- }
- err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
+ if userQuota-preConsumedQuota < 0 {
+ return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+ }
+ err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
+ if err != nil {
+ return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+ }
+ if userQuota > 100*preConsumedQuota {
+ // in this case, we do not pre-consume quota
+ // because the user has enough quota
+ preConsumedQuota = 0
+ }
+ if preConsumedQuota > 0 {
+ err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
- return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
- }
- if userQuota > 100*preConsumedQuota {
- // in this case, we do not pre-consume quota
- // because the user has enough quota
- preConsumedQuota = 0
- }
- if preConsumedQuota > 0 {
- err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
- if err != nil {
- return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
- }
+ return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
@@ -141,11 +140,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
- if relayMode == RelayModeAudioSpeech {
- defer func(ctx context.Context) {
- go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
- }(c.Request.Context())
- } else {
+ if relayMode != RelayModeAudioSpeech {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
@@ -159,13 +154,29 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
- defer func(ctx context.Context) {
- quota := countTokenText(whisperResponse.Text, audioModel)
- quotaDelta := quota - preConsumedQuota
- go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
- }(c.Request.Context())
+ quota = countTokenText(whisperResponse.Text, audioModel)
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
+ if resp.StatusCode != http.StatusOK {
+ if preConsumedQuota > 0 {
+ // we need to roll back the pre-consumed quota
+ defer func(ctx context.Context) {
+ go func() {
+ // negative means add quota back for token & user
+ err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
+ }
+ }()
+ }(c.Request.Context())
+ }
+ return relayErrorHandler(resp)
+ }
+ quotaDelta := quota - preConsumedQuota
+ defer func(ctx context.Context) {
+ go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
+ }(c.Request.Context())
+
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
diff --git a/controller/relay-utils.go b/controller/relay-utils.go
index c7cd4766..391f28b4 100644
--- a/controller/relay-utils.go
+++ b/controller/relay-utils.go
@@ -195,8 +195,9 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin
return fullRequestURL
}
-func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
- err := model.PostConsumeTokenQuota(tokenId, quota)
+func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
+ // quotaDelta is remaining quota to be consumed
+ err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -204,10 +205,14 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
- if quota != 0 {
+ // totalQuota is total quota consumed
+ if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent)
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- model.UpdateChannelUsedQuota(channelId, quota)
+ model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
+ model.UpdateChannelUsedQuota(channelId, totalQuota)
+ }
+ if totalQuota <= 0 {
+ common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}