169 lines
4.4 KiB
Go
169 lines
4.4 KiB
Go
package util
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/songquanpeng/one-api/common"
|
|
"github.com/songquanpeng/one-api/common/config"
|
|
"github.com/songquanpeng/one-api/common/logger"
|
|
"github.com/songquanpeng/one-api/model"
|
|
"github.com/songquanpeng/one-api/relay/channel/openai"
|
|
"io"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
|
|
if !config.AutomaticDisableChannelEnabled {
|
|
return false
|
|
}
|
|
if err == nil {
|
|
return false
|
|
}
|
|
if statusCode == http.StatusUnauthorized {
|
|
return true
|
|
}
|
|
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
|
|
if !config.AutomaticEnableChannelEnabled {
|
|
return false
|
|
}
|
|
if err != nil {
|
|
return false
|
|
}
|
|
if openAIErr != nil {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
type GeneralErrorResponse struct {
|
|
Error openai.Error `json:"error"`
|
|
Message string `json:"message"`
|
|
Msg string `json:"msg"`
|
|
Err string `json:"err"`
|
|
ErrorMsg string `json:"error_msg"`
|
|
Header struct {
|
|
Message string `json:"message"`
|
|
} `json:"header"`
|
|
Response struct {
|
|
Error struct {
|
|
Message string `json:"message"`
|
|
} `json:"error"`
|
|
} `json:"response"`
|
|
}
|
|
|
|
func (e GeneralErrorResponse) ToMessage() string {
|
|
if e.Error.Message != "" {
|
|
return e.Error.Message
|
|
}
|
|
if e.Message != "" {
|
|
return e.Message
|
|
}
|
|
if e.Msg != "" {
|
|
return e.Msg
|
|
}
|
|
if e.Err != "" {
|
|
return e.Err
|
|
}
|
|
if e.ErrorMsg != "" {
|
|
return e.ErrorMsg
|
|
}
|
|
if e.Header.Message != "" {
|
|
return e.Header.Message
|
|
}
|
|
if e.Response.Error.Message != "" {
|
|
return e.Response.Error.Message
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) {
|
|
ErrorWithStatusCode = &openai.ErrorWithStatusCode{
|
|
StatusCode: resp.StatusCode,
|
|
Error: openai.Error{
|
|
Message: "",
|
|
Type: "upstream_error",
|
|
Code: "bad_response_status_code",
|
|
Param: strconv.Itoa(resp.StatusCode),
|
|
},
|
|
}
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = resp.Body.Close()
|
|
if err != nil {
|
|
return
|
|
}
|
|
var errResponse GeneralErrorResponse
|
|
err = json.Unmarshal(responseBody, &errResponse)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if errResponse.Error.Message != "" {
|
|
// OpenAI format error, so we override the default one
|
|
ErrorWithStatusCode.Error = errResponse.Error
|
|
} else {
|
|
ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
|
|
}
|
|
if ErrorWithStatusCode.Error.Message == "" {
|
|
ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
|
}
|
|
return
|
|
}
|
|
|
|
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
|
|
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
|
switch channelType {
|
|
case common.ChannelTypeOpenAI:
|
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
|
case common.ChannelTypeAzure:
|
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
|
}
|
|
}
|
|
return fullRequestURL
|
|
}
|
|
|
|
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
|
// quotaDelta is remaining quota to be consumed
|
|
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
|
if err != nil {
|
|
logger.SysError("error consuming token remain quota: " + err.Error())
|
|
}
|
|
err = model.CacheUpdateUserQuota(userId)
|
|
if err != nil {
|
|
logger.SysError("error update user quota cache: " + err.Error())
|
|
}
|
|
// totalQuota is total quota consumed
|
|
if totalQuota != 0 {
|
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
|
|
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
|
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
|
}
|
|
if totalQuota <= 0 {
|
|
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
|
}
|
|
}
|
|
|
|
func GetAzureAPIVersion(c *gin.Context) string {
|
|
query := c.Request.URL.Query()
|
|
apiVersion := query.Get("api-version")
|
|
if apiVersion == "" {
|
|
apiVersion = c.GetString("api_version")
|
|
}
|
|
return apiVersion
|
|
}
|