Load later to make tests pass.

This commit is contained in:
Nicolas Patry 2023-08-15 14:50:35 +00:00
parent 5469316ed8
commit 4ff509948a

View File

@ -9,7 +9,6 @@ import json
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin
from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
@ -582,6 +581,8 @@ class IdeficsCausalLM(Model):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype