mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Enable llama4
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
39cfe232fd
commit
3482d7ca82
@ -122,5 +122,5 @@ ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
|
|||||||
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
#ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
CMD ["--json-output"]
|
#CMD ["--json-output"]
|
||||||
|
@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0
|
|||||||
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||||
|
|
||||||
image:
|
image:
|
||||||
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
|
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) --build-arg http_proxy=${http_proxy} --build-arg https_proxy=${https_proxy} --build-arg no_proxy=${no_proxy}
|
||||||
|
|
||||||
run-local-dev-container:
|
run-local-dev-container:
|
||||||
docker run -it \
|
docker run -it \
|
||||||
|
@ -57,7 +57,7 @@ def serve(
|
|||||||
), "MASTER_PORT must be set when sharded is True"
|
), "MASTER_PORT must be set when sharded is True"
|
||||||
|
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
logger.remove()
|
#logger.remove()
|
||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format="{message}",
|
format="{message}",
|
||||||
@ -193,7 +193,7 @@ def download_weights(
|
|||||||
merge_lora: bool = False,
|
merge_lora: bool = False,
|
||||||
):
|
):
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
logger.remove()
|
#logger.remove()
|
||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format="{message}",
|
format="{message}",
|
||||||
|
@ -25,6 +25,7 @@ class FastLinear(torch.nn.Module):
|
|||||||
return cls(weight, bias)
|
return cls(weight, bias)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
print(f"input.shape={input.shape}, self.weight={self.weight.shape}")
|
||||||
return F.linear(input, self.weight, self.bias)
|
return F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,9 +16,9 @@ import enum
|
|||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
#from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation_server.models.bloom import BLOOM
|
#from text_generation_server.models.bloom import BLOOM
|
||||||
from text_generation_server.models.starcoder import StarCoder
|
#from text_generation_server.models.starcoder import StarCoder
|
||||||
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
||||||
PhiMoEConfig,
|
PhiMoEConfig,
|
||||||
)
|
)
|
||||||
@ -32,7 +32,7 @@ from text_generation_server.utils.adapter import (
|
|||||||
from text_generation_server.adapters.lora import LoraWeights
|
from text_generation_server.adapters.lora import LoraWeights
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
#from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Model",
|
"Model",
|
||||||
@ -47,7 +47,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
|||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
if ATTENTION == "paged":
|
if ATTENTION == "paged":
|
||||||
FLASH_ATTENTION = True
|
FLASH_ATTENTION = True
|
||||||
|
print(f"Flash Attention enabled models: {FLASH_ATTENTION}")
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
|
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
|
||||||
@ -63,6 +63,9 @@ try:
|
|||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llama4_modeling import (
|
||||||
|
Llama4ForConditionalGeneration,
|
||||||
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||||
FlashCohereForCausalLM,
|
FlashCohereForCausalLM,
|
||||||
)
|
)
|
||||||
@ -179,6 +182,11 @@ class ModelType(enum.Enum):
|
|||||||
"name": "Llama",
|
"name": "Llama",
|
||||||
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
||||||
}
|
}
|
||||||
|
LLAMA4 = {
|
||||||
|
"type": "llama4",
|
||||||
|
"name": "Llama4",
|
||||||
|
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
||||||
|
}
|
||||||
PHI3 = {
|
PHI3 = {
|
||||||
"type": "phi3",
|
"type": "phi3",
|
||||||
"name": "Phi 3",
|
"name": "Phi 3",
|
||||||
@ -451,7 +459,9 @@ def get_model(
|
|||||||
|
|
||||||
kv_cache_dtype = dtype
|
kv_cache_dtype = dtype
|
||||||
|
|
||||||
|
print(f"Model type: {model_type}")
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
print(f"Flash Attention enabled models: {model_type}")
|
||||||
if model_type == DEEPSEEK_V2:
|
if model_type == DEEPSEEK_V2:
|
||||||
head_size = max(
|
head_size = max(
|
||||||
config_dict.get("qk_nope_dim", 128)
|
config_dict.get("qk_nope_dim", 128)
|
||||||
@ -589,6 +599,19 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif model_type == LLAMA4:
|
||||||
|
print(f"Llama4 model detected: {model_id}")
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Llama4ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
elif model_type == BAICHUAN:
|
elif model_type == BAICHUAN:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -823,6 +846,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
from text_generation_server.models.custom_modeling.mllama import (
|
from text_generation_server.models.custom_modeling.mllama import (
|
||||||
MllamaForConditionalGeneration,
|
MllamaForConditionalGeneration,
|
||||||
@ -831,12 +855,15 @@ def get_model(
|
|||||||
LlavaNextForConditionalGeneration,
|
LlavaNextForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
adapt_transformers_to_gaudi()
|
adapt_transformers_to_gaudi()
|
||||||
if SDP_ON_BF16 == 1:
|
if SDP_ON_BF16 == 1:
|
||||||
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
|
from text_generation_server.models.starcoder import StarCoder
|
||||||
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
||||||
if model_type == "bloom":
|
if model_type == "bloom":
|
||||||
|
from text_generation_server.models.bloom import BLOOM
|
||||||
return BLOOM(
|
return BLOOM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -34,6 +34,33 @@ IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
|
|||||||
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_split_image_llama4(aspect_ratio, num_patches_per_chunk):
|
||||||
|
"""
|
||||||
|
Create a structured string representation of image tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_patches: Number of patches in the image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String with appropriate image tokens
|
||||||
|
"""
|
||||||
|
img_string = "<|image_start|>"
|
||||||
|
ratio_h, ratio_w = aspect_ratio
|
||||||
|
if ratio_h * ratio_w > 1:
|
||||||
|
for yy in range(ratio_h):
|
||||||
|
for xx in range(ratio_w):
|
||||||
|
img_string += "<|patch|>" * num_patches_per_chunk
|
||||||
|
if xx < ratio_w - 1:
|
||||||
|
img_string += "<|tile_x_separator|>"
|
||||||
|
|
||||||
|
img_string += "<|tile_y_separator|>"
|
||||||
|
img_string += "<|image|>"
|
||||||
|
img_string += "<|patch|>" * num_patches_per_chunk
|
||||||
|
img_string += "<|image_end|>"
|
||||||
|
|
||||||
|
return img_string
|
||||||
|
|
||||||
|
|
||||||
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
|
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
|
||||||
def _prompt_split_image(
|
def _prompt_split_image(
|
||||||
*,
|
*,
|
||||||
@ -139,6 +166,23 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
num_pads = 256
|
num_pads = 256
|
||||||
padding = "<image_soft_token>" * num_pads
|
padding = "<image_soft_token>" * num_pads
|
||||||
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
||||||
|
elif config.model_type == "llama4":
|
||||||
|
patch_size = config.vision_config.patch_size
|
||||||
|
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
||||||
|
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
||||||
|
aspect_ratios = image_input["aspect_ratios"][image_id]
|
||||||
|
image_height, image_width = image_input["pixel_values"][image_id].shape[-2:]
|
||||||
|
|
||||||
|
num_patches_per_chunk = int(
|
||||||
|
(image_height // patch_size)
|
||||||
|
* (image_width // patch_size)
|
||||||
|
// downsample_ratio
|
||||||
|
)
|
||||||
|
tokens_for_this_image = prompt_split_image_llama4(
|
||||||
|
aspect_ratios, num_patches_per_chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
return tokens_for_this_image
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
@ -257,6 +301,8 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
images.append(image)
|
images.append(image)
|
||||||
elif config.model_type == "gemma3":
|
elif config.model_type == "gemma3":
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
elif config.model_type == "llama4":
|
||||||
|
images.append(image)
|
||||||
else:
|
else:
|
||||||
images.append([image])
|
images.append([image])
|
||||||
else:
|
else:
|
||||||
|
@ -24,25 +24,25 @@ from text_generation_server.utils.adapter import AdapterInfo
|
|||||||
from text_generation_server.utils.tokens import make_tokenizer_optional
|
from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||||
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||||
|
|
||||||
try:
|
#try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
# from text_generation_server.models.vlm_causal_lm import (
|
||||||
VlmCausalLMBatch,
|
# VlmCausalLMBatch,
|
||||||
)
|
# )
|
||||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||||
FlashVlmCausalLMBatch,
|
FlashVlmCausalLMBatch,
|
||||||
)
|
)
|
||||||
|
|
||||||
VLM_BATCH_TYPES = {
|
VLM_BATCH_TYPES = {
|
||||||
PaliGemmaBatch,
|
PaliGemmaBatch,
|
||||||
VlmCausalLMBatch,
|
FlashVlmCausalLMBatch,
|
||||||
FlashVlmCausalLMBatch,
|
FlashMllamaCausalLMBatch,
|
||||||
FlashMllamaCausalLMBatch,
|
}
|
||||||
}
|
#except (ImportError, NotImplementedError):
|
||||||
except (ImportError, NotImplementedError):
|
|
||||||
# These imports can fail on CPU/Non flash.
|
# These imports can fail on CPU/Non flash.
|
||||||
VLM_BATCH_TYPES = set()
|
# print(f"importError: {ImportError}")
|
||||||
|
# VLM_BATCH_TYPES = set()
|
||||||
from text_generation_server.utils.version import (
|
from text_generation_server.utils.version import (
|
||||||
is_driver_compatible,
|
is_driver_compatible,
|
||||||
MIN_TGI_GAUDI_SYNAPSE_VERSION,
|
MIN_TGI_GAUDI_SYNAPSE_VERSION,
|
||||||
@ -110,6 +110,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
if ATTENTION == "paged":
|
if ATTENTION == "paged":
|
||||||
set_max_prefill_tokens(request.max_prefill_tokens)
|
set_max_prefill_tokens(request.max_prefill_tokens)
|
||||||
|
print(f"VLM_BATCH_TYPES: {VLM_BATCH_TYPES}")
|
||||||
if (
|
if (
|
||||||
self.model.batch_type in VLM_BATCH_TYPES
|
self.model.batch_type in VLM_BATCH_TYPES
|
||||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
|
|
||||||
def get_hpu_free_memory(device, memory_fraction):
|
def get_hpu_free_memory(device, memory_fraction):
|
||||||
@ -7,7 +8,7 @@ def get_hpu_free_memory(device, memory_fraction):
|
|||||||
|
|
||||||
device_id = device.index
|
device_id = device.index
|
||||||
mem_stats = memory_stats(device_id)
|
mem_stats = memory_stats(device_id)
|
||||||
logger.info(f"mem_stats: {mem_stats}")
|
log_master(logger.debug, f"mem_stats: {mem_stats}")
|
||||||
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
|
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
|
||||||
free_memory = max(
|
free_memory = max(
|
||||||
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
|
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
|
||||||
|
@ -1,5 +1,17 @@
|
|||||||
from optimum.habana.utils import get_driver_version
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
from packaging import version
|
||||||
|
import subprocess
|
||||||
|
def get_driver_version():
|
||||||
|
"""
|
||||||
|
Returns the driver version.
|
||||||
|
"""
|
||||||
|
# Enable console printing for `hl-smi` check
|
||||||
|
output = subprocess.run(
|
||||||
|
"hl-smi", shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env={"ENABLE_CONSOLE": "true"}
|
||||||
|
)
|
||||||
|
if output.returncode == 0 and output.stdout:
|
||||||
|
return version.parse(output.stdout.split("\n")[2].replace(" ", "").split(":")[1][:-1].split("-")[0])
|
||||||
|
return None
|
||||||
|
|
||||||
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")
|
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user