package util import ( "context" "encoding/json" "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" "io" "net/http" "strconv" "strings" "github.com/gin-gonic/gin" ) func ShouldDisableChannel(err *openai.Error, statusCode int) bool { if !config.AutomaticDisableChannelEnabled { return false } if err == nil { return false } if statusCode == http.StatusUnauthorized { return true } if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { return true } return false } func ShouldEnableChannel(err error, openAIErr *openai.Error) bool { if !config.AutomaticEnableChannelEnabled { return false } if err != nil { return false } if openAIErr != nil { return false } return true } type GeneralErrorResponse struct { Error openai.Error `json:"error"` Message string `json:"message"` Msg string `json:"msg"` Err string `json:"err"` ErrorMsg string `json:"error_msg"` Header struct { Message string `json:"message"` } `json:"header"` Response struct { Error struct { Message string `json:"message"` } `json:"error"` } `json:"response"` } func (e GeneralErrorResponse) ToMessage() string { if e.Error.Message != "" { return e.Error.Message } if e.Message != "" { return e.Message } if e.Msg != "" { return e.Msg } if e.Err != "" { return e.Err } if e.ErrorMsg != "" { return e.ErrorMsg } if e.Header.Message != "" { return e.Header.Message } if e.Response.Error.Message != "" { return e.Response.Error.Message } return "" } func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) { ErrorWithStatusCode = &openai.ErrorWithStatusCode{ StatusCode: resp.StatusCode, Error: openai.Error{ Message: "", Type: "upstream_error", Code: "bad_response_status_code", Param: strconv.Itoa(resp.StatusCode), }, } responseBody, err := io.ReadAll(resp.Body) if err != nil { return } err = resp.Body.Close() if err != nil { return } var errResponse GeneralErrorResponse err = json.Unmarshal(responseBody, &errResponse) if err != nil { return } if errResponse.Error.Message != "" { // OpenAI format error, so we override the default one ErrorWithStatusCode.Error = errResponse.Error } else { ErrorWithStatusCode.Error.Message = errResponse.ToMessage() } if ErrorWithStatusCode.Error.Message == "" { ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) } return } func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { switch channelType { case common.ChannelTypeOpenAI: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) case common.ChannelTypeAzure: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) } } return fullRequestURL } func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { // quotaDelta is remaining quota to be consumed err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } err = model.CacheUpdateUserQuota(userId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } // totalQuota is total quota consumed if totalQuota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) model.UpdateChannelUsedQuota(channelId, totalQuota) } if totalQuota <= 0 { logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) } } func GetAzureAPIVersion(c *gin.Context) string { query := c.Request.URL.Query() apiVersion := query.Get("api-version") if apiVersion == "" { apiVersion = c.GetString("api_version") } return apiVersion }