mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-16 22:32:07 +00:00
Only Apply the TP in language_model (#219)
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
73d93bdd93
commit
2299b739fe
@ -46,6 +46,7 @@ import habana_frameworks.torch as htorch
|
|||||||
from optimum.habana.utils import HabanaProfile
|
from optimum.habana.utils import HabanaProfile
|
||||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||||
from optimum.habana.utils import get_hpu_memory_stats
|
from optimum.habana.utils import get_hpu_memory_stats
|
||||||
|
from optimum.habana.checkpoint_utils import get_ds_injection_policy
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@ -177,7 +178,6 @@ def load_data_uri(image_uri: str) -> Image.Image:
|
|||||||
image = Image.open(BytesIO(content))
|
image = Image.open(BytesIO(content))
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(CausalLMBatch):
|
class VlmCausalLMBatch(CausalLMBatch):
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
@ -682,6 +682,7 @@ class VlmCausalLM(Model):
|
|||||||
ds_inference_kwargs = {"dtype": dtype}
|
ds_inference_kwargs = {"dtype": dtype}
|
||||||
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
||||||
ds_inference_kwargs["enable_cuda_graph"] = False
|
ds_inference_kwargs["enable_cuda_graph"] = False
|
||||||
|
ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(model.language_model.config)
|
||||||
|
|
||||||
if load_to_meta:
|
if load_to_meta:
|
||||||
# model loaded to meta is managed differently
|
# model loaded to meta is managed differently
|
||||||
|
Loading…
Reference in New Issue
Block a user