mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 07:52:06 +00:00
Move call to adapt_transformers_to_gaudi
earlier in the code (#133)
This commit is contained in:
parent
ae6215fcea
commit
37aabf8571
@ -10,6 +10,8 @@ from text_generation_server.models.causal_lm import CausalLM
|
|||||||
from text_generation_server.models.bloom import BLOOM
|
from text_generation_server.models.bloom import BLOOM
|
||||||
from text_generation_server.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
|
|
||||||
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
|
|
||||||
|
|
||||||
# Disable gradients
|
# Disable gradients
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
@ -20,6 +22,7 @@ def get_model(
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
adapt_transformers_to_gaudi()
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||||
model_type = config.model_type
|
model_type = config.model_type
|
||||||
|
|
||||||
|
@ -16,7 +16,6 @@ from opentelemetry import trace
|
|||||||
import text_generation_server.habana_quantization_env as hq_env
|
import text_generation_server.habana_quantization_env as hq_env
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
|
||||||
from optimum.habana.utils import HabanaProfile
|
from optimum.habana.utils import HabanaProfile
|
||||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||||
from optimum.habana.checkpoint_utils import (
|
from optimum.habana.checkpoint_utils import (
|
||||||
@ -572,8 +571,6 @@ class CausalLM(Model):
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
adapt_transformers_to_gaudi()
|
|
||||||
|
|
||||||
# Create tokenizer
|
# Create tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user