ai-gateway/relay/util/common.go

181 lines
4.8 KiB
Go
Raw Normal View History

2024-01-14 11:21:03 +00:00
package util
import (
"context"
"encoding/json"
"fmt"
2024-01-28 11:38:58 +00:00
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
2024-01-14 11:21:03 +00:00
"io"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
if !config.AutomaticDisableChannelEnabled {
2024-01-14 11:21:03 +00:00
return false
}
if err == nil {
return false
}
if statusCode == http.StatusUnauthorized {
return true
}
2024-03-10 12:39:55 +00:00
switch err.Type {
case "insufficient_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
}
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
2024-01-14 11:21:03 +00:00
return true
}
return false
}
func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool {
if !config.AutomaticEnableChannelEnabled {
2024-01-14 11:21:03 +00:00
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}
type GeneralErrorResponse struct {
Error relaymodel.Error `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
2024-01-14 11:21:03 +00:00
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) {
ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{
2024-01-14 11:21:03 +00:00
StatusCode: resp.StatusCode,
Error: relaymodel.Error{
2024-01-14 11:21:03 +00:00
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
2024-03-10 12:39:55 +00:00
if config.DebugEnabled {
logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody)))
}
2024-01-14 11:21:03 +00:00
err = resp.Body.Close()
if err != nil {
return
}
var errResponse GeneralErrorResponse
err = json.Unmarshal(responseBody, &errResponse)
if err != nil {
return
}
if errResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
ErrorWithStatusCode.Error = errResponse.Error
} else {
ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
}
if ErrorWithStatusCode.Error.Message == "" {
ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case common.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
2024-01-14 11:21:03 +00:00
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
2024-01-14 11:21:03 +00:00
}
// totalQuota is total quota consumed
if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
2024-01-14 11:21:03 +00:00
}
}
func GetAzureAPIVersion(c *gin.Context) string {
2024-01-14 11:21:03 +00:00
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString(common.ConfigKeyAPIVersion)
2024-01-14 11:21:03 +00:00
}
return apiVersion
}