mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Only Apply the TP in language_model (#219)
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
73d93bdd93
commit
2299b739fe
@ -46,6 +46,7 @@ import habana_frameworks.torch as htorch
|
||||
from optimum.habana.utils import HabanaProfile
|
||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||
from optimum.habana.utils import get_hpu_memory_stats
|
||||
from optimum.habana.checkpoint_utils import get_ds_injection_policy
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@ -177,7 +178,6 @@ def load_data_uri(image_uri: str) -> Image.Image:
|
||||
image = Image.open(BytesIO(content))
|
||||
return image
|
||||
|
||||
|
||||
class VlmCausalLMBatch(CausalLMBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
@ -682,6 +682,7 @@ class VlmCausalLM(Model):
|
||||
ds_inference_kwargs = {"dtype": dtype}
|
||||
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
||||
ds_inference_kwargs["enable_cuda_graph"] = False
|
||||
ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(model.language_model.config)
|
||||
|
||||
if load_to_meta:
|
||||
# model loaded to meta is managed differently
|
||||
|
Loading…
Reference in New Issue
Block a user