2024-01-14 11:21:03 +00:00
|
|
|
package util
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
2024-01-28 11:38:58 +00:00
|
|
|
"github.com/songquanpeng/one-api/common"
|
|
|
|
"github.com/songquanpeng/one-api/common/config"
|
|
|
|
"github.com/songquanpeng/one-api/common/logger"
|
|
|
|
"github.com/songquanpeng/one-api/model"
|
2024-02-17 16:15:31 +00:00
|
|
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
2024-01-14 11:21:03 +00:00
|
|
|
"io"
|
|
|
|
"net/http"
|
|
|
|
"strconv"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
)
|
|
|
|
|
2024-02-17 16:15:31 +00:00
|
|
|
func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
|
2024-01-21 15:21:42 +00:00
|
|
|
if !config.AutomaticDisableChannelEnabled {
|
2024-01-14 11:21:03 +00:00
|
|
|
return false
|
|
|
|
}
|
|
|
|
if err == nil {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
if statusCode == http.StatusUnauthorized {
|
|
|
|
return true
|
|
|
|
}
|
2024-03-10 12:39:55 +00:00
|
|
|
switch err.Type {
|
|
|
|
case "insufficient_quota":
|
|
|
|
return true
|
|
|
|
// https://docs.anthropic.com/claude/reference/errors
|
|
|
|
case "authentication_error":
|
|
|
|
return true
|
|
|
|
case "permission_error":
|
|
|
|
return true
|
2024-03-13 11:11:30 +00:00
|
|
|
case "forbidden":
|
|
|
|
return true
|
2024-03-10 12:39:55 +00:00
|
|
|
}
|
|
|
|
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
2024-01-14 11:21:03 +00:00
|
|
|
return true
|
|
|
|
}
|
2024-03-13 11:11:30 +00:00
|
|
|
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
|
|
|
|
return true
|
|
|
|
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
|
|
|
|
return true
|
|
|
|
}
|
2024-01-14 11:21:03 +00:00
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
2024-02-17 16:15:31 +00:00
|
|
|
func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool {
|
2024-01-21 15:21:42 +00:00
|
|
|
if !config.AutomaticEnableChannelEnabled {
|
2024-01-14 11:21:03 +00:00
|
|
|
return false
|
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
if openAIErr != nil {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
|
|
|
type GeneralErrorResponse struct {
|
2024-02-17 16:15:31 +00:00
|
|
|
Error relaymodel.Error `json:"error"`
|
|
|
|
Message string `json:"message"`
|
|
|
|
Msg string `json:"msg"`
|
|
|
|
Err string `json:"err"`
|
|
|
|
ErrorMsg string `json:"error_msg"`
|
2024-01-14 11:21:03 +00:00
|
|
|
Header struct {
|
|
|
|
Message string `json:"message"`
|
|
|
|
} `json:"header"`
|
|
|
|
Response struct {
|
|
|
|
Error struct {
|
|
|
|
Message string `json:"message"`
|
|
|
|
} `json:"error"`
|
|
|
|
} `json:"response"`
|
|
|
|
}
|
|
|
|
|
|
|
|
func (e GeneralErrorResponse) ToMessage() string {
|
|
|
|
if e.Error.Message != "" {
|
|
|
|
return e.Error.Message
|
|
|
|
}
|
|
|
|
if e.Message != "" {
|
|
|
|
return e.Message
|
|
|
|
}
|
|
|
|
if e.Msg != "" {
|
|
|
|
return e.Msg
|
|
|
|
}
|
|
|
|
if e.Err != "" {
|
|
|
|
return e.Err
|
|
|
|
}
|
|
|
|
if e.ErrorMsg != "" {
|
|
|
|
return e.ErrorMsg
|
|
|
|
}
|
|
|
|
if e.Header.Message != "" {
|
|
|
|
return e.Header.Message
|
|
|
|
}
|
|
|
|
if e.Response.Error.Message != "" {
|
|
|
|
return e.Response.Error.Message
|
|
|
|
}
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
|
2024-02-17 16:15:31 +00:00
|
|
|
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) {
|
|
|
|
ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{
|
2024-01-14 11:21:03 +00:00
|
|
|
StatusCode: resp.StatusCode,
|
2024-02-17 16:15:31 +00:00
|
|
|
Error: relaymodel.Error{
|
2024-01-14 11:21:03 +00:00
|
|
|
Message: "",
|
|
|
|
Type: "upstream_error",
|
|
|
|
Code: "bad_response_status_code",
|
|
|
|
Param: strconv.Itoa(resp.StatusCode),
|
|
|
|
},
|
|
|
|
}
|
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
2024-03-10 12:39:55 +00:00
|
|
|
if config.DebugEnabled {
|
|
|
|
logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody)))
|
|
|
|
}
|
2024-01-14 11:21:03 +00:00
|
|
|
err = resp.Body.Close()
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
var errResponse GeneralErrorResponse
|
|
|
|
err = json.Unmarshal(responseBody, &errResponse)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if errResponse.Error.Message != "" {
|
|
|
|
// OpenAI format error, so we override the default one
|
|
|
|
ErrorWithStatusCode.Error = errResponse.Error
|
|
|
|
} else {
|
|
|
|
ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
|
|
|
|
}
|
|
|
|
if ErrorWithStatusCode.Error.Message == "" {
|
|
|
|
ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
|
|
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
|
|
|
|
|
|
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
|
|
|
switch channelType {
|
|
|
|
case common.ChannelTypeOpenAI:
|
|
|
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
|
|
|
case common.ChannelTypeAzure:
|
|
|
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return fullRequestURL
|
|
|
|
}
|
|
|
|
|
|
|
|
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
|
|
|
// quotaDelta is remaining quota to be consumed
|
|
|
|
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
|
|
|
if err != nil {
|
2024-01-21 15:21:42 +00:00
|
|
|
logger.SysError("error consuming token remain quota: " + err.Error())
|
2024-01-14 11:21:03 +00:00
|
|
|
}
|
2024-03-13 11:38:44 +00:00
|
|
|
err = model.CacheUpdateUserQuota(ctx, userId)
|
2024-01-14 11:21:03 +00:00
|
|
|
if err != nil {
|
2024-01-21 15:21:42 +00:00
|
|
|
logger.SysError("error update user quota cache: " + err.Error())
|
2024-01-14 11:21:03 +00:00
|
|
|
}
|
|
|
|
// totalQuota is total quota consumed
|
|
|
|
if totalQuota != 0 {
|
|
|
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
|
|
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
|
|
|
|
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
|
|
|
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
|
|
|
}
|
|
|
|
if totalQuota <= 0 {
|
2024-01-21 15:21:42 +00:00
|
|
|
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
2024-01-14 11:21:03 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-01-21 15:21:42 +00:00
|
|
|
func GetAzureAPIVersion(c *gin.Context) string {
|
2024-01-14 11:21:03 +00:00
|
|
|
query := c.Request.URL.Query()
|
|
|
|
apiVersion := query.Get("api-version")
|
|
|
|
if apiVersion == "" {
|
2024-02-17 18:22:50 +00:00
|
|
|
apiVersion = c.GetString(common.ConfigKeyAPIVersion)
|
2024-01-14 11:21:03 +00:00
|
|
|
}
|
|
|
|
return apiVersion
|
|
|
|
}
|