package controller import ( "bytes" "context" "encoding/json" "errors" "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strings" "github.com/gin-gonic/gin" ) func isWithinRange(element string, value int) bool { if _, ok := constant.DalleGenerationImageAmounts[element]; !ok { return false } min := constant.DalleGenerationImageAmounts[element][0] max := constant.DalleGenerationImageAmounts[element][1] return value >= min && value <= max } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := util.GetRelayMeta(c) imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } // map model name var isModelMapped bool meta.OriginModelName = imageRequest.Model imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping) meta.ActualModelName = imageRequest.Model // model validation bizErr := validateImageRequest(imageRequest, meta) if bizErr != nil { return bizErr } imageCostRatio, err := getImageCostRatio(imageRequest) if err != nil { return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) } requestURL := c.Request.URL.String() fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) if meta.ChannelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api apiVersion := util.GetAzureAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion) } var requestBody io.Reader if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } else { requestBody = c.Request.Body } modelRatio := common.GetModelRatio(imageRequest.Model) groupRatio := common.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) if userQuota-quota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } token := c.Request.Header.Get("Authorization") if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication token = strings.TrimPrefix(token, "Bearer ") req.Header.Set("api-key", token) } else { req.Header.Set("Authorization", token) } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) resp, err := util.HTTPClient.Do(req) if err != nil { return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } err = req.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } err = c.Request.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } var imageResponse openai.ImageResponse defer func(ctx context.Context) { if resp.StatusCode != http.StatusOK { return } err := model.PostConsumeTokenQuota(meta.TokenId, quota) if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } }(c.Request.Context()) responseBody, err := io.ReadAll(resp.Body) if err != nil { return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } err = json.Unmarshal(responseBody, &imageResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } return nil }