Move call to adapt_transformers_to_gaudi earlier in the code (#133)

This commit is contained in:
regisss 2024-04-26 11:07:27 +02:00 committed by GitHub
parent ae6215fcea
commit 37aabf8571
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -10,6 +10,8 @@ from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.santacoder import SantaCoder
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
# Disable gradients
torch.set_grad_enabled(False)
@ -20,6 +22,7 @@ def get_model(
revision: Optional[str],
dtype: Optional[torch.dtype] = None,
) -> Model:
adapt_transformers_to_gaudi()
config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type

View File

@ -16,7 +16,6 @@ from opentelemetry import trace
import text_generation_server.habana_quantization_env as hq_env
import habana_frameworks.torch as htorch
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
from optimum.habana.utils import HabanaProfile
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
from optimum.habana.checkpoint_utils import (
@ -572,8 +571,6 @@ class CausalLM(Model):
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
adapt_transformers_to_gaudi()
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_id,