diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index efe9b62a..ce252ba1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -10,6 +10,8 @@ from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.bloom import BLOOM from text_generation_server.models.santacoder import SantaCoder +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + # Disable gradients torch.set_grad_enabled(False) @@ -20,6 +22,7 @@ def get_model( revision: Optional[str], dtype: Optional[torch.dtype] = None, ) -> Model: + adapt_transformers_to_gaudi() config = AutoConfig.from_pretrained(model_id, revision=revision) model_type = config.model_type diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index bdc0b4c5..97a9fd6f 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -16,7 +16,6 @@ from opentelemetry import trace import text_generation_server.habana_quantization_env as hq_env import habana_frameworks.torch as htorch from habana_frameworks.torch.hpu import wrap_in_hpu_graph -from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from optimum.habana.utils import HabanaProfile from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES from optimum.habana.checkpoint_utils import ( @@ -572,8 +571,6 @@ class CausalLM(Model): revision: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): - adapt_transformers_to_gaudi() - # Create tokenizer tokenizer = AutoTokenizer.from_pretrained( model_id,