mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-24 04:22:10 +00:00
Enable Llama4 for Gaudi backend (#3223)
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
7e531f413d
commit
18cbecfb38
@ -16,9 +16,6 @@ import enum
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.bloom import BLOOM
|
||||
from text_generation_server.models.starcoder import StarCoder
|
||||
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
||||
PhiMoEConfig,
|
||||
)
|
||||
@ -32,7 +29,6 @@ from text_generation_server.utils.adapter import (
|
||||
from text_generation_server.adapters.lora import LoraWeights
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
|
||||
__all__ = [
|
||||
"Model",
|
||||
@ -42,6 +38,7 @@ __all__ = [
|
||||
]
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
|
||||
VLM_BATCH_TYPES = set()
|
||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
|
||||
FLASH_ATTENTION = False
|
||||
@ -63,6 +60,9 @@ try:
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_llama4_modeling import (
|
||||
Llama4ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||
FlashCohereForCausalLM,
|
||||
)
|
||||
@ -140,10 +140,24 @@ except ImportError as e:
|
||||
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||
SUPPORTS_WINDOWING = False
|
||||
FLASH_ATTENTION = False
|
||||
VLM_BATCH_TYPES = set()
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashCausalLM)
|
||||
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLMBatch,
|
||||
)
|
||||
|
||||
VLM_BATCH_TYPES = {
|
||||
PaliGemmaBatch,
|
||||
FlashVlmCausalLMBatch,
|
||||
FlashMllamaCausalLMBatch,
|
||||
}
|
||||
|
||||
|
||||
__all__.append(VLM_BATCH_TYPES)
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
DEEPSEEK_V2 = {
|
||||
@ -179,6 +193,11 @@ class ModelType(enum.Enum):
|
||||
"name": "Llama",
|
||||
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
||||
}
|
||||
LLAMA4 = {
|
||||
"type": "llama4",
|
||||
"name": "Llama4",
|
||||
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
||||
}
|
||||
PHI3 = {
|
||||
"type": "phi3",
|
||||
"name": "Phi 3",
|
||||
@ -589,6 +608,19 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == LLAMA4:
|
||||
print(f"Llama4 model detected: {model_id}")
|
||||
return FlashVlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Llama4ForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == BAICHUAN:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -823,6 +855,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.custom_modeling.mllama import (
|
||||
MllamaForConditionalGeneration,
|
||||
@ -830,13 +863,24 @@ def get_model(
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLMBatch,
|
||||
)
|
||||
|
||||
VLM_BATCH_TYPES.add(VlmCausalLMBatch)
|
||||
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
|
||||
adapt_transformers_to_gaudi()
|
||||
if SDP_ON_BF16 == 1:
|
||||
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||
if model_type == "gpt_bigcode":
|
||||
from text_generation_server.models.starcoder import StarCoder
|
||||
|
||||
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
||||
if model_type == "bloom":
|
||||
from text_generation_server.models.bloom import BLOOM
|
||||
|
||||
return BLOOM(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -37,6 +37,33 @@ IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
|
||||
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
||||
|
||||
|
||||
def prompt_split_image_llama4(aspect_ratio, num_patches_per_chunk):
|
||||
"""
|
||||
Create a structured string representation of image tokens
|
||||
|
||||
Args:
|
||||
num_patches: Number of patches in the image
|
||||
|
||||
Returns:
|
||||
String with appropriate image tokens
|
||||
"""
|
||||
img_string = "<|image_start|>"
|
||||
ratio_h, ratio_w = aspect_ratio
|
||||
if ratio_h * ratio_w > 1:
|
||||
for yy in range(ratio_h):
|
||||
for xx in range(ratio_w):
|
||||
img_string += "<|patch|>" * num_patches_per_chunk
|
||||
if xx < ratio_w - 1:
|
||||
img_string += "<|tile_x_separator|>"
|
||||
|
||||
img_string += "<|tile_y_separator|>"
|
||||
img_string += "<|image|>"
|
||||
img_string += "<|patch|>" * num_patches_per_chunk
|
||||
img_string += "<|image_end|>"
|
||||
|
||||
return img_string
|
||||
|
||||
|
||||
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
|
||||
def _prompt_split_image(
|
||||
*,
|
||||
@ -142,6 +169,23 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
num_pads = 256
|
||||
padding = "<image_soft_token>" * num_pads
|
||||
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
||||
elif config.model_type == "llama4":
|
||||
patch_size = config.vision_config.patch_size
|
||||
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
||||
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
||||
aspect_ratios = image_input["aspect_ratios"][image_id]
|
||||
image_height, image_width = image_input["pixel_values"][image_id].shape[-2:]
|
||||
|
||||
num_patches_per_chunk = int(
|
||||
(image_height // patch_size)
|
||||
* (image_width // patch_size)
|
||||
// downsample_ratio
|
||||
)
|
||||
tokens_for_this_image = prompt_split_image_llama4(
|
||||
aspect_ratios, num_patches_per_chunk
|
||||
)
|
||||
|
||||
return tokens_for_this_image
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
@ -260,6 +304,8 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
images.append(image)
|
||||
elif config.model_type == "gemma3":
|
||||
images.append(image)
|
||||
elif config.model_type == "llama4":
|
||||
images.append(image)
|
||||
else:
|
||||
images.append([image])
|
||||
else:
|
||||
|
@ -23,26 +23,8 @@ from text_generation_server.models.globals import set_adapter_to_index
|
||||
from text_generation_server.utils.adapter import AdapterInfo
|
||||
from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||
from text_generation_server.models import VLM_BATCH_TYPES
|
||||
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLMBatch,
|
||||
)
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLMBatch,
|
||||
)
|
||||
|
||||
VLM_BATCH_TYPES = {
|
||||
PaliGemmaBatch,
|
||||
VlmCausalLMBatch,
|
||||
FlashVlmCausalLMBatch,
|
||||
FlashMllamaCausalLMBatch,
|
||||
}
|
||||
except (ImportError, NotImplementedError):
|
||||
# These imports can fail on CPU/Non flash.
|
||||
VLM_BATCH_TYPES = set()
|
||||
from text_generation_server.utils.version import (
|
||||
is_driver_compatible,
|
||||
MIN_TGI_GAUDI_SYNAPSE_VERSION,
|
||||
|
@ -1,5 +1,30 @@
|
||||
from optimum.habana.utils import get_driver_version
|
||||
from packaging.version import Version
|
||||
from packaging import version
|
||||
import subprocess
|
||||
|
||||
|
||||
def get_driver_version():
|
||||
"""
|
||||
Returns the driver version.
|
||||
"""
|
||||
# Enable console printing for `hl-smi` check
|
||||
output = subprocess.run(
|
||||
"hl-smi",
|
||||
shell=True,
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env={"ENABLE_CONSOLE": "true"},
|
||||
)
|
||||
if output.returncode == 0 and output.stdout:
|
||||
return version.parse(
|
||||
output.stdout.split("\n")[2]
|
||||
.replace(" ", "")
|
||||
.split(":")[1][:-1]
|
||||
.split("-")[0]
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")
|
||||
|
||||
|
@ -303,7 +303,7 @@ class Weights:
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
tensors = []
|
||||
tensors_slices = []
|
||||
block_offset = 0
|
||||
for block_size in block_sizes:
|
||||
assert (
|
||||
@ -312,15 +312,18 @@ class Weights:
|
||||
shard_block_size = block_size // world_size
|
||||
start = rank * shard_block_size
|
||||
stop = (rank + 1) * shard_block_size
|
||||
if dim == 0:
|
||||
tensor = slice_[block_offset + start : block_offset + stop]
|
||||
elif dim == 1:
|
||||
tensor = slice_[:, block_offset + start : block_offset + stop]
|
||||
else:
|
||||
raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
|
||||
tensors.append(tensor)
|
||||
tensors_slices += range(block_offset + start, block_offset + stop)
|
||||
block_offset += block_size
|
||||
tensor = torch.cat(tensors, dim=dim)
|
||||
|
||||
if dim == 0:
|
||||
tensor = slice_[tensors_slices, ...]
|
||||
elif dim == 1 or dim == -2:
|
||||
tensor = slice_[:, tensors_slices, ...]
|
||||
elif dim == 2 or dim == -1:
|
||||
tensor = slice_[..., tensors_slices]
|
||||
else:
|
||||
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
|
||||
|
||||
tensor = tensor.to(device=self.device)
|
||||
|
||||
# Avoid casting quantizer dtypes.
|
||||
|
@ -7,5 +7,13 @@ if [[ "$*" == *"--sharded true"* ]]; then
|
||||
echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'
|
||||
export PT_HPU_ENABLE_LAZY_COLLECTIVES=1
|
||||
fi
|
||||
# Check if ATTENTION environment variable is set to paged
|
||||
if [[ "$ATTENTION" == "paged" ]]; then
|
||||
# Check if Llama-4 is in the command line arguments
|
||||
if [[ "$*" == *"Llama-4"* ]]; then
|
||||
echo 'ATTENTION=paged and Llama-4 detected'
|
||||
pip install git+https://github.com/huggingface/transformers.git@29338949
|
||||
fi
|
||||
fi
|
||||
|
||||
text-generation-launcher $@
|
||||
|
Loading…
Reference in New Issue
Block a user