diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 2d1f418b..8fd4d2f8 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -9,7 +9,6 @@ import json from dataclasses import dataclass from opentelemetry import trace from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin -from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model @@ -582,6 +581,8 @@ class IdeficsCausalLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype