diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 632f38cb..93cb179a 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -46,6 +46,7 @@ import habana_frameworks.torch as htorch from optimum.habana.utils import HabanaProfile from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES from optimum.habana.utils import get_hpu_memory_stats +from optimum.habana.checkpoint_utils import get_ds_injection_policy from transformers import ( AutoTokenizer, @@ -177,7 +178,6 @@ def load_data_uri(image_uri: str) -> Image.Image: image = Image.open(BytesIO(content)) return image - class VlmCausalLMBatch(CausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] @@ -682,6 +682,7 @@ class VlmCausalLM(Model): ds_inference_kwargs = {"dtype": dtype} ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} ds_inference_kwargs["enable_cuda_graph"] = False + ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(model.language_model.config) if load_to_meta: # model loaded to meta is managed differently