From e07056ab3f0a8a6e748bcaf766508385fcd4a7fa Mon Sep 17 00:00:00 2001 From: Yuan Wu Date: Fri, 13 Jun 2025 04:35:36 +0800 Subject: [PATCH] [Gaudi] Remove optimum-habana (#3261) Signed-off-by: yuanwu --- Dockerfile_gaudi | 2 +- backends/gaudi/server/pyproject.toml | 5 +- backends/gaudi/server/requirements.txt | 5 +- .../server/text_generation_server/cli.py | 89 +- .../habana_quantization_env.py | 53 - .../text_generation_server/models/__init__.py | 73 +- .../text_generation_server/models/bloom.py | 52 - .../models/causal_lm.py | 1444 --------------- .../models/custom_modeling/llava_next.py | 467 ----- .../models/custom_modeling/mllama.py | 292 --- .../models/custom_modeling/qwen2_5_vl.py | 3 +- .../models/galactica.py | 156 -- .../text_generation_server/models/globals.py | 4 +- .../models/idefics_causal_lm.py | 882 --------- .../text_generation_server/models/mamba.py | 814 --------- .../models/starcoder.py | 47 - .../models/vlm_causal_lm.py | 1609 ----------------- backends/gaudi/tgi-entrypoint.sh | 8 - launcher/src/env_runtime.rs | 4 - launcher/src/main.rs | 9 - 20 files changed, 23 insertions(+), 5995 deletions(-) delete mode 100644 backends/gaudi/server/text_generation_server/habana_quantization_env.py delete mode 100644 backends/gaudi/server/text_generation_server/models/bloom.py delete mode 100644 backends/gaudi/server/text_generation_server/models/causal_lm.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py delete mode 100644 backends/gaudi/server/text_generation_server/models/galactica.py delete mode 100644 backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py delete mode 100644 backends/gaudi/server/text_generation_server/models/mamba.py delete mode 100644 backends/gaudi/server/text_generation_server/models/starcoder.py delete mode 100644 backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 442eb6b7..02885405 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -57,7 +57,7 @@ ARG PYTORCH_VERSION FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base -ENV ATTENTION=default +ENV ATTENTION=paged ENV PREFIX_CACHING=0 ENV PREFILL_CHUNKING=0 ENV PT_HPU_LAZY_MODE=1 diff --git a/backends/gaudi/server/pyproject.toml b/backends/gaudi/server/pyproject.toml index 3f2676cb..fa2c2697 100644 --- a/backends/gaudi/server/pyproject.toml +++ b/backends/gaudi/server/pyproject.toml @@ -22,10 +22,9 @@ opentelemetry-instrumentation-grpc = "^0.53b0" hf-transfer = "^0.1.9" sentencepiece = "^0.2.0" peft = "^0.15" -optimum-habana = "1.17" -transformers = "^4.49" +transformers = "^4.52.4" numpy = "^1.26" -accelerate = "^0.33" +accelerate = "^1.7.0" outlines= { version = "^0.0.36", optional = true } prometheus-client = "^0.21.1" py-cpuinfo = "^9.0.0" diff --git a/backends/gaudi/server/requirements.txt b/backends/gaudi/server/requirements.txt index 6f897722..e6c9abf2 100644 --- a/backends/gaudi/server/requirements.txt +++ b/backends/gaudi/server/requirements.txt @@ -1,4 +1,4 @@ -accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13" +accelerate==1.7.0 ; python_version >= "3.9" and python_version < "3.13" annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13" attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13" certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13" @@ -46,7 +46,6 @@ opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_versi opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13" -optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13" optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13" outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13" packaging==24.2 ; python_version >= "3.9" and python_version < "3.13" @@ -76,7 +75,7 @@ sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.52.4 ; python_version >= "3.9" and python_version < "3.13" triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index d4445a13..dc31ab2f 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -1,6 +1,4 @@ import os -import psutil -import signal import sys import typer @@ -115,80 +113,19 @@ def serve( raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) - - logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) - - if sharded and os.getenv("ATTENTION", "default") not in {"paged"}: - tgi_file = Path(__file__).resolve().parent / "tgi_service.py" - num_shard = int(os.getenv("WORLD_SIZE", "1")) - logger.info("CLI SHARDED = {}".format(num_shard)) - import subprocess - - cmd = ( - f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}" - ) - cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}" - cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}" - cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}" - if speculate is not None: - cmd += f"--speculate {speculate}" - logger.info("CLI server start deepspeed ={} ".format(cmd)) - sys.stdout.flush() - sys.stderr.flush() - with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: - do_terminate = False - current_handler = signal.getsignal(signal.SIGTERM) - - def terminate_handler(sig, frame): - nonlocal do_terminate - do_terminate = True - if callable(current_handler): - current_handler(sig, frame) - - signal.signal(signal.SIGTERM, terminate_handler) - - finished = False - while not finished: - try: - if do_terminate: - parent = psutil.Process(proc.pid) - all_procs = parent.children(recursive=True) + [parent] - for p in all_procs: - try: - p.terminate() - except psutil.NoSuchProcess: - pass - _, alive = psutil.wait_procs(all_procs, timeout=30) - for p in alive: - p.kill() - - do_terminate = False - - proc.wait(timeout=3) - except subprocess.TimeoutExpired: - pass - else: - finished = True - - sys.stdout.flush() - sys.stderr.flush() - if proc.returncode != 0: - logger.error(f"{cmd} exited with status = {proc.returncode}") - return proc.returncode - else: - server.serve( - model_id, - lora_adapters, - revision, - sharded, - quantize, - speculate, - dtype, - kv_cache_dtype, - trust_remote_code, - uds_path, - max_input_tokens, - ) + server.serve( + model_id, + lora_adapters, + revision, + sharded, + quantize, + speculate, + dtype, + kv_cache_dtype, + trust_remote_code, + uds_path, + max_input_tokens, + ) @app.command() diff --git a/backends/gaudi/server/text_generation_server/habana_quantization_env.py b/backends/gaudi/server/text_generation_server/habana_quantization_env.py deleted file mode 100644 index b03b7e26..00000000 --- a/backends/gaudi/server/text_generation_server/habana_quantization_env.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import os -import habana_frameworks.torch as htorch - -quant_config = os.getenv("QUANT_CONFIG", "") -is_quantization_enabled = quant_config != "" - -if is_quantization_enabled: - os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true") - os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true") - os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false") - os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false") - os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") - os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") - - -def patch_scoped_linear_all_reduce(model): - from deepspeed.module_inject.layers import LinearAllreduce - from optimum.habana.transformers.models.modeling_all_models import ( - ScopedLinearAllReduce, - ) - - for name, module in model.named_children(): - if type(module) is LinearAllreduce: - SL = ScopedLinearAllReduce(mod=module) - setattr(model, name, SL) - patch_scoped_linear_all_reduce(module) - - -def setup_quantization(model): - if is_quantization_enabled: - htorch.core.quantization._mark_params_as_const(model) - htorch.core.quantization._check_params_as_const(model) - htorch.core.hpu_initialize(model) - return model - - -def prepare_model_for_quantization(model): - if is_quantization_enabled: - if model.config.model_type in [ - "llama", - "falcon", - "qwen2", - "starcoder2", - "gemma", - ]: - patch_scoped_linear_all_reduce(model) - from neural_compressor.torch.quantization import FP8Config, convert - - config = FP8Config.from_json_file(quant_config) - model = convert(model, config) - return model diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 18396e8d..c4943463 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -5,7 +5,6 @@ import os from loguru import logger from transformers.configuration_utils import PretrainedConfig -from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi from typing import Optional from pathlib import Path @@ -36,14 +35,10 @@ __all__ = [ "Seq2SeqLM", "get_model_with_lora_adapters", ] -from text_generation_server.models.globals import ATTENTION VLM_BATCH_TYPES = set() -FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." -FLASH_ATTENTION = False -if ATTENTION == "paged": - FLASH_ATTENTION = True +FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM @@ -883,72 +878,6 @@ def get_model( trust_remote_code=trust_remote_code, ) - from text_generation_server.models.causal_lm import CausalLM - from text_generation_server.models.vlm_causal_lm import VlmCausalLM - from text_generation_server.models.custom_modeling.mllama import ( - MllamaForConditionalGeneration, - ) - from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, - ) - from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLMBatch, - ) - - VLM_BATCH_TYPES.add(VlmCausalLMBatch) - - from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi - - adapt_transformers_to_gaudi() - if SDP_ON_BF16 == 1: - torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) - if model_type == "gpt_bigcode": - from text_generation_server.models.starcoder import StarCoder - - return StarCoder(model_id=model_id, revision=revision, dtype=dtype) - if model_type == "bloom": - from text_generation_server.models.bloom import BLOOM - - return BLOOM( - model_id=model_id, - revision=revision, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == "llava_next": - return VlmCausalLM( - model_class=LlavaNextForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=None, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == "mllama": - return VlmCausalLM( - model_class=MllamaForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=None, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - raise ValueError(f"Unsupported model type {model_type}") diff --git a/backends/gaudi/server/text_generation_server/models/bloom.py b/backends/gaudi/server/text_generation_server/models/bloom.py deleted file mode 100644 index 6fe64374..00000000 --- a/backends/gaudi/server/text_generation_server/models/bloom.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import torch - -from typing import Optional, Type - -from transformers import PreTrainedTokenizerBase - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.pb import generate_pb2 - - -class BloomCausalLMBatch(CausalLMBatch): - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "CausalLMBatch": - batch = super().from_pb( - pb=pb, - tokenizer=tokenizer, - dtype=dtype, - device=device, - ) - batch.keys_head_dim_last = False - return batch - - -class BLOOM(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(BLOOM, self).__init__( - model_id=model_id, - revision=revision, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return BloomCausalLMBatch diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py deleted file mode 100644 index dd6e070d..00000000 --- a/backends/gaudi/server/text_generation_server/models/causal_lm.py +++ /dev/null @@ -1,1444 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import bisect -from dataclasses import dataclass -from functools import wraps -import itertools -import json -import math -import os -import tempfile -import time -import copy -from typing import Dict, List, Optional, Tuple, Type - -import torch -import torch._dynamo -from loguru import logger -from opentelemetry import trace - -import text_generation_server.habana_quantization_env as hq_env -from text_generation_server.utils import weight_files -import habana_frameworks.torch as htorch -from optimum.habana.utils import HabanaProfile -from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES -from text_generation_server.utils.chunks import concat_text_chunks -from optimum.habana.checkpoint_utils import model_on_meta -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - PreTrainedTokenizerBase, - AutoConfig, -) - -from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.models import Model -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - HeterogeneousNextTokenChooser, - StoppingCriteria, - is_tokenizer_transparent, - pad_next_token_chooser_parameters, -) -from optimum.habana.utils import get_hpu_memory_stats -from text_generation_server.utils.debug import dbg_trace -from text_generation_server.utils.speculate import get_speculate - -tracer = trace.get_tracer(__name__) -MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256)) -CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) -BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2)) -SEQ_LEN_EXPONENT_BASE = int(os.environ.get("SEQ_LEN_EXPONENT_BASE", 2)) -MAX_BATCH_SIZE = ( - int(os.environ.get("MAX_BATCH_SIZE")) - if os.environ.get("MAX_BATCH_SIZE") is not None - else None -) - - -def torch_compile_for_eager(func): - if LAZY_MODE == 1: - return func - return torch.compile( - func, backend="hpu_backend", options={"keep_input_mutations": True} - ) - - -def round_up_seq(number, k, base): - exponent = max(0, math.ceil(math.log(number / k, base))) - return int(k * (base**exponent)) - - -def iterate_powers_of_base(max_value, start, base): - current = start - result = [] - assert ( - max_value >= start - ), f"max_value {max_value} must be greater than start {start}" - while current < max_value: - result.append(current) - current *= base - return result - - -def round_up_batch(number): - return BATCH_SIZE_EXPONENT_BASE ** ( - math.ceil(math.log(number, BATCH_SIZE_EXPONENT_BASE)) - ) - - -def to_tensor_indices(indices, device): - return torch.tensor(indices, dtype=torch.long, device=device) - - -def calculate_chunks(offset): - result = [] - while offset != 0: - sign = 1 if offset > 0 else -1 - best_chunk = min((abs(offset - sign * c), sign * c) for c in CHUNK_SIZES)[1] - result.append(best_chunk) - offset = offset - best_chunk - return result - - -def biggest_single_chunk(offset): - if offset != 0: - idx = bisect.bisect(CHUNK_SIZES, abs(offset)) - return int(math.copysign(CHUNK_SIZES[idx - 1], offset)) - else: - return 0 - - -@torch_compile_for_eager -def grouped_pad(tensor_groups, dims, values): - grouped_result = [] - for tensors, dim, value in zip(tensor_groups, dims, values): - padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0 - if padding > 0: - assert dim in [-1, -2], f"Only dims -1 and -2 are supported! {dim}" - pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding) - result = [ - torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors - ] - else: - result = [t for t in tensors] - grouped_result.append(result) - htorch.core.mark_step() - return grouped_result - - -@torch_compile_for_eager -def roll(tensor, chunk, dim, merge_graphs): - if dim is None: - return tensor - tensor = torch.roll(tensor, chunk, dim) - if not merge_graphs: - htorch.core.mark_step() - return tensor - - -def grouped_roll(tensor_groups, chunk, dims, merge_graphs): - tensor_groups = [ - [roll(t, chunk, dim, merge_graphs) for t in tensors] - for tensors, dim in zip(tensor_groups, dims) - ] - if merge_graphs: - htorch.core.mark_step() - return tensor_groups - - -@torch_compile_for_eager -def grouped_shift(tensor_groups, dims, offset, merge_graphs): - chunks = calculate_chunks(offset) - for c in chunks: - tensor_groups = grouped_roll(tensor_groups, c, dims, merge_graphs) - return tensor_groups - - -def move(dst_tensors, dst_indices, src_tensors): - bs_dim = 0 - num_indices = dst_indices.size(0) - for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)): - if src_t.size(bs_dim) != num_indices: - src_t = torch.narrow(src_t, bs_dim, 0, num_indices) - dst_t.index_copy_(bs_dim, dst_indices, src_t) - htorch.core.mark_step() - - -def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups): - for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups): - move(dst_tensors, dst_indices, src_tensors) - - -@torch_compile_for_eager -def extend_tensor(tensor, padding, dim): - result = torch.cat([tensor, padding], dim=dim) - htorch.core.mark_step() - return result - - -@torch_compile_for_eager -def extend_batch(tensors, target_bs, dim): - diff = target_bs - tensors[0].size(dim) - # TODO: add support for shrinking bs - if diff <= 0: - return tensors - shape = list(tensors[0].shape) - shape[dim] = diff - padding = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) - tensors = [extend_tensor(t, padding, dim) for t in tensors] - return tensors - - -def grouped_extend_batch(tensor_groups, target_bs, bs_dims): - tensor_groups = [ - extend_batch(tensors, target_bs, dim) - for tensors, dim in zip(tensor_groups, bs_dims) - ] - return tensor_groups - - -@torch_compile_for_eager -def merge(tensor_group): - tensor_group = [torch.stack(tensor_group)] - htorch.core.mark_step() - return tensor_group - - -@torch_compile_for_eager -def split(tensor_group, clone_data): - tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)] - if clone_data: - tensor_group = [t.clone() for t in tensor_group] - htorch.core.mark_step() - return tensor_group - - -def remove_kv_cache_from_output(module): - orig_fwd = module.forward - - @wraps(orig_fwd) - def forward(*args, **kwargs): - if kwargs["past_key_values"] is not None: - kwargs["return_dict"] = False - output = orig_fwd(*args, **kwargs) - first_value, second_value, *_ = output - if first_value.nelement() < 2: - return second_value - else: - return first_value - else: - kwargs["return_dict"] = True - return orig_fwd(*args, **kwargs) - - module.forward = forward - return module - - -@dataclass -class CausalLMRequest: - idx: int - data: generate_pb2.Request - input_length: int - prefix_offset: int - read_offset: int - stopping_criteria: StoppingCriteria - - all_input_ids: torch.Tensor - - @classmethod - def from_pb( - cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase - ): - return cls( - idx=idx, - data=data, - input_length=None, - prefix_offset=None, - read_offset=None, - stopping_criteria=StoppingCriteria.from_pb( - data.stopping_parameters, tokenizer - ), - all_input_ids=None, - ) - - def update_idx(self, new_idx): - prev = self.idx - self.idx = new_idx - return (new_idx, prev) - - -@dataclass -class CausalLMBatch(Batch): - batch_id: int - requests: List[CausalLMRequest] - - # Decoder values - input_ids: torch.Tensor - attention_mask: torch.Tensor - position_ids: torch.Tensor - past_key_values: Optional[List[Tuple]] - merged_kv_cache: bool - - # Lengths of all generations present in the batch - input_length: int - - # Generation helpers - next_token_chooser: HeterogeneousNextTokenChooser - top_n_tokens: List[int] - top_n_tokens_tensor: torch.Tensor - - input_length: int - - # Past metadata - logits = None - past = None - - keys_head_dim_last: bool = True - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.data.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) - - def detach_kv_cache(self): - past_keys = [past[0] for past in self.past_key_values] - past_values = [past[1] for past in self.past_key_values] - del self.past_key_values - return past_keys, past_values - - def attach_kv_cache(self, past_keys, past_values): - # TODO: Add support for models that don't store kv_cache in a list - self.past_key_values = list(zip(past_keys, past_values)) - - def merge_kv_cache_if_needed(self, target_bs, offset): - pad_needed = self.seq_length < MAX_TOTAL_TOKENS - shift_needed = offset != 0 - expand_needed = target_bs > self.batch_size - # Very simple heuristic to determine whether we should merge tensors - # this needs tuning for other models/scenarios - small_bs = len(self.past_key_values) > self.batch_size - if ( - not self.merged_kv_cache - and small_bs - and (pad_needed or shift_needed or expand_needed) - ): - past_keys, past_values = self.detach_kv_cache() - past_keys = merge(past_keys) - past_values = merge(past_values) - self.attach_kv_cache(past_keys, past_values) - self.merged_kv_cache = True - - def split_kv_cache_if_needed(self, clone_data): - if self.merged_kv_cache: - past_keys, past_values = self.detach_kv_cache() - past_keys = split(past_keys, clone_data) - past_values = split(past_values, clone_data) - self.attach_kv_cache(past_keys, past_values) - self.merged_kv_cache = False - - def get_tensor_groups(self): - past_keys, past_values = self.detach_kv_cache() - seq_dim = -1 - key_dim = -2 if self.keys_head_dim_last else -1 - value_dim = -2 - tensors = [ - [self.input_ids], - [self.attention_mask], - [self.position_ids], - past_keys, - past_values, - ] - # We don't need to align position_ids - seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim] - bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0]) - return tensors, seq_dims, bs_dims - - def set_tensor_groups(self, tensors): - self.input_ids = tensors.pop(0)[0] - self.attention_mask = tensors.pop(0)[0] - self.position_ids = tensors.pop(0)[0] - past_keys = tensors.pop(0) - past_values = tensors.pop(0) - self.attach_kv_cache(past_keys, past_values) - - def realign(self, target_bs, offset, pad_token_id): - tensors, seq_dims, _ = self.get_tensor_groups() - tensors = grouped_pad(tensors, seq_dims, [pad_token_id, 0, 0, 0, 0]) - tensors = grouped_shift(tensors, seq_dims, offset, self.merged_kv_cache) - self.set_tensor_groups(tensors) - - def expand_bs(self, target_bs): - tensors, _, bs_dims = self.get_tensor_groups() - tensors = grouped_extend_batch(tensors, target_bs, bs_dims) - self.set_tensor_groups(tensors) - - def used_indices(self): - return [req.idx for req in self.requests] - - def update_indices(self, new_indices): - for req, new_idx in zip(self.requests, new_indices): - req.idx = new_idx - return self.used_indices() - - def free_indices_generator(self): - used = set(req.idx for req in self.requests) - return (i for i in range(self.batch_size) if i not in used) - - def move_data(self, src_batches): - dst_tensors, _, dst_dims = self.get_tensor_groups() - free_indices_gen = self.free_indices_generator() - for src_b in src_batches: - dst_indices = to_tensor_indices( - src_b.update_indices(free_indices_gen), self.input_ids.device - ) - src_tensors, _, src_dims = src_b.get_tensor_groups() - grouped_move(dst_tensors, dst_indices, src_tensors) - self.set_tensor_groups(dst_tensors) - - @classmethod - def recombine( - cls, batches: List["CausalLMBatch"], pad_token_id: int - ) -> "CausalLMBatch": - if not all(b.past_key_values is not None for b in batches): - raise ValueError("KV cache not allocated! Cannot recombine before prefill!") - - total_requests = sum(len(b) for b in batches) - new_bs = total_requests - new_bs = round_up_batch(total_requests) - - batch_id = batches[0].batch_id - device = batches[0].input_ids.device - - input_lengths = [b.input_length for b in batches] - max_input_length = max(input_lengths) - offsets = [max_input_length - b.input_length for b in batches] - - cur_padding = [b.right_padding for b in batches] - # For prefill there is a space allocated only for first token - # Need to add padding to the max total tokens before first decode - - moves_needed = [ - total_requests - len(b) if b.batch_size == new_bs else total_requests - for b in batches - ] - dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] - reshape = batches[dst_batch_idx].batch_size < new_bs - - # TODO: Add support for changing max seq len, i.e. due to output length bucketing - # FIXME: max_seq_len for non optimized code - if len(batches) > 1: - scenario = "CONCAT" - elif reshape: - scenario = "RESHAPE" - elif cur_padding[dst_batch_idx] <= 0: - scenario = "SHIFT" - offsets = [ - biggest_single_chunk(b.max_input_length - max_input_length) - for b in batches - ] - max_input_length = max_input_length + offsets[dst_batch_idx] - else: - # Nothing to do - return batches[0] - - dbg_trace( - scenario, - f"bs:{[b.batch_size for b in batches]}->{new_bs}" - f" reqs:{[len(b) for b in batches]}" - f" offsets:{offsets}" - f" input_lengths:{input_lengths}" - f" cur_padding:{cur_padding}" - f" dst_batch:{dst_batch_idx}", - ) - - grouped_requests = [[req for req in batch.requests] for batch in batches] - flat_requests = list(itertools.chain(*grouped_requests)) - - for i in range(len(batches)): - target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size - batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) - batches[i].realign(target_bs, offsets[i], pad_token_id) - batches[i].split_kv_cache_if_needed(i == dst_batch_idx) - batches[dst_batch_idx].expand_bs(new_bs) - batches[dst_batch_idx].move_data( - [batches[i] for i in range(len(batches)) if i != dst_batch_idx] - ) - - top_n_tokens = [r.data.top_n_tokens for r in flat_requests] - top_n_tokens.extend([-1] * (new_bs - total_requests)) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - parameters = [r.data.parameters for r in flat_requests] - # append the dummy parameters for dummy requests - batch_size = batches[dst_batch_idx].batch_size - parameters = pad_next_token_chooser_parameters(parameters, batch_size) - - # update past grammar states - fsm_grammar_states = [0] * batch_size - for batch in batches: - for i, req in enumerate(batch.requests): - fsm_grammar_states[req.idx] = ( - batch.next_token_chooser.fsm_grammar_states[i] - ) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - parameters, - batches[dst_batch_idx].next_token_chooser.dtype, - batches[dst_batch_idx].next_token_chooser.device, - batches[dst_batch_idx].next_token_chooser.tokenizer, - fsm_grammar_states, - quantization_enabled=hq_env.is_quantization_enabled, - ) - - input_ids = batches[dst_batch_idx].input_ids - attention_mask = batches[dst_batch_idx].attention_mask - position_ids = batches[dst_batch_idx].position_ids - past_key_values = batches[dst_batch_idx].past_key_values - input_length = max_input_length - - htorch.core.mark_step() - - return cls( - batch_id=batch_id, - requests=flat_requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_length, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "CausalLMBatch": - dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}") - requests = [ - CausalLMRequest.from_pb(idx, req, tokenizer) - for idx, req in enumerate(pb.requests) - ] - inputs = [] - top_n_tokens = [] - - # Parse batch - max_truncation = 0 - for i, r in enumerate(pb.requests): - inputs.append(concat_text_chunks(r.input_chunks.chunks)) - top_n_tokens.append(r.top_n_tokens) - max_truncation = max(max_truncation, r.truncate) - - max_input_length = max_truncation - if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF: - max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF - max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) - - # TODO: by tokenizing all inputs at once we loose information on actual input lengths - # this means that we cannot shift inputs to the left after a long input sequence - # was filtered out - new_bs = round_up_batch(len(requests)) - missing_inputs = new_bs - len(inputs) - dummy_inputs = ["?"] * missing_inputs - parameters = [r.parameters for r in pb.requests] - # append the dummy parameters for dummy request - parameters = pad_next_token_chooser_parameters(parameters, new_bs) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - pb=parameters, - dtype=dtype, - device=device, - tokenizer=tokenizer, - quantization_enabled=hq_env.is_quantization_enabled, - ) - - tokenized_inputs = tokenizer( - inputs + dummy_inputs, - return_tensors="pt", - padding="longest", - return_token_type_ids=False, - truncation=True, - max_length=max_truncation, - ) - - input_len = tokenized_inputs["input_ids"].shape[1] - # Round up sequence length - bucket_size = max_input_length - left_padding = max_input_length - input_len - if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: - assert ( - PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length - ), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - rounded_seq_len = round_up_seq( - input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE - ) - if rounded_seq_len <= max_input_length: - bucket_size = rounded_seq_len - 1 - else: - bucket_size = max_input_length - 1 - left_padding = bucket_size - input_len - - input_ids = tokenized_inputs["input_ids"] - attention_mask = tokenized_inputs["attention_mask"] - - # Allocate space for first token - input_ids = torch.nn.functional.pad( - input_ids, (left_padding, 1), value=tokenizer.pad_token_id - ) - attention_mask = torch.nn.functional.pad( - attention_mask, (left_padding, 1), value=0 - ) - all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id - ).T.split(1, dim=1) - input_len = bucket_size - for r in requests: - r.input_length = input_len - r.prefix_offset = input_len - 5 - r.read_offset = input_len - r.all_input_ids = all_input_ids[r.idx] - - input_ids = input_ids.to(device) - attention_mask = attention_mask.to(device) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - - old_bs = len(requests) - top_n_tokens.extend([-1] * (new_bs - old_bs)) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - htorch.core.mark_step() - return cls( - batch_id=pb.id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_len, - ) - - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: - dbg_trace("FILTER", f"num_reqs:{len(self.requests)} -> {len(request_ids)}") - request_ids = set(request_ids) - self.requests = [req for req in self.requests if req.data.id in request_ids] - return self - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate( - cls, batches: List["CausalLMBatch"], pad_token_id: int = 0 - ) -> "CausalLMBatch": - return cls.recombine(batches, pad_token_id) - - def __len__(self): - return len(self.requests) - - @property - def max_input_length(self): - return max(req.input_length for req in self.requests) - - @property - def batch_size(self): - return self.attention_mask.size(0) - - @property - def seq_length(self): - return self.attention_mask.size(1) - - @property - def right_padding(self): - return self.seq_length - self.input_length - - # Maximum number of tokens this batch will grow to - @property - def max_tokens(self): - max_total_tokens = self.attention_mask.size(1) - return len(self.requests) * max_total_tokens - - -class CausalLM(Model): - def __init__( - self, - model_id: str, - model_class: Optional[Type[torch.nn.Module]] = None, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - default_dtype=torch.float16, - trust_remote_code: bool = False, - tokenizer_class=AutoTokenizer, - config_class=AutoConfig, - batch_class=CausalLMBatch, - ): - if speculator: - raise RuntimeError("Speculator decoding is not enabled for AutoModel") - - self.prev_bs = 0 - self.quantize = quantize - - # Create tokenizer - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - # Create model - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) - dtype = torch.bfloat16 if dtype is None else dtype - device = torch.device("hpu") - - if hq_env.is_quantization_enabled: - htorch.core.hpu_set_env() - - # Get weight files - weight_files(model_id, revision=revision, extension=".safetensors") - - if world_size > 1: - os.environ.setdefault( - "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1" - ) - model = self.get_deepspeed_model(model_id, dtype, revision) - model = hq_env.prepare_model_for_quantization(model) - else: - # Check support for rope scaling - model_kwargs = {} - config = AutoConfig.from_pretrained(model_id) - if hasattr(config, "rope_scaling"): - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - trust_remote_code=trust_remote_code, - **model_kwargs, - ) - model = hq_env.prepare_model_for_quantization(model) - model = model.eval().to(device) - - self.enable_hpu_graph = ( - os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 - ) - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true" - - if model.config.model_type not in [ - "gpt_bigcode" - ]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output() - model = remove_kv_cache_from_output(model) - - if self.enable_hpu_graph: - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - - model = wrap_in_hpu_graph(model, disable_tensor_cache=True) - else: - if LAZY_MODE == 0: - # It is said that "keep_input_mutations" is safe for inference to be done - dbg_trace("TORCH COMPILE", "Torch compiling of model") - model.model = torch.compile( - model.model, - backend="hpu_backend", - options={"keep_input_mutations": True}, - ) - - model = hq_env.setup_quantization(model) - - if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: - raise ValueError(f"Model type {model.config.model_type} is not supported!") - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - if isinstance(model.config.eos_token_id, int): - tokenizer.pad_token_id = model.config.eos_token_id - elif isinstance(model.config.eos_token_id, list): - tokenizer.pad_token_id = model.config.eos_token_id[0] - else: - raise ValueError( - f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id" - ) - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - self.kwargs = { - "use_cache": True, - "return_dict": True, - } - - if model.config.model_type in [ - "llama", - "mistral", - "starcoder2", - "qwen2", - "falcon", - "gpt_bigcode", - ]: - if model.config.model_type not in ["falcon", "gpt_bigcode"]: - self.kwargs["attn_softmax_bf16"] = True - - if model.config.model_type not in ["gpt_bigcode"]: - self.kwargs["trim_logits"] = True - - if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true": - self.kwargs["use_flash_attention"] = True - if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true": - self.kwargs["flash_attention_recompute"] = True - - self.speculate = get_speculate() - - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - ) - - # Create profiler - ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")] - record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" - output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") - self.profiling_warmup_steps = ( - int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_steps = ( - int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) - if self.profiling_steps > 0: - self.hb_profiler = HabanaProfile( - wait=self.profiling_wait_steps, - warmup=self.profiling_warmup_steps, - active=self.profiling_steps, - output_dir=output_dir, - record_shapes=record_shapes, - ) - self.hb_profiler.start() - else: - self.hb_profiler = None - self.step = 0 - - def get_deepspeed_model( - self, model_id: str, dtype: torch.dtype, revision: Optional[str] = None - ) -> torch.nn.Module: - import deepspeed - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - - world_size, rank, local_rank = initialize_distributed_hpu() - model_kwargs = {"revision": revision} - - # Initialize process(es) for DeepSpeed - deepspeed.init_distributed(dist_backend="hccl") - logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format( - world_size, rank, local_rank - ) - ) - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - load_to_meta = model_on_meta(config) - - # Check support for rope scaling - if hasattr(config, "rope_scaling"): - config.rope_scaling = self.get_rope_scaling() - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - if load_to_meta: - # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load - with deepspeed.OnDevice(dtype=dtype, device="meta"): - model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) - else: - # TODO: revisit placement on CPU when auto-injection is possible - with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = AutoModelForCausalLM.from_pretrained( - model_id, torch_dtype=dtype, **model_kwargs - ) - model = model.eval() - - # Initialize the model - ds_inference_kwargs = {"dtype": dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = False - - if load_to_meta: - # model loaded to meta is managed differently - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - checkpoint_files = [ - str(f) - for f in weight_files( - model_id, revision=revision, extension=".safetensors" - ) - ] - data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0} - json.dump(data, checkpoints_json) - checkpoints_json.flush() - - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - model = deepspeed.init_inference(model, **ds_inference_kwargs) - - return model.module - - def get_rope_scaling(self) -> Optional[Dict]: - rope_scaling = os.getenv("ROPE_SCALING", None) - if rope_scaling is None: - return None - - rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) - return {"type": rope_scaling, "factor": float(rope_factor)} - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return CausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - - def decode_token( - self, - all_input_ids: List[int], - prefix_offset: int = 0, - read_offset: int = 0, - ) -> Tuple[str, int, int]: - if is_tokenizer_transparent(self.tokenizer): - new_text = self.tokenizer.decode( - all_input_ids[read_offset:], skip_special_tokens=False - ) - return new_text, read_offset, len(all_input_ids) - else: - return super().decode_token(all_input_ids, prefix_offset, read_offset) - - def forward( - self, - input_ids, - attention_mask, - position_ids, - token_idx, - past_key_values: Optional[List[Tuple]] = None, - bypass_hpu_graph: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "token_idx": token_idx, - } - - # Optimum Habana got "lazy_mode" key-val only supported for llama type of models - if self.model.config.model_type == "llama": - kwargs["lazy_mode"] = LAZY_MODE == 1 - - if self.has_position_ids: - kwargs["position_ids"] = position_ids - - if bypass_hpu_graph is not None: - kwargs["bypass_hpu_graphs"] = bypass_hpu_graph - - kwargs.update(self.kwargs) - - if past_key_values is not None and self.model.config.model_type not in [ - "gpt_bigcode" - ]: - return self.model.forward(**kwargs) - else: - outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batches: List[CausalLMBatch] - ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: - start = time.time_ns() - # Results - generations: List[Generation] = [] - prev_batches = [] - requests_to_generate = [] - # In order to pipeline any actions on CPU we perform the operation in 3 main stages: - # Stage 1. Collect next token ids of any previously started generations - for batch_id, batch in enumerate(batches): - if batch.logits is not None: - logits = batch.logits - past = batch.past - prefill = batch.past_key_values is None - if prefill: - # no right padding for prefill - token_idx_scalar = batch.attention_mask.shape[-1] - 1 - token_idx = torch.tensor(token_idx_scalar).to(self.device) - else: - token_idx_scalar = ( - batch.attention_mask.shape[-1] - batch.right_padding - ) - token_idx = torch.tensor(token_idx_scalar).to(self.device) - - # Select next token - input_length = batch.input_length - if logits.shape[-2] > 1: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, - logits[:, input_length - 1 : input_length, :].squeeze(-2), - self.speculate, - ) - ) - else: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, logits.squeeze(-2), self.speculate - ) - ) - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - logprobs, - accepted_ids, - ) - - prev_batches.append( - { - "next_token_ids": next_token_ids, - "next_token_logprobs": next_token_logprobs, - } - ) - - for req_idx, req in enumerate(batch.requests): - requests_to_generate.append( - { - "req": req, - "prev_req_idx": req.idx, - "batch_id": batch_id, - "seed": batch.next_token_chooser.seeds[req_idx], - "do_sample": batch.next_token_chooser.do_sample[req_idx], - "top_n_tokens": batch.top_n_tokens[req_idx], - "top_token_ids": batch_top_token_ids[req_idx], - "top_token_logprobs": batch_top_token_logprobs[req_idx], - "grammar_state": batch.next_token_chooser.fsm_grammar_states[ - req.idx - ], - } - ) - - htorch.core.mark_step() - - # Add new token into input_ids - batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask.index_fill_(1, token_idx, 1) - - # Adjust lengths - batch.input_length += 1 - - # Update position_ids - if prefill: - batch.position_ids = ( - torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 - ) - else: - batch.position_ids += 1 - # Update past key values - if prefill or self.model.config.model_type in ["gpt_bigcode"]: - batch.past_key_values = past - - htorch.core.mark_step() - - # Stage 2. Prepare new batch for speculative scheduling - if len(batches) > 1: - batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id) - else: - batch = batches[0] - - prefill = batch.past_key_values is None - - # Check if we need to do any bookkeeping first - if not prefill: - batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) - - scenario = "PREFILL" if prefill else "GENERATE" - if ( - self.enable_hpu_graph - and self.limit_hpu_graph - and round_up_batch(batch.batch_size) != self.prev_bs - ): - self.model.clear_cache() - self.prev_bs = round_up_batch(batch.batch_size) - dbg_trace( - scenario, - f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}", - ) - assert batch.right_padding > 0, "No more room for next token!" - - # Execute batch - if prefill: - # no right padding for prefill - token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) - batch.logits, batch.past = self.forward( - batch.input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): - # Don't schedule next forward if max_new_tokens for all requests equals 1 - # - we've already generated the first and only needed token in the prefill phase - pass - else: - token_idx = torch.tensor( - batch.attention_mask.shape[-1] - batch.right_padding - ).to(self.device) - input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) - logits = self.forward( - input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - if self.model.config.model_type in ["gpt_bigcode"]: - batch.logits, batch.past = logits - else: - batch.logits = logits - - htorch.core.mark_step() - - start_decode = time.time_ns() - - # Stage 3. Finish and return previous generations - stopped = len(requests_to_generate) > 0 - for prev_batch in prev_batches: - prev_batch["next_token_logprobs"] = prev_batch[ - "next_token_logprobs" - ].tolist() - prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu() - htorch.core.mark_step() - - for req_data in requests_to_generate: - req = req_data["req"] - i = req_data["prev_req_idx"] - prev_batch_id = req_data["batch_id"] - assert len(prev_batches) > prev_batch_id - next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"] - next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"] - - request = req.data - input_length = req.input_length - prefix_offset = req.prefix_offset - read_offset = req.read_offset - do_sample = req_data["do_sample"] - seed = req_data["seed"] - stopping_criteria = req.stopping_criteria - all_input_ids = req.all_input_ids - next_token_id = next_token_ids_cpu[i] - next_token_logprob = next_token_logprobs[i] - top_n_tokens = req_data["top_n_tokens"] - top_token_ids = req_data["top_token_ids"] - top_token_logprobs = req_data["top_token_logprobs"] - grammar_state = req_data["grammar_state"] - - # Append next token to all tokens - all_input_ids[input_length] = next_token_id - new_input_length = input_length + 1 - - # Generated token - if ( - is_tokenizer_transparent(self.tokenizer) - and len(stopping_criteria.stop_sequence_criterias) == 0 - ): - next_token_text = "" - else: - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[0:new_input_length, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - if is_tokenizer_transparent(self.tokenizer): - output_text = None - else: - output_text = self.decode( - all_input_ids[ - new_input_length - - stopping_criteria.current_tokens : new_input_length, - 0, - ] - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + next_token_logprobs - prefill_token_ids = all_input_ids[0 : new_input_length - 1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id], - [next_token_logprob], - [next_token_text], - [next_token_id in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single_with_past_state( - req.idx, next_token_id, grammar_state - ) - ) - - req.all_input_ids = all_input_ids - req.input_length = new_input_length - req.prefix_offset = prefix_offset - req.read_offset = read_offset - - htorch.core.mark_step() - self.step = self.step + 1 - if self.hb_profiler is not None: - if ( - self.step - > self.profiling_wait_steps - + self.profiling_warmup_steps - + self.profiling_steps - ): - self.hb_profiler.stop() - else: - self.hb_profiler.step() - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch if not stopped else None, (forward_ns, decode_ns) - - def generate_warmup_batch(self, request, seq_len, batch_size): - batch = copy.deepcopy(request.batch) - for req in batch.requests: - req.truncate = seq_len - - for i in range(len(batch.requests) - batch_size): - batch.requests.pop() - - return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device) - - def warmup( - self, request: generate_pb2.WarmupRequest - ) -> Tuple[Optional[int], Optional[int], Optional[int]]: - assert ( - MAX_BATCH_SIZE is not None - ), "MAX_BATCH_SIZE is not set, it should be set in the launcher" - MAX_BATCH_TOTAL_TOKENS = MAX_BATCH_SIZE * request.max_total_tokens - logger.info(f"MAX_BATCH_SIZE: {MAX_BATCH_SIZE}") - logger.info(f"MAX_BATCH_TOTAL_TOKENS: {MAX_BATCH_TOTAL_TOKENS}") - MAX_TOTAL_TOKENS = request.max_total_tokens - - batch = self.batch_type.from_pb( - request.batch, self.tokenizer, self.dtype, self.device - ) - max_prefill_batch_size = batch.input_ids.shape[0] - try: - # max prefill batch size warmup - _, prefill_batch, _ = self.generate_token([batch]) - except Exception: - raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " - f"You need to decrease `--max-batch-prefill-tokens`" - ) - - del prefill_batch - - # Warmup prefill batch_size - max_input_tokens = request.max_input_tokens - max_exp = math.ceil(math.log(max_prefill_batch_size, BATCH_SIZE_EXPONENT_BASE)) - prefill_batch_size_list = [ - BATCH_SIZE_EXPONENT_BASE**exp - for exp in range( - 0, - max_exp + 1, - ) - ] - prefill_seqlen_list = iterate_powers_of_base( - max_input_tokens, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE - ) - prefill_seqlen_list.append(max_input_tokens) - prefill_batch_size_list.sort(reverse=True) - prefill_seqlen_list.sort(reverse=True) - try: - for batch_size in prefill_batch_size_list: - for seq_len in prefill_seqlen_list: - batch = self.generate_warmup_batch(request, seq_len - 1, batch_size) - _, prefill_batch, _ = self.generate_token([batch]) - except Exception: - prefill_batch_size_list.sort() - prefill_seqlen_list.sort() - raise RuntimeError( - f"Not enough memory to run following prefill batch_size." - f"Prefill batch size list:{prefill_batch_size_list}" - f"Prefill sequence length list:{prefill_seqlen_list}" - f"You need to decrease `--max-batch-prefill-tokens`" - ) - prefill_seqlen_list.sort() - prefill_batch_size_list.sort() - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing prefill warmup successfully.\n" - f"Prefill batch size list:{prefill_batch_size_list}\n" - f"Prefill sequence length list:{prefill_seqlen_list}\n" - f"Memory stats: {mem_stats} " - ) - - max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) - max_exp = math.ceil(math.log(max_decode_batch_size, BATCH_SIZE_EXPONENT_BASE)) - decode_batch_size_list = [ - BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1) - ] - decode_batch_size_list.sort(reverse=True) - - try: - for batch_size in decode_batch_size_list: - batches = [] - iters = math.floor(batch_size / max_prefill_batch_size) - for i in range(iters): - batch = self.generate_warmup_batch( - request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size - ) - _, prefill_batch, _ = self.generate_token([batch]) - batches.append(prefill_batch) - - if batch_size % max_prefill_batch_size != 0: - batch = self.generate_warmup_batch( - request, - PAD_SEQUENCE_TO_MULTIPLE_OF - 1, - batch_size % max_prefill_batch_size, - ) - _, prefill_batch, _ = self.generate_token([batch]) - batches.append(prefill_batch) - - _, decode_batch, _ = self.generate_token(batches) - _, decode_batch, _ = self.generate_token([decode_batch]) - del decode_batch - batches.clear() - - except Exception: - raise RuntimeError( - f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})." - f"You need to decrease `--max-batch-total-tokens`" - ) - - decode_batch_size_list.sort() - max_supported_total_tokens = MAX_TOTAL_TOKENS * decode_batch_size_list[-1] - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing decode warmup successfully.\n" - f"Decode batch size list:{decode_batch_size_list}\n" - f"Memory stats: {mem_stats} " - ) - - max_input_tokens = max_input_tokens - max_total_tokens = MAX_TOTAL_TOKENS - - return max_supported_total_tokens, max_input_tokens, max_total_tokens diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py deleted file mode 100644 index 00ecdf95..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py +++ /dev/null @@ -1,467 +0,0 @@ -# coding=utf-8 -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Llava-NeXT model.""" - -from typing import List, Optional, Union - -import torch -import torch.utils.checkpoint -import numpy as np - -from loguru import logger -from transformers.models.llava_next.modeling_llava_next import ( - unpad_image, -) -from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration -from transformers.image_processing_utils import select_best_resolution - - -def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): - """ - Calculate the shape of the image patch grid after the preprocessing for images of any resolution. - - Args: - image_size (`tuple`): - The size of the input image in the format (width, height). - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - tuple: The shape of the image patch grid in the format (width, height). - """ - if not isinstance(grid_pinpoints, list): - raise ValueError("grid_pinpoints should be a list of tuples or lists") - - height, width = select_best_resolution(image_size, grid_pinpoints) - return height // patch_size, width // patch_size - - -# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 -def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): - """ - Calculate the number of patches after the preprocessing for images of any resolution. - - Args: - image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): - The size of the input image in the format (height, width). ? - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - int: the number of patches - """ - if not isinstance(grid_pinpoints, list): - raise TypeError("grid_pinpoints should be a list of tuples or lists") - - # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate - if not isinstance(image_size, (list, tuple)): - if not isinstance(image_size, (torch.Tensor, np.ndarray)): - raise TypeError( - f"image_size invalid type {type(image_size)} with value {image_size}" - ) - image_size = image_size.tolist() - - best_resolution = select_best_resolution(image_size, grid_pinpoints) - height, width = best_resolution - num_patches = 0 - # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 - for i in range(0, height, patch_size): - for j in range(0, width, patch_size): - num_patches += 1 - # add the base patch - num_patches += 1 - return num_patches - - -class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): - - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - image_sizes: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[int] = None, - vision_feature_select_strategy: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = True, - flash_attention_recompute: Optional[bool] = True, - ): - - if token_idx is not None: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - token_idx=token_idx, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) - - logits = outputs[0] - - if not return_dict: - output = (logits,) + outputs[1:] - return output - - return outputs - - # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 - def pack_image_features( - self, - image_features, - image_sizes, - vision_feature_select_strategy, - image_newline=None, - ): - """ - Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. - - Args: - image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) - List of image feature tensor, each contains all the visual feature of all patches. - image_sizes (`torch.Tensor` of shape `(num_images, 2)`) - Actual image size of each images (H, W). - vision_feature_select_strategy (`str`) - The feature selection strategy used to select the vision feature from the vision backbone. - image_newline (`torch.Tensor` of shape `(embed_dim)`) - New line embedding vector. - Returns: - image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) - feature_lens (`List[int]`) - token length of each image in image_features - """ - new_image_features = [] - feature_lens = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size - ) - - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - - if ( - np.prod(image_feature.shape) - % (num_patch_height * num_patch_width * height * width) - != 0 - and vision_feature_select_strategy == "default" - ): - logger.warning_once( - "Image feature shape does not line up with the provided patch size. " - "You may be using the `default` vision_feature_select_strategy with a" - " visual encoder that does not have CLS." - ) - - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - if image_newline is not None: - image_feature = torch.cat( - ( - image_feature, - image_newline[:, None, None] - .expand(*image_feature.shape[:-1], 1) - .to(image_feature.device, image_feature.dtype), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - else: - image_feature = image_feature[0] - if image_newline is not None: - image_feature = torch.cat( - (image_feature, image_newline[None].to(image_feature)), dim=0 - ) - new_image_features.append(image_feature) - feature_lens.append(image_feature.size(0)) - image_features = torch.cat(new_image_features, dim=0) - feature_lens = torch.tensor( - feature_lens, dtype=torch.long, device=image_features.device - ) - return image_features, feature_lens - - # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 - def get_image_features( - self, - pixel_values: torch.FloatTensor, - image_sizes: torch.Tensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, - ): - """ - Obtains image last hidden states from the vision tower and apply multimodal projection. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) - The tensors corresponding to the input images. - image_sizes (`torch.Tensor` of shape `(num_images, 2)`) - Actual image size of each images (H, W). - vision_feature_layer (`Union[int, List[int]]`): - The index of the layer to select the vision feature. If multiple indices are provided, - the vision feature of the corresponding indices will be concatenated to form the - vision features. - vision_feature_select_strategy (`str`): - The feature selection strategy used to select the vision feature from the vision backbone. - Can be one of `"default"` or `"full"` - Returns: - image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches - and are of shape `(num_patches, image_length, embed_dim)`). - """ - # ! infer image_num_patches from image_sizes - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.config.image_grid_pinpoints, - patch_size=self.config.vision_config.image_size, - ) - for imsize in image_sizes - ] - if pixel_values.dim() == 5: - # stacked if input is (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [ - pix_val[:num_patch] - for pix_val, num_patch in zip(pixel_values, image_num_patches) - ] - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - # otherwise has to be stacked from list of (num_patches, num_channels, height, width) - raise ValueError( - f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions" - ) - - image_features = self.vision_tower(pixel_values, output_hidden_states=True) - # If we have one vision feature layer, return the corresponding hidden states, - # otherwise, select the hidden states of each feature layer and concatenate them - if isinstance(vision_feature_layer, int): - selected_image_feature = image_features.hidden_states[vision_feature_layer] - else: - hs_pool = [ - image_features.hidden_states[layer_idx] - for layer_idx in vision_feature_layer - ] - selected_image_feature = torch.cat(hs_pool, dim=-1) - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - - image_features = self.multi_modal_projector(selected_image_feature) - image_features = torch.split(image_features, image_num_patches, dim=0) - return image_features - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - pixel_values=None, - image_sizes=None, - attention_mask=None, - **kwargs, - ): - """ - Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 - The only differences are: - - add new args token_idx - - add the process of merging images into inputs_embeds - """ - token_idx = kwargs.get("token_idx", None) - if token_idx is None: - return super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - image_sizes=image_sizes, - attention_mask=attention_mask, - **kwargs, - ) - else: - use_flash_attention = kwargs.get("use_flash_attention", True) - flash_attention_recompute = kwargs.get("flash_attention_recompute", True) - - position_ids = kwargs.get("position_ids", None) - labels = kwargs.get("labels", None) - if ( - past_key_values is None - and pixel_values is not None - and input_ids.shape[1] != 1 - ): - vision_feature_select_strategy = kwargs.get( - "vision_feature_select_strategy", None - ) - vision_feature_layer = kwargs.get("vision_feature_layer", None) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy - ) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_feature_layer - ) - - # 1. Extract the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - # 2. Merge text and images - image_features = self.get_image_features( - pixel_values, - image_sizes, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - ) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - vision_feature_select_strategy=vision_feature_select_strategy, - image_newline=self.image_newline, - ) - - special_image_mask = ( - input_ids == self.config.image_token_index - ).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to( - inputs_embeds.device - ) - if inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - image_features = image_features.to( - inputs_embeds.device, inputs_embeds.dtype - ) - inputs_embeds = inputs_embeds.masked_scatter( - special_image_mask, image_features - ) - - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of - # generation with cache - elif past_key_values is not None: - seq_len = input_ids.shape[1] - pad_len = seq_len - token_idx.item() - input_ids = torch.index_select(input_ids, 1, token_idx - 1) - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where( - first_layer_past_key_value.float().sum(-2) == 0 - ) - # Get the target length - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = extended_attention_mask - attention_mask[:, -pad_len:] = 0 - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if token_idx is not None: - position_ids = ( - torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - ) - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "labels": labels, - "use_flash_attention": use_flash_attention, - "flash_attention_recompute": flash_attention_recompute, - } - ) - - return model_inputs diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py deleted file mode 100644 index 6ba0ffff..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py +++ /dev/null @@ -1,292 +0,0 @@ -# coding=utf-8 -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Mllama model.""" - -from typing import Optional, Tuple, List, Union - -import torch -import torch.utils.checkpoint - -from optimum.habana.transformers.models import GaudiMllamaForConditionalGeneration -from optimum.habana.transformers.models.mllama.modeling_mllama import ( - _prepare_cross_attention_mask, -) -from transformers.modeling_outputs import CausalLMOutputWithPast - - -class MllamaForConditionalGeneration(GaudiMllamaForConditionalGeneration): - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = True, - flash_attention_recompute: Optional[bool] = True, - **kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - """ - Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077 - The only differences are: - - add token_idx input - - add use_flash_attention and flash_attention_recompute - """ - full_text_row_masked_out_mask = kwargs.get( - "full_text_row_masked_out_mask", None - ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - outputs = self.language_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - use_cache=use_cache, - inputs_embeds=inputs_embeds, - labels=labels, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, - token_idx=token_idx, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) - - logits = outputs[0] - if not return_dict: - output = (logits,) + outputs[1:] - return output - - return outputs - - def prepare_inputs_for_generation( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - position_ids=None, - pixel_values=None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=None, - past_key_values=None, - use_cache=False, - cache_position=None, - num_logits_to_keep=None, - **kwargs, - ): - """ - Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208 - The only differences are: - - add token_idx handling - - add bucket_internal handling - - add use_flash_attention and flash_attention_recompute - """ - - token_idx = kwargs.get("token_idx", None) - if token_idx is None: - return super().prepare_inputs_for_generation( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - cross_attention_mask=cross_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - else: - use_flash_attention = kwargs.get("use_flash_attention", True) - flash_attention_recompute = kwargs.get("flash_attention_recompute", True) - position_ids = kwargs.get("position_ids", None) - output_attentions = kwargs.get("output_attentions", None) - output_hidden_states = kwargs.get("output_hidden_states", None) - return_dict = kwargs.get("return_dict", None) - labels = kwargs.get("labels", None) - cross_attention_states = kwargs.get("cross_attention_states", None) - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - bucket_internal = kwargs.get("bucket_internal", None) - - if past_key_values is not None: - if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) - elif inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif ( - input_ids.shape[1] != cache_position.shape[0] - ): # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - elif bucket_internal and token_idx is not None: - # for the 1st token we can slice the inputs till token idx for the fwd pass. - input_ids = input_ids[:, :token_idx] - attention_mask = attention_mask[:, :token_idx] - if cross_attention_mask is not None: - cross_attention_mask = cross_attention_mask[:, :token_idx, ...] - - # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if token_idx is not None: - position_ids = torch.index_select( - position_ids, 1, token_idx - 1 - ) - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone( - memory_format=torch.contiguous_format - ) - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and cross_attention_states is not None: - raise ValueError( - "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" - ) - - if pixel_values is not None: - if aspect_ratio_ids is None: - raise ValueError( - "`aspect_ratio_ids` must be provided if `pixel_values` is provided" - ) - # get vision tokens from vision model - vision_outputs = self.vision_model( - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - use_flash_attention=use_flash_attention, - ) - cross_attention_states = vision_outputs[0] - cross_attention_states = self.multi_modal_projector( - cross_attention_states - ).reshape(-1, cross_attention_states.shape[-2], self.hidden_size) - - if cross_attention_mask is not None: - cross_attention_mask, full_text_row_masked_out_mask = ( - _prepare_cross_attention_mask( - cross_attention_mask, - num_vision_tokens=self.vision_model.num_patches, - dtype=self.dtype, - token_idx=token_idx, - ) - ) - else: - full_text_row_masked_out_mask = None - - if cross_attention_mask is not None: - if cache_position is not None: - cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - :, :, cache_position - ] - elif past_key_values is not None: - if token_idx is not None: - cross_attention_mask = torch.index_select( - cross_attention_mask, -2, token_idx - 1 - ) - full_text_row_masked_out_mask = torch.index_select( - full_text_row_masked_out_mask, -2, token_idx - 1 - ) - else: - cross_attention_mask = cross_attention_mask[:, :, -1:] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - :, :, -1: - ] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = { - "input_ids": input_ids.clone(memory_format=torch.contiguous_format), - "inputs_embeds": None, - } - - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep - - # keep cache_position implementation as None for HPU - cache_position = None - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "labels": labels, - "return_dict": kwargs.get("return_dict"), - "full_text_row_masked_out_mask": full_text_row_masked_out_mask, - "use_flash_attention": use_flash_attention, - "cross_attention_mask": cross_attention_mask, - "cross_attention_states": cross_attention_states, - "output_attentions": output_attentions, - "flash_attention_recompute": flash_attention_recompute, - } - ) - - return model_inputs diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index a80a86a7..ac1578e9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -54,7 +54,8 @@ import habana_frameworks.torch as htorch # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py from typing import Union from transformers.feature_extraction_utils import BatchFeature -from transformers.image_utils import ImageInput, VideoInput +from transformers.image_utils import ImageInput +from transformers.video_utils import VideoInput from transformers.processing_utils import ( ProcessingKwargs, ProcessorMixin, diff --git a/backends/gaudi/server/text_generation_server/models/galactica.py b/backends/gaudi/server/text_generation_server/models/galactica.py deleted file mode 100644 index 7c4e462c..00000000 --- a/backends/gaudi/server/text_generation_server/models/galactica.py +++ /dev/null @@ -1,156 +0,0 @@ -import re -import torch -import torch.distributed - - -from transformers import ( - PreTrainedTokenizerBase, -) -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - NextTokenChooser, - StoppingCriteria, -) -from text_generation_server.utils.chunks import concat_text_chunks - -# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py - -# we split individual characters inside special tokens like [START_DNA] -CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])") - -# token added to implement a custom sequence tokenization. This token is added at -# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance -# that they do not occur in the corpus. The digits are escaped so that the token does not appear -# literally in the source code in case we ever include it in the training data. -SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E" - - -def _insert_split_marker(m: re.Match): - """ - Applies split marker based on a regex match of special tokens such as - [START_DNA]. - Parameters - ---------- - n : str - Input text to split - Returns - ---------- - str - the text with the split token added - """ - start_token, _, sequence, end_token = m.groups() - sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) - return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" - - -def escape_custom_split_sequence(text): - """ - Applies custom splitting to the text for GALILEO's tokenization - Parameters - ---------- - text : str - Input text to split - Returns - ---------- - str - the text with the split token added - """ - return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) - - -# END CREDIT - - -class GalacticaCausalLMBatch(CausalLMBatch): - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "GalacticaCausalLMBatch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - prefix_offsets = [] - top_n_tokens = [] - read_offsets = [] - requests_idx_mapping = {} - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - # Add escape_custom_split_sequence to the CausalLMBatch logic - inputs.append( - escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks)) - ) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) - - tokenized_inputs = tokenizer( - inputs, - return_tensors="pt", - padding=True, - return_token_type_ids=False, - truncation=True, - max_length=max_truncation, - ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append(0) - read_offsets.append(input_len) - - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() - - input_ids = tokenized_inputs["input_ids"] - # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) - ) - # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] - - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - max_tokens = len(inputs) * max_input_length + max_decode_tokens - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) diff --git a/backends/gaudi/server/text_generation_server/models/globals.py b/backends/gaudi/server/text_generation_server/models/globals.py index cd221e14..cdde67ca 100644 --- a/backends/gaudi/server/text_generation_server/models/globals.py +++ b/backends/gaudi/server/text_generation_server/models/globals.py @@ -4,14 +4,14 @@ from loguru import logger from text_generation_server.utils.log import log_master REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} -ATTENTION = os.getenv("ATTENTION", "default") +ATTENTION = os.getenv("ATTENTION", "paged") # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in { "1", "true", } log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -_expected = {"paged", "default"} +_expected = {"paged"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" diff --git a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py deleted file mode 100644 index 98d7352a..00000000 --- a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py +++ /dev/null @@ -1,882 +0,0 @@ -from io import BytesIO -from PIL import Image -import torch -import time - -from dataclasses import dataclass -from opentelemetry import trace -from transformers import ( - AutoConfig, - AutoProcessor, - AutoTokenizer, - PreTrainedTokenizerBase, - ProcessorMixin, -) -from typing import Optional, Tuple, List, Type, Dict - -from text_generation_server.models import Model -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling -import torch.distributed -from text_generation_server.models.custom_modeling.idefics_modeling import ( - IdeficsForVisionText2Text, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.quantization import get_loader - -tracer = trace.get_tracer(__name__) - - -@dataclass -class IdeficsCausalLMBatch(Batch): - batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] - - # Decoder values - input_ids: torch.Tensor - attention_mask: torch.Tensor - position_ids: torch.Tensor - pixel_values: Optional[torch.Tensor] - image_hidden_states: Optional[torch.Tensor] - image_attention_mask: Optional[torch.Tensor] - past_key_values: Optional[List[Tuple]] - - # All tokens - all_input_ids: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - prefix_offsets: List[int] - read_offsets: List[int] - - # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - - # Metadata used for padding - max_input_length: int - padding_right_offset: int - - # Maximum number of tokens this batch will grow to - max_tokens: int - - # Past metadata - keys_head_dim_last: bool = True - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "IdeficsCausalLMBatch": - raise NotImplementedError - - @classmethod - def from_pb_processor( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - processor: ProcessorMixin, # Hack - config, - dtype: torch.dtype, - device: torch.device, - ) -> "IdeficsCausalLMBatch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - prefix_offsets = [] - read_offsets = [] - requests_idx_mapping = {} - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - inputs.append(r.input_chunks.chunks) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) - - # TODO Check impact on idefics - prompts = [] - for inp in inputs: - # Each input is encoded into a list, where each element of this input list is either a string or a URL - prompt = [] - for chunk in inp: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - prompt.append(chunk.text) - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - prompt.append(image) - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - prompts.append(prompt) - - # The processor replaces the call to tokenizer, and - # a/ takes care of fetching images from the URL - # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model - tokenized_inputs = processor( - prompts, - return_tensors="pt", - padding=True, - truncation=True, - max_length=max_truncation, - # TODO Check impact on idefics - # add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token - ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append( - input_len - 5 - ) # To decode without potential fallbacks errors - read_offsets.append( - input_len - ) # To decode without potential fallbacks errors - - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() - - input_ids = tokenized_inputs["input_ids"] - pixel_values = tokenized_inputs.get("pixel_values", None) - image_hidden_states = None - # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) - ) - # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] - # Do the same for image_attention_mask - if pixel_values is None: - image_attention_mask = None - else: - image_attention_mask = input_ids.new_zeros( - ( - pb.size, - max_input_length + padding_right_offset, - pixel_values.size(1), - ) - ) - image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ - "image_attention_mask" - ] - - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].T.split( - 1, dim=1 - ) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list - - max_tokens = len(inputs) * (max_input_length + max_decode_tokens) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) - - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: - # It deletes requests from the batch. For instance when client lost connection - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - requests = [] - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - max_input_length = 0 - - next_token_choosers = [] - stopping_criterias = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - keep_indices.append(idx) - - requests.append(self.requests[idx]) - prefix_offsets.append(self.prefix_offsets[idx]) - read_offsets.append(self.read_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - position_ids = self.position_ids[keep_indices] - self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - ] - # Do the same for pixel_values and image_attention_mask - pixel_values = self.pixel_values[keep_indices] - self.image_attention_mask = self.image_attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.image_attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - :, - ] - if self.image_hidden_states is None: - image_hidden_states = None - else: - image_hidden_states = self.image_hidden_states[keep_indices] - - # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) is tuple: - self.past_key_values = [list(layer) for layer in self.past_key_values] - - # Update tensors in-place to allow incremental garbage collection - past_kv_length = max_input_length - 1 - for layer in self.past_key_values: - past_keys, past_values = layer - if len(past_keys.shape) == 3: - # Force past to be of dim [self_size, num_heads, ...] for easy indexing - past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) - past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) - if self.keys_head_dim_last: - layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] - else: - layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] - del past_keys - layer[1] = past_values[keep_indices, :, -past_kv_length:, :] - del past_values - - max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.pixel_values = pixel_values - self.image_hidden_states = image_hidden_states - self.position_ids = position_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.prefix_offsets = prefix_offsets - self.read_offsets = read_offsets - self.next_token_choosers = next_token_choosers - self.stopping_criterias = stopping_criterias - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - - return self - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate( - cls, batches: List["IdeficsCausalLMBatch"] - ) -> "IdeficsCausalLMBatch": - # It adds new requests to the batch - # Used for padding - total_batch_size = 0 - max_input_length = 0 - max_num_images = 0 - padding_right_offset = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - max_num_images = max(max_num_images, batch.pixel_values.size(1)) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - max_tokens = 0 - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - pixel_values = None - image_hidden_states = None - image_attention_mask = None - past_key_values = [] - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - prefix_offsets.extend(batch.prefix_offsets) - read_offsets.extend(batch.read_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - # Create padded tensor - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_input_length + padding_right_offset), - ) - - curr_batch_max_num_images = batch.pixel_values.size(1) - if pixel_values is None: - pixel_values = batch.pixel_values.new_zeros( - (total_batch_size, max_num_images, 3, 224, 224) - ) - pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( - batch.pixel_values - ) - - if image_attention_mask is None: - image_attention_mask = batch.image_attention_mask.new_zeros( - ( - total_batch_size, - max_input_length + padding_right_offset, - max_num_images, - ) - ) - - # We need to slice the attention mask to remove padding from previous steps - # and to remove unused allocated space - left_offset = max_input_length - batch.max_input_length - batch_left_offset = ( - batch.attention_mask.shape[1] - - batch.max_input_length - - batch.padding_right_offset - ) - attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - ] = batch.attention_mask[ - :, - batch_left_offset : -batch.padding_right_offset, - ] - image_attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - :curr_batch_max_num_images, - ] = batch.image_attention_mask[ - :, batch_left_offset : -batch.padding_right_offset, : - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((total_batch_size, 1)) - position_ids[start_index:end_index] = batch.position_ids - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - if isinstance(batch.past_key_values[0], tuple): - batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] - for layer in batch.past_key_values - ] - elif len(batch.past_key_values[0][0].shape) == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) - - start_index = end_index - - first_past_kvs = batches[0].past_key_values - _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - - padded_past_values_shape = ( - total_batch_size, - num_heads, - max_input_length - 1, - head_dim, - ) - - if batches[0].keys_head_dim_last: - padded_past_keys_shape = padded_past_values_shape - else: - # seq_length is last for BLOOM - padded_past_keys_shape = ( - total_batch_size, - num_heads, - head_dim, - max_input_length - 1, - ) - - # Iterate over attention layers - # Concatenate past key values layer by layer to allow incremental garbage collection - for j in range(len(first_past_kvs)): - padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) - start_index = 0 - for batch in batches: - past_keys = batch.past_key_values[j][0] - # Clear reference to the original tensor - batch.past_key_values[j][0] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - if batch.keys_head_dim_last: - padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( - past_keys[:, :, -past_seq_len:, :] - ) - else: - # BLOOM case - padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( - past_keys[:, :, :, -past_seq_len:] - ) - del past_keys - - start_index = end_index - - padded_past_values = first_past_kvs[j][1].new_zeros( - padded_past_values_shape - ) - start_index = 0 - for batch in batches: - past_values = batch.past_key_values[j][1] - # Clear reference to the original tensor - batch.past_key_values[j][1] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the past values to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( - past_values[:, :, -past_seq_len:, :] - ) - del past_values - - # Update values - start_index = end_index - - past_key_values.append([padded_past_keys, padded_past_values]) - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=past_key_values, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - ) - - def __len__(self): - return len(self.requests) - - -class IdeficsCausalLM(Model): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.quantize = quantize - self.process_group, rank, world_size = initialize_torch_distributed() - device = torch.device("hpu") - dtype = torch.bfloat16 if dtype is None else dtype - self.device, self.dtype = device, dtype - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - config.vision_config.quantize = quantize - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - weights_loader = get_loader( - quantize=quantize, model_id=model_id, revision=revision - ) - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - weights_loader=weights_loader, - ) - - model = IdeficsForVisionText2Text(config, weights) - - self.config = config - - torch.distributed.barrier(group=self.process_group) - super().__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[IdeficsCausalLMBatch]: - return IdeficsCausalLMBatch - - def forward( - self, - input_ids, - attention_mask, - position_ids, - pixel_values, - image_hidden_states, - image_attention_mask, - past_key_values: Optional = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "image_hidden_states": image_hidden_states, - "image_attention_mask": image_attention_mask, - "past_key_values": past_key_values, - "use_cache": True, - "return_dict": True, - } - if self.has_position_ids: - kwargs["position_ids"] = position_ids - - outputs, speculative_logits = self.model.forward(**kwargs) - return ( - outputs.logits, - speculative_logits, - outputs.past_key_values, - outputs.image_hidden_states, - ) - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batch: IdeficsCausalLMBatch - ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]: - start = time.time_ns() - # slice the attention mask to the correct shape - attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - if batch.image_attention_mask is None: - image_attention_mask = None - else: - if batch.input_ids.size(1) == 1: - # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), - # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension - # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated - # token need to attend to the encoder hidden states (i.e. the vision encoder) - # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic - image_attention_mask = batch.image_attention_mask[ - :, -(batch.padding_right_offset + 1) - ].unsqueeze(1) - else: - image_attention_mask = batch.image_attention_mask[ - :, : -batch.padding_right_offset - ] - - logits, speculative_logits, past, image_hidden_states = self.forward( - input_ids=batch.input_ids, - attention_mask=attention_mask, - position_ids=batch.position_ids, - pixel_values=batch.pixel_values, - image_hidden_states=batch.image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=batch.past_key_values, - ) - # Hardcoded remove image tokens - logits[:, 32000:32001] = torch.finfo(logits.dtype).min - - start_decode = time.time_ns() - - # Results - generations: List[Generation] = [] - stopped = True - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - ) - - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits[-1:, :] - ) - - # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) - new_input_length = input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[:, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_squeezed, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + torch.log_softmax( - logits, -1 - ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - # Update values - batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( - next_token_id_squeezed.item() - ) - batch.input_ids[i, 0] = next_token_id - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) - - # We finished all generations in the batch; there is no next batch - if stopped: - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, None, (forward_ns, decode_ns) - - # Slice unused values from prefill - batch.input_ids = batch.input_ids[:, :1] - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask[:, -batch.padding_right_offset] = 1 - batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( - batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] - ) - # Decrease right offset - batch.padding_right_offset -= 1 - - # Update position_ids - batch.position_ids = batch.position_ids[:, -1:] + 1 - - # Update past key values - batch.past_key_values = past - batch.image_hidden_states = image_hidden_states - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/mamba.py b/backends/gaudi/server/text_generation_server/models/mamba.py deleted file mode 100644 index f6dcde68..00000000 --- a/backends/gaudi/server/text_generation_server/models/mamba.py +++ /dev/null @@ -1,814 +0,0 @@ -import torch -import torch.distributed -from transformers import AutoTokenizer, PreTrainedTokenizerBase -from typing import Optional -from text_generation_server.models.custom_modeling.mamba_modeling import ( - MambaConfig, -) -from loguru import logger -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL -import time -from text_generation_server.models.custom_modeling.mamba_modeling import ( - MambaModel, - InferenceParams, -) -from text_generation_server.models import Model -from typing import Any, List, Tuple, Type, Dict -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.quantization import get_loader -from text_generation_server.utils.tokens import batch_top_tokens, Sampling -from dataclasses import dataclass -from text_generation_server.utils import NextTokenChooser, StoppingCriteria - - -def new_inference_params( - n_blocks: int, - batch_size: int, - d_inner: int, - d_conv: int, - d_state: int, - seqlen_offset: int, - dtype: torch.dtype, - device: torch.device, -): - max_seqlen = 0 - conv_states = torch.zeros( - ( - n_blocks, - batch_size, - d_inner, - d_conv, - ), - device=device, - dtype=dtype, - ) - ssm_states = torch.zeros( - ( - n_blocks, - batch_size, - d_inner, - d_state, - ), - device=device, - dtype=dtype, - ) - inference_params = InferenceParams( - max_seqlen=max_seqlen, - max_batch_size=batch_size, - seqlen_offset=seqlen_offset, - conv_states=conv_states, - ssm_states=ssm_states, - ) - return inference_params - - -@dataclass -class MambaBatch(Batch): - batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] - - # Decoder values - input_ids: torch.Tensor - - # All tokens - all_input_ids: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - prefix_offsets: List[int] - read_offsets: List[int] - - # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - top_n_tokens: List[int] - top_n_tokens_tensor: torch.Tensor - - # Metadata used for padding - max_input_length: int - padding_right_offset: int - - # Maximum number of tokens this batch will grow to - max_tokens: int - - # Past metadata - keys_head_dim_last: bool = True - - # Inference params - inference_params: Optional[Dict[str, Any]] = None - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "MambaBatch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - prefix_offsets = [] - read_offsets = [] - requests_idx_mapping = {} - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - inputs.append(concat_text_chunks(r.input_chunks.chunks)) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) - - tokenized_inputs = tokenizer( - inputs, - return_tensors="pt", - padding=True, - return_token_type_ids=False, - truncation=True, - max_length=max_truncation, - ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append(input_len - 5) - read_offsets.append(input_len) - - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() - input_ids = tokenized_inputs["input_ids"] - all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - max_tokens = len(inputs) * (max_input_length + max_decode_tokens) - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - # past_input_ids=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) - - def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - requests = [] - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - max_input_length = 0 - - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - indices = [] - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - keep_indices.append(idx) - - requests.append(self.requests[idx]) - prefix_offsets.append(self.prefix_offsets[idx]) - read_offsets.append(self.read_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - indices.append(idx) - - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(self.top_n_tokens[idx]) - remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - - top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] - max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.prefix_offsets = prefix_offsets - self.read_offsets = read_offsets - self.next_token_choosers = next_token_choosers - self.stopping_criterias = stopping_criterias - self.top_n_tokens = top_n_tokens - self.top_n_tokens_tensor = top_n_tokens_tensor - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - - # TODO - # Kept it simple by just updating the state, maybe updating the other CPU values is necessary. - self.inference_params.conv_states = self.inference_params.conv_states[ - :, indices - ] - self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices] - return self - - @classmethod - def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": - # Used for padding - total_batch_size = 0 - max_input_length = 0 - padding_right_offset = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - max_tokens = 0 - seqlen_offset = 0 - - (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape - (_, _, _, d_state) = batches[0].inference_params.ssm_states.shape - dtype = batches[0].inference_params.conv_states.dtype - device = batches[0].inference_params.conv_states.device - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=total_batch_size, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=device, - dtype=dtype, - ) - - # Batch tensors - input_ids = None - top_n_tokens_tensor = None - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - prefix_offsets.extend(batch.prefix_offsets) - read_offsets.extend(batch.read_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - top_n_tokens.extend(batch.top_n_tokens) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - if top_n_tokens_tensor is None: - top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( - total_batch_size, - ) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) - - inference_params.max_seqlen = max( - inference_params.max_seqlen, batch.inference_params.max_seqlen - ) - assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset" - inference_params.seqlen_offset = max( - inference_params.seqlen_offset, batch.inference_params.seqlen_offset - ) - - inference_params.conv_states[:, start_index:end_index] = ( - batch.inference_params.conv_states - ) - inference_params.ssm_states[:, start_index:end_index] = ( - batch.inference_params.ssm_states - ) - - start_index = end_index - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - inference_params=inference_params, - ) - - def __len__(self): - return len(self.requests) - - -class Mamba(Model): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.quantize = quantize - self.process_group, _rank, world_size = initialize_torch_distributed() - if world_size > 1: - raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") - self.cuda_graphs = {} - if torch.cuda.is_available(): - device = torch.device("cuda") - # Bf16 is important. In f16 accumulations in the matmul are causing - # differences while the server is under load. - # This is detectable by the integration load test - dtype = torch.bfloat16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - "EleutherAI/gpt-neox-20b", - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - config = MambaConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - tokenizer.bos_token_id = config.bos_token_id - tokenizer.eos_token_id = config.eos_token_id - tokenizer.pad_token = tokenizer.eos_token - - config.quantize = quantize - config.speculator = speculator - torch.distributed.barrier(group=self.process_group) - weights_loader = get_loader( - quantize=quantize, model_id=model_id, revision=revision - ) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - weights_loader=weights_loader, - ) - model = MambaModel(config, weights) - torch.distributed.barrier(group=self.process_group) - super(Mamba, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - @property - def batch_type(self) -> Type[MambaBatch]: - return MambaBatch - - def warmup(self, batch) -> Optional[int]: - # TODO: implement warmup for Mamba if needed - if CUDA_GRAPHS: - if self.speculate is None or self.speculate == 0: - try: - logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") - # Warmup cuda graphs - for bs in CUDA_GRAPHS: - self.cuda_graph_warmup(bs) - except Exception: - logger.exception("Decode cuda graph warmup failed") - else: - logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") - - return None - - def cuda_graph_warmup(self, batch_size: int): - input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) - n_blocks = len(self.model.blocks) - - d_state = self.model.config.d_state - d_conv = self.model.config.d_conv - # Inner takes the expand multiplication - d_inner = self.model.config.d_inner - - # Important seqlen_offset to go through the update mecanism with the state - seqlen_offset = 1 - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=batch_size, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=self.device, - dtype=self.dtype, - ) - - graph = torch.cuda.CUDAGraph() - - torch.cuda.synchronize() - # Run once outside to warmup - self.model.forward(input_ids=input_ids, inference_params=inference_params) - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - logits, speculative_logits = self.model.forward( - input_ids=input_ids, inference_params=inference_params - ) - torch.cuda.synchronize() - graph_dict = { - "input_ids": input_ids, - "inference_params": inference_params, - "graph": graph, - "logits": logits, - "speculative_logits": speculative_logits, - } - self.cuda_graphs[batch_size] = graph_dict - - def tunableop_warmup(self, batch_size: int, seqlen: int): - input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) - n_blocks = len(self.model.blocks) - - d_state = self.model.config.d_state - d_conv = self.model.config.d_conv - # Inner takes the expand multiplication - d_inner = self.model.config.d_inner - - # Important seqlen_offset to go through the update mecanism with the state - seqlen_offset = 1 - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=seqlen, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=self.device, - dtype=self.dtype, - ) - - self.model.forward(input_ids=input_ids, inference_params=inference_params) - - def forward( - self, input_ids: torch.Tensor, inference_params: Any - ) -> Tuple[torch.Tensor, torch.Tensor]: - bs = input_ids.shape[0] - padded_bs = bs - if bs == 3: - padded_bs = 4 - elif 3 < bs <= 8: - padded_bs = 8 - elif bs > 8: - padded_bs = (bs + 7) // 8 * 8 - - # Try to find an associated cuda graph - cuda_graph = self.cuda_graphs.get(padded_bs, None) - is_prefill = inference_params is None or inference_params.seqlen_offset == 0 - - if is_prefill or cuda_graph is None: - return self.model( - input_ids, - inference_params=inference_params, - ) - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][:bs] = input_ids - cuda_graph["inference_params"].conv_states[ - :, :bs - ] = inference_params.conv_states - cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states - - # Replay the graph - cuda_graph["graph"].replay() - - inference_params.conv_states.copy_( - cuda_graph["inference_params"].conv_states[:, :bs] - ) - inference_params.ssm_states.copy_( - cuda_graph["inference_params"].ssm_states[:, :bs] - ) - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None - ) - logits = cuda_graph["logits"][:bs] - return logits, speculative_logits - - def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: - start = time.time_ns() - input_ids = ( - batch.input_ids - ) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids - - batch_size, max_seqlen = input_ids.shape - # Inference params - - if batch.inference_params is None: - # 0 is important here - seqlen_offset = 0 - n_blocks = len(self.model.blocks) - d_state = self.model.config.d_state - d_conv = self.model.config.d_conv - d_inner = self.model.config.d_inner - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=batch_size, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=self.device, - dtype=self.dtype, - ) - batch.inference_params = inference_params - - # Forward pass - logits, speculative_logits = self.forward( - input_ids, inference_params=batch.inference_params - ) - - # batch.inference_params = new_inference_params - # Results - generations: List[Generation] = [] - stopped = True - - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - torch.log_softmax(logits[:, -1], -1), - accepted_ids, - ) - - start_decode = time.time_ns() - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - batch.top_n_tokens, - batch_top_token_ids, - batch_top_token_logprobs, - ) - - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - top_n_tokens, - top_token_ids, - top_token_logprobs, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits[-1:, :] - ) - - # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) - new_input_length = input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[:, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_squeezed, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - generated_text = None - - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + torch.log_softmax( - logits, -1 - ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - else: - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - # Update values - batch.next_token_choosers[i] = batch.next_token_choosers[ - i - ].advance_grammar(next_token_id_squeezed.item()) - batch.input_ids[i, 0] = next_token_id - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) - - # We finished all generations in the batch; there is no next batch - if stopped: - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, None, (forward_ns, decode_ns) - - # Slice unused values from prefill - batch.input_ids = batch.input_ids[:, :1] - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/starcoder.py b/backends/gaudi/server/text_generation_server/models/starcoder.py deleted file mode 100644 index 6c6ca2cf..00000000 --- a/backends/gaudi/server/text_generation_server/models/starcoder.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -from dataclasses import dataclass -from typing import List, Optional, Type - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch - - -@dataclass -class StarCoderCausalLMBatch(CausalLMBatch): - past_key_values: Optional[List[torch.Tensor]] - - def detach_kv_cache(self): - past_keys = [] - past_values = [] - last_dim = int(self.past_key_values[0].size(dim=-1) / 2) - for key_value in self.past_key_values: - past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0]) - past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1]) - del self.past_key_values - - return past_keys, past_values - - def attach_kv_cache(self, past_keys, past_values): - self.past_key_values = [ - torch.cat((key, value), dim=-1) - for key, value in zip(past_keys, past_values) - ] - - -class StarCoder(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ): - - super(StarCoder, self).__init__( - model_id=model_id, - revision=revision, - dtype=dtype, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return StarCoderCausalLMBatch diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py deleted file mode 100644 index 6929b2ef..00000000 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ /dev/null @@ -1,1609 +0,0 @@ -import json -import re -import torch -import os -import time -import math -from PIL import Image -from io import BytesIO -from opentelemetry import trace -from loguru import logger -from typing import Iterable, Optional, Tuple, List, Type, Dict -import tempfile -import copy -from text_generation_server.models import Model -from transformers import PreTrainedTokenizerBase -from text_generation_server.utils import weight_files -from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import ( - CausalLMBatch, - CausalLMRequest, - remove_kv_cache_from_output, -) - -from transformers.models.llava_next.modeling_llava_next import ( - get_anyres_image_grid_shape, -) - -from transformers import AutoProcessor -import text_generation_server.habana_quantization_env as hq_env -from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi -from text_generation_server.utils import ( - HeterogeneousNextTokenChooser, - make_tokenizer_optional, - is_tokenizer_transparent, - pad_next_token_chooser_parameters, -) -import habana_frameworks.torch as htorch -from optimum.habana.utils import HabanaProfile -from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES -from optimum.habana.utils import get_hpu_memory_stats -from optimum.habana.checkpoint_utils import get_ds_injection_policy - -from transformers import ( - AutoTokenizer, - AutoConfig, -) -from optimum.habana.checkpoint_utils import model_on_meta - -from text_generation_server.utils.speculate import get_speculate -from text_generation_server.models.types import ( - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.utils.debug import dbg_trace - -tracer = trace.get_tracer(__name__) - -IDEFICS2_FAKE_TOKEN = "" -IDEFICS2_IMAGE_TOKEN = "" - - -IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") -BASE_IMAGE_TOKENS = int(os.environ.get("BASE_IMAGE_TOKENS", 2048)) -MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) -CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) - - -PREFILL_WARMUP_BATCH_SIZE_LIST = [] -PREFILL_WARMUP_SEQLEN_LIST = [] -DECODE_WARMUP_BATCH_SIZE_LIST = [] -CROSS_ATTENTION_LAYERS = [] - - -def round_up(warmup_list: list, num): - i = 0 - for i in warmup_list: - if num <= i: - break - return i if i > 0 else num - - -def split(string) -> List[Dict[str, str]]: - parts = [] - cursor = 0 - for pattern in IMAGES.finditer(string): - start = pattern.start() - if start != cursor: - parts.append({"type": "text", "content": string[cursor:start]}) - - parts.append({"type": "image", "content": pattern.group(1)}) - cursor = pattern.end() - - if cursor != len(string): - parts.append({"type": "text", "content": string[cursor:]}) - - return parts - - -def image_text_replacement(config) -> str: - if config.model_type == "idefics2": - image_seq_len = 64 - image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" - return image_str - elif config.model_type == "llava_next": - return "" - elif config.model_type == "paligemma": - return "" - elif config.model_type == "mllama": - return "<|image|>" - else: - raise RuntimeError(f"Unknown config {config.model_type} for multimodal") - - -def image_text_replacement_fixup(config, text: str) -> str: - if config.model_type == "idefics2": - return text.replace( - f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN - ) - return text - - -def get_unpadded_features( - original_height: int, - original_width: int, - npatches: int, - num_patch_height: int, - num_patch_width: int, -) -> Tuple[int, int]: - current_height = npatches * num_patch_height - current_width = npatches * num_patch_width - - aspect_ratio: float = original_width / original_height - current_aspect_ratio: float = current_width / current_height - - if aspect_ratio > current_aspect_ratio: - new_height = (original_height * current_width) // original_width - padding = (current_height - new_height) // 2 - current_height = current_height - (2 * padding) - else: - new_width = (original_width * current_height) // original_height - padding = (current_width - new_width) // 2 - current_width = current_width - (2 * padding) - - unpadded_features = current_height * current_width - newline_features = current_height - return (unpadded_features, newline_features) - - -def get_number_of_features(height: int, width: int, config) -> int: - # From config - # Hardcoded for CLIP for now - # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] - image_grid_pinpoints = config.image_grid_pinpoints - image_size = config.vision_config.image_size - patch_size = config.vision_config.patch_size - - assert image_size % patch_size == 0 - - npatches = image_size // patch_size - - # Dimensions are intentionally swapped to be bug-compatible with - # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - [height, width], - image_grid_pinpoints, - image_size, - ) - - unpadded_features, newline_features = get_unpadded_features( - height, width, npatches, num_patch_height, num_patch_width - ) - # The base patch covers the entire image - base_features = npatches**2 - return unpadded_features + newline_features + base_features - - -class VlmCausalLMBatch(CausalLMBatch): - pixel_values: Optional[List[torch.Tensor]] - pixel_attention_mask: Optional[List[torch.Tensor]] - image_sizes: Optional[List[Tuple[int, int]]] - aspect_ratio_ids: Optional[torch.Tensor] = None - aspect_ratio_mask: Optional[torch.Tensor] = None - cross_attention_mask: Optional[torch.Tensor] = None - prefilling: bool = True - token_idx: torch.Tensor = None - - def __init__( - self, - batch_id, - requests, - input_ids, - attention_mask, - position_ids, - past_key_values, - merged_kv_cache, - next_token_chooser, - top_n_tokens, - top_n_tokens_tensor, - input_length, - pixel_values: Optional[List[torch.Tensor]] = None, - pixel_attention_mask: Optional[List[torch.Tensor]] = None, - image_sizes: Optional[List[Tuple[int, int]]] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - prefilling: Optional[bool] = True, - ): - super().__init__( - batch_id=batch_id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - merged_kv_cache=merged_kv_cache, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_length, - ) - - self.pixel_values = pixel_values - self.pixel_attention_mask = pixel_attention_mask - self.image_sizes = image_sizes - self.aspect_ratio_ids = aspect_ratio_ids - self.aspect_ratio_mask = aspect_ratio_mask - self.cross_attention_mask = cross_attention_mask - self.prefilling = prefilling - - @property - def token_idx(self): - if self.prefilling: - # no right padding for prefill - token_idx_scalar = self.attention_mask.shape[-1] - 1 - return torch.tensor(token_idx_scalar).to(self.attention_mask.device) - else: - token_idx_scalar = self.attention_mask.shape[-1] - self.right_padding - return torch.tensor(token_idx_scalar).to(self.attention_mask.device) - - def padding_process(self, pad_id: int): - # self.input_ids = torch.index_select(self.input_ids, 1, self.token_idx - 1) - right_padding = MAX_TOTAL_TOKENS - self.attention_mask.shape[1] - self.input_ids = torch.nn.functional.pad( - self.input_ids, (0, right_padding), value=pad_id - ) - self.attention_mask = torch.nn.functional.pad( - self.attention_mask, (0, right_padding), value=0 - ) - # if self.position_ids is not None: - # self.position_ids = torch.index_select(self.position_ids, 1, self.token_idx - 1) + 1 - if self.cross_attention_mask is not None: - self.cross_attention_mask = torch.nn.functional.pad( - self.cross_attention_mask, (0, 0, 0, 0, 0, right_padding), value=0 - ) - if self.past is not None: - past_key_values_list = list(self.past_key_values) - for layer_id in range(len(self.past)): - past_key_value_list = list(self.past_key_values[layer_id]) - if layer_id not in CROSS_ATTENTION_LAYERS: - past_key_value_list[0] = torch.nn.functional.pad( - self.past_key_values[layer_id][0], - (0, 0, 0, right_padding), - value=0, - ) - past_key_value_list[1] = torch.nn.functional.pad( - self.past_key_values[layer_id][1], - (0, 0, 0, right_padding), - value=0, - ) - past_key_values_list[layer_id] = tuple(past_key_value_list) - self.past_key_values = tuple(past_key_values_list) - - self.prefilling = False - self.input_length = self.input_length - - @classmethod - def from_tokenized( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - batch_tokenized_inputs, - dtype: torch.dtype, - device: torch.device, - is_warmup: bool = False, - ) -> "VlmCausalLMBatch": - - dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}") - requests = [ - CausalLMRequest.from_pb(idx, req, tokenizer) - for idx, req in enumerate(pb.requests) - ] - - max_input_length = max(r.data.truncate for r in requests) - max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) - # TODO: Add support for sparse batches - top_n_tokens = [r.top_n_tokens for r in pb.requests] - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - # TODO: by tokenizing all inputs at once we loose information on actual input lengths - # this means that we cannot shift inputs to the left after a long input sequence - # was filtered out - new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) - parameters = [r.parameters for r in pb.requests] - # append the dummy parameters for dummy request - parameters = pad_next_token_chooser_parameters(parameters, new_bs) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - pb=parameters, - dtype=dtype, - device=device, - tokenizer=tokenizer, - quantization_enabled=hq_env.is_quantization_enabled, - ) - tokenized_inputs = batch_tokenized_inputs - input_len = tokenized_inputs["input_ids"].shape[1] - - bucket_size = max_input_length - left_padding = max_input_length - input_len - if is_warmup is False: - rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) - bucket_size = rounded_seq_len - 1 - left_padding = bucket_size - input_len - - input_ids = tokenized_inputs["input_ids"] - attention_mask = tokenized_inputs["attention_mask"] - cross_attention_mask = tokenized_inputs.get("cross_attention_mask", None) - # Allocate space for first token - input_ids = torch.nn.functional.pad( - input_ids, (left_padding, 1), value=tokenizer.pad_token_id - ) - attention_mask = torch.nn.functional.pad( - attention_mask, (left_padding, 1), value=0 - ) - if cross_attention_mask is not None: - cross_attention_mask = torch.nn.functional.pad( - cross_attention_mask, (0, 0, 0, 0, left_padding, 1), value=0 - ) - all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id - ).T.split(1, dim=1) - - # New input length after left padding - input_len = bucket_size - for r in requests: - r.input_length = input_len - r.prefix_offset = input_len - 5 - r.read_offset = input_len - r.all_input_ids = all_input_ids[r.idx] - input_ids = input_ids.to(device) - attention_mask = attention_mask.to(device) - cross_attention_mask = ( - cross_attention_mask.to(device) - if cross_attention_mask is not None - else None - ) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - htorch.core.mark_step() - - return cls( - batch_id=pb.id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_len, - cross_attention_mask=cross_attention_mask, - ) - - @classmethod - def batch_tokenized_inputs( - cls, - requests: Iterable[generate_pb2.Request], - tokenizer, - processor, - config, - is_warmup, - ): - image_inputs = {} - texts = [] - images = [] - batch_tokenized_inputs = {} - - for i, r in enumerate(requests): - # Each input is encoded into a list, where each element of this input list is either a string or a URL - curr_text = "" - curr_image = None - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - curr_text += chunk.text - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - # TODO unsure about BOS - curr_image = image - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - - if image_text_replacement(config) not in curr_text: - if "" in curr_text: - curr_text = curr_text.replace( - "", image_text_replacement(config) - ) - else: - curr_text = image_text_replacement(config) + curr_text - - texts.append(curr_text) - if curr_image is not None: - if config.model_type == "mllama": - images.append([curr_image]) - else: - images.append(curr_image) - - if is_warmup is True: - images += [images[0]] * (len(texts) - len(images)) - - missing_inputs = 0 - dummy_images = None - if is_warmup is False: - new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) - missing_inputs = new_bs - len(requests) - if missing_inputs > 0: - dummy_inputs = [] - if len(texts) > 0: - dummy_inputs = [texts[0]] * missing_inputs - dummy_images = [images[0]] * missing_inputs - texts += dummy_inputs - images += dummy_images - - processor_output = processor( - images, - texts, - truncation=True, - max_length=r.truncate, - add_special_tokens=r.add_special_tokens, - return_tensors="pt", - padding_side="left", - padding="longest", - ) - if "input_ids" in processor_output: - batch_tokenized_inputs.update({"input_ids": processor_output["input_ids"]}) - if "attention_mask" in processor_output: - batch_tokenized_inputs.update( - {"attention_mask": processor_output["attention_mask"]} - ) - if "cross_attention_mask" in processor_output: - batch_tokenized_inputs.update( - {"cross_attention_mask": processor_output["cross_attention_mask"]} - ) - if "pixel_values" in processor_output: - image_inputs.update({"pixel_values": processor_output["pixel_values"]}) - if "pixel_attention_mask" in processor_output: - image_inputs.update( - {"pixel_attention_mask": processor_output["pixel_attention_mask"]} - ) - if "aspect_ratio_ids" in processor_output: - image_inputs.update( - {"aspect_ratio_ids": processor_output["aspect_ratio_ids"]} - ) - if "aspect_ratio_mask" in processor_output: - image_inputs.update( - {"aspect_ratio_mask": processor_output["aspect_ratio_mask"]} - ) - if "image_sizes" in processor_output: - image_inputs.update({"image_sizes": processor_output["image_sizes"]}) - - return batch_tokenized_inputs, image_inputs - - @classmethod - def from_pb_processor( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - processor, - config, - dtype: torch.dtype, - device: torch.device, - is_warmup: bool = False, - ) -> "VlmCausalLMBatch": - batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor, config, is_warmup - ) - batch = cls.from_tokenized( - pb, tokenizer, batch_tokenized_inputs, dtype, device, is_warmup=is_warmup - ) - if image_inputs is not None: - batch.pixel_values = image_inputs["pixel_values"].to(device=device) - if "pixel_attention_mask" in image_inputs: - batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( - device=device - ) - else: - batch.pixel_attention_mask = None - if "image_sizes" in image_inputs: - batch.image_sizes = image_inputs["image_sizes"].to(device=device) - else: - batch.image_sizes = None - if "aspect_ratio_ids" in image_inputs: - batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to( - device=device - ) - else: - batch.aspect_ratio_ids = None - if "aspect_ratio_mask" in image_inputs: - batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to( - device=device - ) - else: - batch.aspect_ratio_mask = None - else: - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - batch.aspect_ratio_ids = None - batch.aspect_ratio_mask = None - batch.cross_attention_mask = None - - return batch - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate( - cls, - batches: List["CausalLMBatch"], - pad_token_id: int = 0, - is_warmup: bool = False, - ) -> "CausalLMBatch": - return cls.recombine(batches, pad_token_id, is_warmup) - - @classmethod - def recombine( - cls, - batches: List["VlmCausalLMBatch"], - pad_token_id: int, - is_warmup: bool = False, - ) -> "VlmCausalLMBatch": - if not all(b.past_key_values is not None for b in batches): - raise ValueError("KV cache not allocated! Cannot recombine before prefill!") - # Used for padding - - total_requests = sum(len(b) for b in batches) - new_bs = total_requests - if not is_warmup: - new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) - - if len(batches) > 1: - scenario = "CONCAT" - elif batches[0].prefilling: - scenario = "SHIFT" - else: - return batches[0] - - dbg_trace( - scenario, - f"bs:{[b.batch_size for b in batches]}->{new_bs}" - f" reqs:{[len(b) for b in batches]}", - ) - - if scenario == "SHIFT": - batch = batches[0] - batch.padding_process(pad_token_id) - return batch - - total_batch_size = 0 - max_input_length = 0 - for i, batch in enumerate(batches): - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.input_length) - # Batch attributes - requests = [] - input_lengths = [] - top_n_tokens = [] - parameters = [] - fsm_grammar_states = [] - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - past_key_values = [] - top_n_tokens_tensor = None - cross_attention_mask = None - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - keep_indices = [] - for req in batch.requests: - keep_indices.append(req.idx) - - requests.extend(batch.requests) - parameters.extend([r.data.parameters for r in batch.requests]) - fsm_grammar_states.extend( - [batch.next_token_chooser.fsm_grammar_states[i] for i in keep_indices] - ) - input_lengths.extend([batch.input_length]) - top_n_tokens.extend([batch.top_n_tokens[i] for i in keep_indices]) - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((new_bs, MAX_TOTAL_TOKENS)) - # # Copy to correct indices - - left_offset = max_input_length - batch.input_length - right_padding = MAX_TOTAL_TOKENS - max_input_length - input_ids[start_index:end_index, left_offset:-right_padding] = ( - batch.input_ids[keep_indices, : batch.input_length] - ) - - # Create padded tensor - if top_n_tokens_tensor is None: - top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( - new_bs, - ) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor[ - keep_indices - ] - - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (new_bs, MAX_TOTAL_TOKENS), - ) - - attention_mask[ - start_index:end_index, - left_offset:-right_padding, - ] = batch.attention_mask[ - keep_indices, - : batch.input_length, - ] - - if batch.cross_attention_mask is not None: - cross_attention_mask_shape = list(batch.cross_attention_mask.shape) - cross_attention_mask_shape[1] = MAX_TOTAL_TOKENS - cross_attention_mask_shape[0] = new_bs - cross_attention_mask_shape = torch.Size(cross_attention_mask_shape) - if cross_attention_mask is None: - cross_attention_mask = batch.cross_attention_mask.new_zeros( - cross_attention_mask_shape, - ) - cross_attention_mask[ - start_index:end_index, - left_offset:-right_padding, - ] = batch.cross_attention_mask[ - keep_indices, - : batch.input_length, - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((new_bs, 1)) - position_ids[start_index:end_index] = batch.position_ids[keep_indices, :] - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - if isinstance(batch.past_key_values, tuple): - batch.past_key_values = [ - [t.view(batch.batch_size, -1, *t.shape[-2:]) for t in layer] - for layer in batch.past_key_values - ] - elif len(batch.past_key_values[0][0].shape) == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(batch.batch_size, -1, *t.shape[-2:]) - - start_index = end_index - - first_past_kvs = batches[0].past_key_values - _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - past_key_values = [] - for layer_id in range(len(batches[0].past_key_values)): - if layer_id in CROSS_ATTENTION_LAYERS: - padded_past_keys_shape = list( - batches[0].past_key_values[layer_id][0].shape - ) - padded_past_keys_shape[0] = new_bs - padded_past_keys_shape = torch.Size(padded_past_keys_shape) - else: - padded_past_keys_shape = ( - new_bs, - num_heads, - MAX_TOTAL_TOKENS, - head_dim, - ) - - padded_past_keys = first_past_kvs[layer_id][0].new_zeros( - padded_past_keys_shape - ) - padded_past_values = first_past_kvs[layer_id][1].new_zeros( - padded_past_keys_shape - ) - start_index = 0 - for batch in batches: - keep_indices = [] - for req in batch.requests: - keep_indices.append(req.idx) - - left_offset = max_input_length - batch.input_length - right_padding = MAX_TOTAL_TOKENS - max_input_length - past_keys = batch.past_key_values[layer_id][0] - past_values = batch.past_key_values[layer_id][1] - # Clear reference to the original tensor - batch.past_key_values[layer_id] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - if layer_id in CROSS_ATTENTION_LAYERS: - padded_past_keys[start_index:end_index, :, :, :] = past_keys[ - keep_indices, :, :, : - ] - padded_past_values[start_index:end_index, :, :, :] = past_values[ - keep_indices, :, :, : - ] - - else: - padded_past_keys[ - start_index:end_index, :, left_offset:-right_padding, : - ] = past_keys[keep_indices, :, : batch.input_length, :] - padded_past_values[ - start_index:end_index, :, left_offset:-right_padding, : - ] = past_values[keep_indices, :, : batch.input_length, :] - - start_index = end_index - - past_key_values.append(tuple([padded_past_keys, padded_past_values])) - past_key_values = tuple(past_key_values) - - batch_id = batches[0].batch_id - top_n_tokens.extend([-1] * (new_bs - total_batch_size)) - fsm_grammar_states.extend([-1] * (new_bs - total_batch_size)) - - for idx, req in enumerate(requests): - req.idx = idx - - parameters = pad_next_token_chooser_parameters(parameters, new_bs) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - parameters, - batches[0].next_token_chooser.dtype, - batches[0].next_token_chooser.device, - batches[0].next_token_chooser.tokenizer, - fsm_grammar_states, - quantization_enabled=hq_env.is_quantization_enabled, - ) - input_length = max_input_length - - htorch.core.mark_step() - - return cls( - batch_id=batch_id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_length, - pixel_values=None, - pixel_attention_mask=None, - image_sizes=None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=cross_attention_mask, - prefilling=False, - ) - - -class VlmCausalLM(Model): - def __init__( - self, - model_class, - model_id: str, - *, - processor_class=AutoProcessor, - processor_kwargs=None, - batch_class=VlmCausalLMBatch, - revision, - quantize: Optional[str] = None, - dtype, - trust_remote_code: bool, - **kwargs, - ): - adapt_transformers_to_gaudi() - if processor_kwargs is None: - processor_kwargs = {} - self.processor = processor_class.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - **processor_kwargs, - ) - self.batch_class = batch_class - self.prev_bs = 0 - self.quantize = quantize - - # Create tokenizer - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - make_tokenizer_optional(tokenizer) - - # Create model - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) - dtype = torch.bfloat16 if dtype is None else dtype - device = torch.device("hpu") - - if hq_env.is_quantization_enabled: - htorch.core.hpu_set_env() - - # Get weight files - weight_files(model_id, revision=revision, extension=".safetensors") - - if world_size > 1: - os.environ.setdefault( - "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1" - ) - model = self.get_deepspeed_model(model_class, model_id, dtype, revision) - model = hq_env.prepare_model_for_quantization(model) - else: - # Check support for rope scaling - model_kwargs = {} - config = AutoConfig.from_pretrained(model_id) - if hasattr(config, "rope_scaling"): - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - model = model_class.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - trust_remote_code=trust_remote_code, - **model_kwargs, - ) - model = hq_env.prepare_model_for_quantization(model) - model = model.eval().to(device) - - self.enable_hpu_graph = ( - os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 - ) - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true" - model = remove_kv_cache_from_output(model) - if self.enable_hpu_graph: - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - - model = wrap_in_hpu_graph(model, disable_tensor_cache=True) - else: - if LAZY_MODE == 0: - # It is said that "keep_input_mutations" is safe for inference to be done - dbg_trace("TORCH COMPILE", "Torch compiling of model") - model.model = torch.compile( - model.model, - backend="hpu_backend", - options={"keep_input_mutations": True}, - ) - - model = hq_env.setup_quantization(model) - - if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: - raise ValueError(f"Model type {model.config.model_type} is not supported!") - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - if isinstance(model.config.eos_token_id, int): - tokenizer.pad_token_id = model.config.eos_token_id - elif isinstance(model.config.eos_token_id, list): - tokenizer.pad_token_id = model.config.eos_token_id[0] - else: - raise ValueError( - f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id" - ) - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - self.kwargs = { - "use_cache": True, - "return_dict": True, - } - - if model.config.model_type in ["llava_next"]: - self.kwargs["attn_softmax_bf16"] = True - self.kwargs["trim_logits"] = True - - if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true": - self.kwargs["use_flash_attention"] = True - if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true": - self.kwargs["flash_attention_recompute"] = True - - self.speculate = get_speculate() - if model.config.model_type == "mllama": - global CROSS_ATTENTION_LAYERS, BASE_IMAGE_TOKENS - CROSS_ATTENTION_LAYERS = model.config.text_config.cross_attention_layers - BASE_IMAGE_TOKENS = 0 - - super(VlmCausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - ) - - # Create profiler - ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")] - record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" - output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") - self.profiling_warmup_steps = ( - int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_steps = ( - int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) - if self.profiling_steps > 0: - self.hb_profiler = HabanaProfile( - wait=self.profiling_wait_steps, - warmup=self.profiling_warmup_steps, - active=self.profiling_steps, - output_dir=output_dir, - record_shapes=record_shapes, - ) - self.hb_profiler.start() - else: - self.hb_profiler = None - self.step = 0 - - @property - def batch_type(self) -> Type[VlmCausalLMBatch]: - return self.batch_class - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) - - def get_deepspeed_model( - self, - model_class, - model_id: str, - dtype: torch.dtype, - revision: Optional[str] = None, - ) -> torch.nn.Module: - import deepspeed - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - - world_size, rank, local_rank = initialize_distributed_hpu() - model_kwargs = {"revision": revision} - - # Initialize process(es) for DeepSpeed - deepspeed.init_distributed(dist_backend="hccl") - logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format( - world_size, rank, local_rank - ) - ) - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - load_to_meta = model_on_meta(config) - - # Check support for rope scaling - if hasattr(config, "rope_scaling"): - config.rope_scaling = self.get_rope_scaling() - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - if load_to_meta: - # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load - with deepspeed.OnDevice(dtype=dtype, device="meta"): - model = model_class.from_config(config, torch_dtype=dtype) - else: - # TODO: revisit placement on CPU when auto-injection is possible - with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = model_class.from_pretrained( - model_id, torch_dtype=dtype, **model_kwargs - ) - model = model.eval() - - # Initialize the model - ds_inference_kwargs = {"dtype": dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = False - ds_inference_kwargs["injection_policy"] = get_ds_injection_policy( - model.language_model.config - ) - - if load_to_meta: - # model loaded to meta is managed differently - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - checkpoint_files = [ - str(f) - for f in weight_files( - model_id, revision=revision, extension=".safetensors" - ) - ] - data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0} - json.dump(data, checkpoints_json) - checkpoints_json.flush() - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - model = deepspeed.init_inference(model, **ds_inference_kwargs) - - return model.module - - def get_rope_scaling(self) -> Optional[Dict]: - rope_scaling = os.getenv("ROPE_SCALING", None) - if rope_scaling is None: - return None - - rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) - return {"type": rope_scaling, "factor": float(rope_factor)} - - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - - def decode_token( - self, - all_input_ids: List[int], - prefix_offset: int = 0, - read_offset: int = 0, - ) -> Tuple[str, int, int]: - if is_tokenizer_transparent(self.tokenizer): - new_text = self.tokenizer.decode( - all_input_ids[read_offset:], skip_special_tokens=False - ) - return new_text, read_offset, len(all_input_ids) - else: - return super().decode_token(all_input_ids, prefix_offset, read_offset) - - def forward( - self, - batch: VlmCausalLMBatch, - bypass_hpu_graph: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": batch.input_ids, - "attention_mask": batch.attention_mask, - "past_key_values": batch.past_key_values, - "token_idx": batch.token_idx, - "pixel_values": batch.pixel_values, - } - - if self.model.config.model_type == "mllama": - kwargs["aspect_ratio_ids"] = batch.aspect_ratio_ids - kwargs["aspect_ratio_mask"] = batch.aspect_ratio_mask - kwargs["cross_attention_mask"] = batch.cross_attention_mask - else: - kwargs["image_sizes"] = batch.image_sizes - - hpu_kwargs = {} - # Optimum Habana got "lazy_mode" key-val only supported for llama type of models - if self.model.config.model_type == "llama": - hpu_kwargs["lazy_mode"] = LAZY_MODE == 1 - - if self.has_position_ids: - kwargs["position_ids"] = batch.position_ids - if bypass_hpu_graph is not None: - hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph - - kwargs.update(self.kwargs) - model_inputs = self.model.prepare_inputs_for_generation(**kwargs) - - if batch.past_key_values is not None: - return self.model.forward(**model_inputs, **hpu_kwargs) - else: - outputs = self.model.forward(**model_inputs, **hpu_kwargs) - return outputs.logits, outputs.past_key_values - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batches: list[VlmCausalLMBatch], is_warmup: bool = False - ) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]: - - start = time.time_ns() - # Results - generations: List[Generation] = [] - prev_batches = [] - requests_to_generate = [] - # In order to pipeline any actions on CPU we perform the operation in 3 main stages: - # Stage 1. Collect next token ids of any previously started generations - for batch_id, batch in enumerate(batches): - if batch.logits is not None: - logits = batch.logits - past = batch.past - prefill = batch.past_key_values is None - if prefill: - # no right padding for prefill - token_idx_scalar = batch.attention_mask.shape[-1] - 1 - token_idx = torch.tensor(token_idx_scalar).to(self.device) - else: - token_idx_scalar = ( - batch.attention_mask.shape[-1] - batch.right_padding - ) - token_idx = torch.tensor(token_idx_scalar).to(self.device) - - # Select next token - input_length = batch.input_length - if logits.shape[-2] > 1: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, - logits[:, input_length - 1 : input_length, :].squeeze(-2), - self.speculate, - ) - ) - else: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, logits.squeeze(-2), self.speculate - ) - ) - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - logprobs, - accepted_ids, - ) - - prev_batches.append( - { - "next_token_ids": next_token_ids, - "next_token_logprobs": next_token_logprobs, - } - ) - - for req_idx, req in enumerate(batch.requests): - requests_to_generate.append( - { - "req": req, - "prev_req_idx": req.idx, - "batch_id": batch_id, - "seed": batch.next_token_chooser.seeds[req_idx], - "do_sample": batch.next_token_chooser.do_sample[req_idx], - "top_n_tokens": batch.top_n_tokens[req_idx], - "top_token_ids": batch_top_token_ids[req_idx], - "top_token_logprobs": batch_top_token_logprobs[req_idx], - "grammar_state": batch.next_token_chooser.fsm_grammar_states[ - req.idx - ], - } - ) - - htorch.core.mark_step() - - # Add new token into input_ids - batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask.index_fill_(1, token_idx, 1) - - # add cross-attn mask for new token - if batch.cross_attention_mask is not None: - cross_attention_mask_prev = batch.cross_attention_mask - if token_idx is not None: - mask = cross_attention_mask_prev[ - :, token_idx - 2 : token_idx - 1, ... - ] - cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask) - batch.cross_attention_mask = cross_attention_mask_prev - - # Adjust lengths - batch.input_length += 1 - # Update position_ids - if prefill: - batch.position_ids = ( - torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 - ) - else: - batch.position_ids += 1 - # Update past key values - if prefill: - batch.past_key_values = past - - htorch.core.mark_step() - - # Stage 2. Prepare new batch for speculative scheduling - if len(batches) > 1: - batch = self.batch_type.concatenate( - batches, self.tokenizer.pad_token_id, is_warmup - ) - else: - batch = batches[0] - - prefill = batch.past_key_values is None - - # Check if we need to do any bookkeeping first - if not prefill: - batch = self.batch_type.recombine( - [batch], self.tokenizer.pad_token_id, is_warmup - ) - - scenario = "PREFILL" if prefill else "GENERATE" - if ( - self.enable_hpu_graph - and self.limit_hpu_graph - and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) - != self.prev_bs - ): - self.model.clear_cache() - self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) - dbg_trace( - scenario, - f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}", - ) - # assert batch.right_padding > 0, 'No more room for next token!' - - # Execute batch - if prefill: - # no right padding for prefill - # token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) - batch.logits, batch.past = self.forward( - batch, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - - elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): - # Don't schedule next forward if max_new_tokens for all requests equals 1 - # - we've already generated the first and only needed token in the prefill phase - pass - else: - # token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) - batch.logits = self.forward( - batch, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.aspect_ratio_ids is not None: - batch.aspect_ratio_ids = None - if batch.aspect_ratio_mask is not None: - batch.aspect_ratio_mask = None - - htorch.core.mark_step() - - start_decode = time.time_ns() - - # Stage 3. Finish and return previous generations - stopped = len(requests_to_generate) > 0 - for prev_batch in prev_batches: - prev_batch["next_token_logprobs"] = prev_batch[ - "next_token_logprobs" - ].tolist() - prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu() - htorch.core.mark_step() - - for req_data in requests_to_generate: - req = req_data["req"] - i = req_data["prev_req_idx"] - prev_batch_id = req_data["batch_id"] - assert len(prev_batches) > prev_batch_id - next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"] - next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"] - - request = req.data - input_length = req.input_length - prefix_offset = req.prefix_offset - read_offset = req.read_offset - do_sample = req_data["do_sample"] - seed = req_data["seed"] - stopping_criteria = req.stopping_criteria - all_input_ids = req.all_input_ids - next_token_id = next_token_ids_cpu[i] - next_token_logprob = next_token_logprobs[i] - top_n_tokens = req_data["top_n_tokens"] - top_token_ids = req_data["top_token_ids"] - top_token_logprobs = req_data["top_token_logprobs"] - grammar_state = req_data["grammar_state"] - - # Append next token to all tokens - all_input_ids[input_length] = next_token_id - new_input_length = input_length + 1 - - # Generated token - if ( - is_tokenizer_transparent(self.tokenizer) - and len(stopping_criteria.stop_sequence_criterias) == 0 - ): - next_token_text = "" - else: - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[0:new_input_length, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - if is_tokenizer_transparent(self.tokenizer): - output_text = None - else: - output_text = self.decode( - all_input_ids[ - new_input_length - - stopping_criteria.current_tokens : new_input_length, - 0, - ] - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + next_token_logprobs - prefill_token_ids = all_input_ids[0 : new_input_length - 1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id], - [next_token_logprob], - [next_token_text], - [next_token_id in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single_with_past_state( - req.idx, next_token_id, grammar_state - ) - ) - - req.all_input_ids = all_input_ids - req.input_length = new_input_length - req.prefix_offset = prefix_offset - req.read_offset = read_offset - - htorch.core.mark_step() - self.step = self.step + 1 - if self.hb_profiler is not None: - if ( - self.step - > self.profiling_wait_steps - + self.profiling_warmup_steps - + self.profiling_steps - ): - self.hb_profiler.stop() - else: - self.hb_profiler.step() - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch if not stopped else None, (forward_ns, decode_ns) - - def batch_from_pb(self, batch, is_warmup): - return self.batch_type.from_pb_processor( - batch, - self.tokenizer, - self.processor, - self.model.config, - self.dtype, - self.device, - is_warmup, - ) - - def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup): - batch = copy.deepcopy(request.batch) - for req in batch.requests: - req.truncate = seq_len - - for i in range(len(batch.requests) - batch_size): - batch.requests.pop() - - return self.batch_from_pb(batch, is_warmup) - - def warmup( - self, request: generate_pb2.WarmupRequest - ) -> Tuple[Optional[int], Optional[int], Optional[int]]: - global MAX_TOTAL_TOKENS - MAX_TOTAL_TOKENS = request.max_total_tokens - batch = self.batch_from_pb(request.batch, is_warmup=True) - max_input_tokens = request.max_input_tokens - max_prefill_batch_size = batch.input_ids.shape[0] - max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") - if max_batch_size_str is not None: - MAX_BATCH_SIZE = int(max_batch_size_str) - else: - raise ValueError("MAX_BATCH_SIZE is not set") - - try: - # max prefill batch size warmup - _, prefill_batch, _ = self.generate_token([batch], is_warmup=True) - except Exception: - raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " - f"You need to decrease `--max-batch-prefill-tokens`" - ) - - global BASE_IMAGE_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST - PREFILL_WARMUP_BATCH_SIZE_LIST = [] - batch_size = 1 - while batch_size <= max_prefill_batch_size: - PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) - batch_size = batch_size * 2 - if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size: - PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) - - if self.model.config.model_type == "mllama": - seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF - else: - seq_len = BASE_IMAGE_TOKENS - - PREFILL_WARMUP_SEQLEN_LIST = [] - i = 0 - while seq_len <= max_input_tokens: - PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) - seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF * (2**i) - i += 1 - if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_tokens: - PREFILL_WARMUP_SEQLEN_LIST.append(max_input_tokens) - - # Prefill and decode warmup - DECODE_WARMUP_BATCH_SIZE_LIST = [] - prefill_batch = None - decode_batch = None - try: - for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST: - for seq_len in PREFILL_WARMUP_SEQLEN_LIST: - batch = self.generate_warmup_batch( - request, seq_len, batch_size, is_warmup=True - ) - _, prefill_batch, _ = self.generate_token([batch], is_warmup=True) - assert prefill_batch is not None - _, decode_batch, _ = self.generate_token( - [prefill_batch], is_warmup=True - ) - - DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) - - except Exception: - raise RuntimeError( - f"Not enough memory to handle following prefill and decode warmup." - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" - f"You need to decrease `--max-batch-prefill-tokens`" - ) - - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing prefill and decode warmup successfully.\n" - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" - f"Memory stats: {mem_stats} " - ) - - max_decode_batch_size = MAX_BATCH_SIZE - batch_size = max_prefill_batch_size * 2 - # Decode warmup with bigger batch_size - try: - if ( - DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size - and batch_size <= max_decode_batch_size - ): - batches = [] - while batch_size <= max_decode_batch_size: - for i in range(int(batch_size / max_prefill_batch_size)): - batch = self.generate_warmup_batch( - request, - PREFILL_WARMUP_SEQLEN_LIST[0] - 1, - max_prefill_batch_size, - is_warmup=True, - ) - _, prefill_batch, _ = self.generate_token( - [batch], is_warmup=True - ) - batches.append(prefill_batch) - - _, decode_batch, _ = self.generate_token(batches, is_warmup=True) - DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) - batch_size = batch_size * 2 - batches.clear() - - if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: - max_decode_batch_size = math.floor(max_decode_batch_size / 2) * 2 - batch_size = max_decode_batch_size - for i in range(int(max_decode_batch_size / 2)): - batch = self.generate_warmup_batch( - request, - PREFILL_WARMUP_SEQLEN_LIST[0] - 1, - 2, - is_warmup=True, - ) - _, prefill_batch, _ = self.generate_token( - [batch], is_warmup=True - ) - batches.append(prefill_batch) - _, decode_batch, _ = self.generate_token(batches, is_warmup=True) - DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size) - - except Exception: - raise RuntimeError( - f"Not enough memory to handle batch_size({batch_size}) decode warmup." - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" - f"max_decode_batch_size is {max_decode_batch_size}" - f"You need to decrease env `MAX_BATCH_SIZE` or '--max_batch_size'" - ) - - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing decode warmup successfully.\n" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" - f"Memory stats: {mem_stats}" - ) - - max_supported_total_tokens = MAX_BATCH_SIZE * MAX_TOTAL_TOKENS - max_input_tokens = max_input_tokens - max_total_tokens = MAX_TOTAL_TOKENS - - return max_supported_total_tokens, max_input_tokens, max_total_tokens diff --git a/backends/gaudi/tgi-entrypoint.sh b/backends/gaudi/tgi-entrypoint.sh index d787ea8e..a5c3f5e1 100644 --- a/backends/gaudi/tgi-entrypoint.sh +++ b/backends/gaudi/tgi-entrypoint.sh @@ -7,13 +7,5 @@ if [[ "$*" == *"--sharded true"* ]]; then echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding' export PT_HPU_ENABLE_LAZY_COLLECTIVES=1 fi -# Check if ATTENTION environment variable is set to paged -if [[ "$ATTENTION" == "paged" ]]; then - # Check if Llama-4 is in the command line arguments - if [[ "$*" == *"Llama-4"* || "$*" == *"Qwen3"* ]]; then - echo 'ATTENTION=paged and Llama-4 or Qwen3 detected' - pip install git+https://github.com/huggingface/transformers.git@29338949 - fi -fi text-generation-launcher $@ diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index d9056e41..cd4ee290 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -27,10 +27,6 @@ impl Env { docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), } } - - pub fn should_start_a_single_hpu_shard(&self) -> bool { - self.hpu_env != "N/A" && std::env::var("ATTENTION").as_deref() != Ok("paged") - } } impl fmt::Display for Env { diff --git a/launcher/src/main.rs b/launcher/src/main.rs index ee80eb00..c727623c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1590,11 +1590,6 @@ fn spawn_shards( ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..num_shard { - if rank != 0 && env_runtime::Env::new().should_start_a_single_hpu_shard() { - tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server"); - break; - } - let model_id = args.model_id.clone(); let revision = args.revision.clone(); let uds_path = args.shard_uds_path.clone(); @@ -1670,10 +1665,6 @@ fn spawn_shards( if shard_ready == num_shard { break; } - if env_runtime::Env::new().should_start_a_single_hpu_shard() { - tracing::info!("HPU detected, shard is ready"); - break; - } } Err(TryRecvError::Empty) => { sleep(Duration::from_millis(100));