Load later to make tests pass.

This commit is contained in:
Nicolas Patry 2023-08-15 14:50:35 +00:00
parent 5469316ed8
commit 4ff509948a

View File

@ -9,7 +9,6 @@ import json
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin
from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
@ -582,6 +581,8 @@ class IdeficsCausalLM(Model):
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype