mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
1457 lines
59 KiB
Python
1457 lines
59 KiB
Python
import re
|
|
import torch
|
|
import os
|
|
import time
|
|
import math
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
import base64
|
|
import numpy
|
|
from opentelemetry import trace
|
|
from loguru import logger
|
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
|
import itertools
|
|
import tempfile
|
|
import copy
|
|
from text_generation_server.models import Model
|
|
from transformers import PreTrainedTokenizerBase
|
|
from transformers.image_processing_utils import select_best_resolution
|
|
from text_generation_server.utils.tokens import batch_top_tokens
|
|
from text_generation_server.pb import generate_pb2
|
|
from text_generation_server.models.causal_lm import (
|
|
CausalLMBatch,
|
|
CausalLMRequest,
|
|
remove_kv_cache_from_output,
|
|
biggest_single_chunk,
|
|
)
|
|
|
|
from transformers.models.llava_next.modeling_llava_next import (
|
|
get_anyres_image_grid_shape,
|
|
)
|
|
|
|
from transformers import AutoProcessor
|
|
import text_generation_server.habana_quantization_env as hq_env
|
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
|
from text_generation_server.utils import (
|
|
HeterogeneousNextTokenChooser,
|
|
StoppingCriteria,
|
|
make_tokenizer_optional,
|
|
is_tokenizer_transparent,
|
|
pad_next_token_chooser_parameters,
|
|
)
|
|
import habana_frameworks.torch as htorch
|
|
from optimum.habana.utils import HabanaProfile
|
|
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
|
from optimum.habana.utils import get_hpu_memory_stats
|
|
from optimum.habana.checkpoint_utils import get_ds_injection_policy
|
|
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModel,
|
|
PreTrainedTokenizerBase,
|
|
AutoConfig,
|
|
)
|
|
from optimum.habana.checkpoint_utils import (
|
|
get_repo_root,
|
|
model_on_meta,
|
|
write_checkpoints_json,
|
|
)
|
|
|
|
from text_generation_server.utils.speculate import get_speculate
|
|
from text_generation_server.models.types import (
|
|
Batch,
|
|
Tokens,
|
|
Generation,
|
|
GeneratedText,
|
|
)
|
|
from text_generation_server.utils.debug import dbg_trace
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
|
IDEFICS2_IMAGE_TOKEN = "<image>"
|
|
|
|
|
|
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
|
BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048))
|
|
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192))
|
|
MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072))
|
|
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
|
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
|
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
|
|
|
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
|
PREFILL_WARMUP_SEQLEN_LIST = []
|
|
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
|
CROSS_ATTENTION_LAYERS= []
|
|
def round_up(warmup_list:list, num) :
|
|
i = 0
|
|
for i in warmup_list:
|
|
if num <= i :
|
|
break
|
|
return i if i > 0 else num
|
|
|
|
def split(string) -> List[Dict[str, str]]:
|
|
parts = []
|
|
cursor = 0
|
|
for pattern in IMAGES.finditer(string):
|
|
start = pattern.start()
|
|
if start != cursor:
|
|
parts.append({"type": "text", "content": string[cursor:start]})
|
|
|
|
parts.append({"type": "image", "content": pattern.group(1)})
|
|
cursor = pattern.end()
|
|
|
|
if cursor != len(string):
|
|
parts.append({"type": "text", "content": string[cursor:]})
|
|
|
|
return parts
|
|
|
|
def image_text_replacement(config) -> str:
|
|
if config.model_type == "idefics2":
|
|
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
|
return image_str
|
|
elif config.model_type == "llava_next":
|
|
return "<image>"
|
|
elif config.model_type == "paligemma":
|
|
return "<image>"
|
|
elif config.model_type == "mllama":
|
|
return "<|image|>"
|
|
else:
|
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
|
|
|
def image_text_replacement_fixup(config, text: str) -> str:
|
|
if config.model_type == "idefics2":
|
|
return text.replace(
|
|
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
|
)
|
|
return text
|
|
|
|
|
|
def get_unpadded_features(
|
|
original_height: int,
|
|
original_width: int,
|
|
npatches: int,
|
|
num_patch_height: int,
|
|
num_patch_width: int,
|
|
) -> Tuple[int, int]:
|
|
current_height = npatches * num_patch_height
|
|
current_width = npatches * num_patch_width
|
|
|
|
aspect_ratio: float = original_width / original_height
|
|
current_aspect_ratio: float = current_width / current_height
|
|
|
|
if aspect_ratio > current_aspect_ratio:
|
|
new_height = (original_height * current_width) // original_width
|
|
padding = (current_height - new_height) // 2
|
|
current_height = current_height - (2 * padding)
|
|
else:
|
|
new_width = (original_width * current_height) // original_height
|
|
padding = (current_width - new_width) // 2
|
|
current_width = current_width - (2 * padding)
|
|
|
|
unpadded_features = current_height * current_width
|
|
newline_features = current_height
|
|
return (unpadded_features, newline_features)
|
|
|
|
|
|
def get_number_of_features(height: int, width: int, config) -> int:
|
|
# From config
|
|
# Hardcoded for CLIP for now
|
|
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
|
image_grid_pinpoints = config.image_grid_pinpoints
|
|
image_size = config.vision_config.image_size
|
|
patch_size = config.vision_config.patch_size
|
|
|
|
assert image_size % patch_size == 0
|
|
|
|
npatches = image_size // patch_size
|
|
|
|
# Dimensions are intentionally swapped to be bug-compatible with
|
|
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
[height, width],
|
|
image_grid_pinpoints,
|
|
image_size,
|
|
)
|
|
|
|
unpadded_features, newline_features = get_unpadded_features(
|
|
height, width, npatches, num_patch_height, num_patch_width
|
|
)
|
|
# The base patch covers the entire image
|
|
base_features = npatches**2
|
|
return unpadded_features + newline_features + base_features
|
|
|
|
|
|
class VlmCausalLMBatch(CausalLMBatch):
|
|
pixel_values: Optional[List[torch.Tensor]]
|
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
|
image_sizes: Optional[List[Tuple[int, int]]]
|
|
aspect_ratio_ids: Optional[torch.Tensor] = None
|
|
aspect_ratio_mask: Optional[torch.Tensor] = None
|
|
cross_attention_mask: Optional[torch.Tensor] = None
|
|
prefilling: bool = True
|
|
token_idx: torch.Tensor = None
|
|
|
|
def __init__(self,
|
|
batch_id,
|
|
requests,
|
|
input_ids,
|
|
attention_mask,
|
|
position_ids,
|
|
past_key_values,
|
|
merged_kv_cache,
|
|
next_token_chooser,
|
|
top_n_tokens,
|
|
top_n_tokens_tensor,
|
|
input_length,
|
|
pixel_values: Optional[List[torch.Tensor]] = None,
|
|
pixel_attention_mask: Optional[List[torch.Tensor]] = None,
|
|
image_sizes: Optional[List[Tuple[int, int]]] = None,
|
|
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
|
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
|
cross_attention_mask: Optional[torch.Tensor] = None,
|
|
prefilling: Optional[bool] = True,
|
|
):
|
|
super().__init__(
|
|
batch_id = batch_id,
|
|
requests = requests,
|
|
input_ids = input_ids,
|
|
attention_mask = attention_mask,
|
|
position_ids = position_ids,
|
|
past_key_values = past_key_values,
|
|
merged_kv_cache = merged_kv_cache,
|
|
next_token_chooser = next_token_chooser,
|
|
top_n_tokens = top_n_tokens,
|
|
top_n_tokens_tensor = top_n_tokens_tensor,
|
|
input_length = input_length)
|
|
|
|
self.pixel_values = pixel_values
|
|
self.pixel_attention_mask = pixel_attention_mask
|
|
self.image_sizes = image_sizes
|
|
self.aspect_ratio_ids = aspect_ratio_ids
|
|
self.aspect_ratio_mask = aspect_ratio_mask
|
|
self.cross_attention_mask = cross_attention_mask
|
|
self.prefilling = prefilling
|
|
|
|
@property
|
|
def token_idx(self):
|
|
if self.prefilling:
|
|
# no right padding for prefill
|
|
token_idx_scalar = self.attention_mask.shape[-1] - 1
|
|
return torch.tensor(token_idx_scalar).to(self.attention_mask.device)
|
|
else:
|
|
token_idx_scalar = (
|
|
self.attention_mask.shape[-1] - self.right_padding
|
|
)
|
|
return torch.tensor(token_idx_scalar).to(self.attention_mask.device)
|
|
|
|
def padding_process(self, pad_id:int):
|
|
#self.input_ids = torch.index_select(self.input_ids, 1, self.token_idx - 1)
|
|
right_padding = MAX_TOTAL_TOKENS - self.attention_mask.shape[1]
|
|
self.input_ids = torch.nn.functional.pad(self.input_ids, (0, right_padding), value=pad_id)
|
|
self.attention_mask = torch.nn.functional.pad(
|
|
self.attention_mask, (0, right_padding), value=0
|
|
)
|
|
# if self.position_ids is not None:
|
|
# self.position_ids = torch.index_select(self.position_ids, 1, self.token_idx - 1) + 1
|
|
if self.cross_attention_mask is not None:
|
|
self.cross_attention_mask = torch.nn.functional.pad(
|
|
self.cross_attention_mask, (0, 0, 0, 0, 0, right_padding), value=0
|
|
)
|
|
if self.past is not None:
|
|
past_key_values_list = list(self.past_key_values)
|
|
for layer_id in range(len(self.past)):
|
|
past_key_value_list = list(self.past_key_values[layer_id])
|
|
if layer_id not in CROSS_ATTENTION_LAYERS:
|
|
past_key_value_list[0] = torch.nn.functional.pad(
|
|
self.past_key_values[layer_id][0], (0, 0, 0, right_padding), value=0
|
|
)
|
|
past_key_value_list[1] = torch.nn.functional.pad(
|
|
self.past_key_values[layer_id][1], (0, 0, 0, right_padding), value=0
|
|
)
|
|
past_key_values_list[layer_id] = tuple(past_key_value_list)
|
|
self.past_key_values = tuple(past_key_values_list)
|
|
|
|
self.prefilling = False
|
|
self.input_length = self.input_length
|
|
|
|
|
|
@classmethod
|
|
def from_tokenized(
|
|
cls,
|
|
pb: generate_pb2.Batch,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
batch_tokenized_inputs,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
is_warmup: bool = False,
|
|
) -> "VlmCausalLMBatch":
|
|
|
|
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
|
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
|
|
|
max_input_length = max(r.data.truncate for r in requests)
|
|
max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
|
|
# TODO: Add support for sparse batches
|
|
top_n_tokens = [r.top_n_tokens for r in pb.requests]
|
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
|
|
|
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
|
# this means that we cannot shift inputs to the left after a long input sequence
|
|
# was filtered out
|
|
new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
|
|
parameters = [r.parameters for r in pb.requests]
|
|
# append the dummy parameters for dummy request
|
|
parameters = pad_next_token_chooser_parameters(parameters, new_bs)
|
|
|
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
|
pb=parameters,
|
|
dtype=dtype,
|
|
device=device,
|
|
tokenizer=tokenizer,
|
|
quantization_enabled=hq_env.is_quantization_enabled,
|
|
)
|
|
tokenized_inputs = batch_tokenized_inputs
|
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
|
|
|
bucket_size = max_input_length
|
|
left_padding = max_input_length - input_len
|
|
if is_warmup is False:
|
|
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
|
|
bucket_size = rounded_seq_len - 1
|
|
left_padding = bucket_size - input_len
|
|
|
|
input_ids = tokenized_inputs["input_ids"]
|
|
attention_mask = tokenized_inputs["attention_mask"]
|
|
cross_attention_mask = tokenized_inputs.get("cross_attention_mask", None)
|
|
# Allocate space for first token
|
|
input_ids = torch.nn.functional.pad(
|
|
input_ids, (left_padding, 1), value=tokenizer.pad_token_id
|
|
)
|
|
attention_mask = torch.nn.functional.pad(
|
|
attention_mask, (left_padding, 1), value=0
|
|
)
|
|
if cross_attention_mask is not None:
|
|
cross_attention_mask = torch.nn.functional.pad(
|
|
cross_attention_mask, (0, 0, 0, 0, left_padding, 1), value=0
|
|
)
|
|
all_input_ids = torch.nn.functional.pad(
|
|
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
|
).T.split(1, dim=1)
|
|
|
|
# New input length after left padding
|
|
input_len = bucket_size
|
|
for r in requests:
|
|
r.input_length = input_len
|
|
r.prefix_offset = input_len - 5
|
|
r.read_offset = input_len
|
|
r.all_input_ids = all_input_ids[r.idx]
|
|
input_ids = input_ids.to(device)
|
|
attention_mask = attention_mask.to(device)
|
|
cross_attention_mask = cross_attention_mask.to(device) if cross_attention_mask is not None else None
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
htorch.core.mark_step()
|
|
|
|
return cls(
|
|
batch_id=pb.id,
|
|
requests=requests,
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=None,
|
|
merged_kv_cache=False,
|
|
next_token_chooser=next_token_chooser,
|
|
top_n_tokens=top_n_tokens,
|
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
|
input_length=input_len,
|
|
cross_attention_mask=cross_attention_mask,
|
|
)
|
|
|
|
@classmethod
|
|
def batch_tokenized_inputs(
|
|
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config, is_warmup
|
|
):
|
|
image_inputs = {}
|
|
texts = []
|
|
images = []
|
|
image_indices = []
|
|
batch_tokenized_inputs = {}
|
|
|
|
for i, r in enumerate(requests):
|
|
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
|
curr_text = ""
|
|
curr_image = None
|
|
curr_i = None
|
|
for chunk in r.input_chunks.chunks:
|
|
chunk_type = chunk.WhichOneof("chunk")
|
|
if chunk_type == "text":
|
|
curr_text += chunk.text
|
|
elif chunk_type == "image":
|
|
image = Image.open(BytesIO(chunk.image.data))
|
|
# TODO unsure about BOS
|
|
curr_image = image
|
|
curr_i = i
|
|
else:
|
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
|
|
|
if image_text_replacement(config) not in curr_text:
|
|
if "<image>" in curr_text:
|
|
curr_text = curr_text.replace("<image>", image_text_replacement(config))
|
|
else:
|
|
curr_text = image_text_replacement(config) + curr_text
|
|
|
|
texts.append(curr_text)
|
|
if curr_image is not None:
|
|
if config.model_type == "mllama":
|
|
images.append([curr_image])
|
|
else:
|
|
images.append(curr_image)
|
|
|
|
missing_inputs = 0
|
|
dummy_images = None
|
|
if is_warmup is False:
|
|
new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
|
|
missing_inputs = new_bs - len(requests)
|
|
if missing_inputs > 0:
|
|
dummy_inputs = []
|
|
if len(texts) > 0:
|
|
dummy_inputs = [texts[0]] * missing_inputs
|
|
dummy_images = [images[0]] * missing_inputs
|
|
texts += dummy_inputs
|
|
images += dummy_images
|
|
|
|
processor_output = processor(images,
|
|
texts,
|
|
truncation=True,
|
|
max_length=r.truncate,
|
|
add_special_tokens=r.add_special_tokens,
|
|
return_tensors="pt",
|
|
padding_side="left",
|
|
padding="longest")
|
|
if "input_ids" in processor_output:
|
|
batch_tokenized_inputs.update({"input_ids" : processor_output["input_ids"]})
|
|
if "attention_mask" in processor_output:
|
|
batch_tokenized_inputs.update({"attention_mask" : processor_output["attention_mask"]})
|
|
if "cross_attention_mask" in processor_output:
|
|
batch_tokenized_inputs.update({"cross_attention_mask" : processor_output["cross_attention_mask"]})
|
|
if "pixel_values" in processor_output:
|
|
image_inputs.update({"pixel_values" : processor_output["pixel_values"]})
|
|
if "pixel_attention_mask" in processor_output:
|
|
image_inputs.update({"pixel_attention_mask" : processor_output["pixel_attention_mask"]})
|
|
if "aspect_ratio_ids" in processor_output:
|
|
image_inputs.update({"aspect_ratio_ids" : processor_output["aspect_ratio_ids"]})
|
|
if "aspect_ratio_mask" in processor_output:
|
|
image_inputs.update({"aspect_ratio_mask" : processor_output["aspect_ratio_mask"]})
|
|
if "image_sizes" in processor_output:
|
|
image_inputs.update({"image_sizes" : processor_output["image_sizes"]})
|
|
|
|
return batch_tokenized_inputs, image_inputs
|
|
|
|
@classmethod
|
|
def from_pb_processor(
|
|
cls,
|
|
pb: generate_pb2.Batch,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
processor,
|
|
config,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
is_warmup: bool = False,
|
|
) -> "VlmCausalLMBatch":
|
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
|
pb.requests, tokenizer, processor, config, is_warmup
|
|
)
|
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device, is_warmup=is_warmup)
|
|
if image_inputs is not None:
|
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
|
if "pixel_attention_mask" in image_inputs:
|
|
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
|
device=device
|
|
)
|
|
else:
|
|
batch.pixel_attention_mask = None
|
|
if "image_sizes" in image_inputs:
|
|
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
|
else:
|
|
batch.image_sizes = None
|
|
if "aspect_ratio_ids" in image_inputs:
|
|
batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device)
|
|
else:
|
|
batch.aspect_ratio_ids = None
|
|
if "aspect_ratio_mask" in image_inputs:
|
|
batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to(device=device)
|
|
else:
|
|
batch.aspect_ratio_mask = None
|
|
else:
|
|
batch.pixel_values = None
|
|
batch.pixel_attention_mask = None
|
|
batch.image_sizes = None
|
|
batch.aspect_ratio_ids = None
|
|
batch.aspect_ratio_mask = None
|
|
batch.cross_attention_mask = None
|
|
|
|
return batch
|
|
|
|
@classmethod
|
|
@tracer.start_as_current_span("concatenate")
|
|
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup:bool = False) -> "CausalLMBatch":
|
|
return cls.recombine(batches, pad_token_id, is_warmup)
|
|
|
|
|
|
|
|
@classmethod
|
|
def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch":
|
|
if not all(b.past_key_values is not None for b in batches):
|
|
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
|
|
# Used for padding
|
|
|
|
total_requests = sum(len(b) for b in batches)
|
|
new_bs = total_requests
|
|
if not is_warmup:
|
|
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST,total_requests)
|
|
|
|
if len(batches) > 1:
|
|
scenario = "CONCAT"
|
|
elif batches[0].prefilling:
|
|
scenario = "SHIFT"
|
|
else:
|
|
return batches[0]
|
|
|
|
dbg_trace(
|
|
scenario,
|
|
f"bs:{[b.batch_size for b in batches]}->{new_bs}"
|
|
f" reqs:{[len(b) for b in batches]}"
|
|
)
|
|
|
|
if scenario == "SHIFT":
|
|
batch = batches[0]
|
|
batch.padding_process(pad_token_id)
|
|
return batch
|
|
|
|
total_batch_size = 0
|
|
max_input_length = 0
|
|
for i, batch in enumerate(batches):
|
|
total_batch_size += len(batch)
|
|
max_input_length = max(max_input_length, batch.input_length)
|
|
# Batch attributes
|
|
requests = []
|
|
input_lengths = []
|
|
top_n_tokens = []
|
|
max_tokens = 0
|
|
parameters = []
|
|
fsm_grammar_states = []
|
|
|
|
# Batch tensors
|
|
input_ids = None
|
|
attention_mask = None
|
|
position_ids = None
|
|
past_key_values = []
|
|
top_n_tokens_tensor = None
|
|
cross_attention_mask = None
|
|
# Used for slicing correctly inside the tensors
|
|
# Equivalent to a cumsum on batch sizes
|
|
start_index = 0
|
|
for i, batch in enumerate(batches):
|
|
keep_indices = []
|
|
for req in batch.requests:
|
|
keep_indices.append(req.idx)
|
|
|
|
requests.extend(batch.requests)
|
|
parameters.extend([r.data.parameters for r in batch.requests])
|
|
fsm_grammar_states.extend([batch.next_token_chooser.fsm_grammar_states[i] for i in keep_indices])
|
|
input_lengths.extend([batch.input_length])
|
|
top_n_tokens.extend([batch.top_n_tokens[i] for i in keep_indices])
|
|
|
|
# Slicing end index for this batch
|
|
end_index = start_index + len(batch)
|
|
|
|
# We only concatenate batches that did at least one step
|
|
if batch.past_key_values is None:
|
|
raise ValueError("only concatenate prefilled batches")
|
|
|
|
# Create empty tensor
|
|
# input_ids is always of shape [batch_size, 1]
|
|
# We do not need to pad it
|
|
if input_ids is None:
|
|
input_ids = batch.input_ids.new_empty((new_bs, MAX_TOTAL_TOKENS))
|
|
# # Copy to correct indices
|
|
|
|
left_offset = max_input_length - batch.input_length
|
|
right_padding = MAX_TOTAL_TOKENS - max_input_length
|
|
input_ids[start_index:end_index, left_offset:-right_padding] = batch.input_ids[keep_indices, :batch.input_length]
|
|
|
|
# Create padded tensor
|
|
if top_n_tokens_tensor is None:
|
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
|
new_bs,
|
|
)
|
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor[keep_indices]
|
|
|
|
if attention_mask is None:
|
|
attention_mask = batch.attention_mask.new_zeros(
|
|
(new_bs, MAX_TOTAL_TOKENS),
|
|
)
|
|
|
|
attention_mask[
|
|
start_index:end_index,
|
|
left_offset:-right_padding,
|
|
] = batch.attention_mask[
|
|
keep_indices,
|
|
:batch.input_length,
|
|
]
|
|
|
|
if batch.cross_attention_mask is not None:
|
|
cross_attention_mask_shape = list(batch.cross_attention_mask.shape)
|
|
cross_attention_mask_shape[1] = MAX_TOTAL_TOKENS
|
|
cross_attention_mask_shape[0] = new_bs
|
|
cross_attention_mask_shape = torch.Size(cross_attention_mask_shape)
|
|
if cross_attention_mask is None:
|
|
cross_attention_mask = batch.cross_attention_mask.new_zeros(
|
|
cross_attention_mask_shape,
|
|
)
|
|
cross_attention_mask[
|
|
start_index:end_index,
|
|
left_offset:-right_padding,
|
|
] = batch.cross_attention_mask[
|
|
keep_indices,
|
|
:batch.input_length,
|
|
]
|
|
|
|
# Create empty tensor
|
|
# position_ids is always of shape [batch_size, 1]
|
|
if position_ids is None:
|
|
position_ids = batch.position_ids.new_empty((new_bs, 1))
|
|
position_ids[start_index:end_index] = batch.position_ids[keep_indices, :]
|
|
|
|
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
|
|
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
|
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
|
# And ensure that we can update tensors in-place
|
|
if isinstance(batch.past_key_values, tuple):
|
|
batch.past_key_values = [
|
|
[t.view(batch.batch_size, -1, *t.shape[-2:]) for t in layer]
|
|
for layer in batch.past_key_values
|
|
]
|
|
elif len(batch.past_key_values[0][0].shape) == 3:
|
|
for layer in batch.past_key_values:
|
|
for k, t in enumerate(layer):
|
|
layer[k] = t.view(batch.batch_size, -1, *t.shape[-2:])
|
|
|
|
start_index = end_index
|
|
|
|
first_past_kvs = batches[0].past_key_values
|
|
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
|
|
past_key_values = []
|
|
for layer_id in range(len(batches[0].past_key_values)):
|
|
if layer_id in CROSS_ATTENTION_LAYERS:
|
|
padded_past_keys_shape = list(batches[0].past_key_values[layer_id][0].shape)
|
|
padded_past_keys_shape[0] = new_bs
|
|
padded_past_keys_shape = torch.Size(padded_past_keys_shape)
|
|
else:
|
|
padded_past_keys_shape = (
|
|
new_bs,
|
|
num_heads,
|
|
MAX_TOTAL_TOKENS,
|
|
head_dim,
|
|
)
|
|
|
|
padded_past_keys = first_past_kvs[layer_id][0].new_zeros(padded_past_keys_shape)
|
|
padded_past_values = first_past_kvs[layer_id][1].new_zeros(padded_past_keys_shape)
|
|
start_index = 0
|
|
for batch in batches:
|
|
keep_indices = []
|
|
for req in batch.requests:
|
|
keep_indices.append(req.idx)
|
|
|
|
left_offset = max_input_length - batch.input_length
|
|
right_padding = MAX_TOTAL_TOKENS - max_input_length
|
|
past_keys = batch.past_key_values[layer_id][0]
|
|
past_values = batch.past_key_values[layer_id][1]
|
|
# Clear reference to the original tensor
|
|
batch.past_key_values[layer_id] = None
|
|
|
|
# Slicing end index for this batch
|
|
end_index = start_index + len(batch)
|
|
# We slice the keys to remove the padding from previous batches
|
|
if layer_id in CROSS_ATTENTION_LAYERS:
|
|
padded_past_keys[start_index:end_index, :, :, :] = (
|
|
past_keys[keep_indices, :, :, :]
|
|
)
|
|
padded_past_values[start_index:end_index, :, :, :] = (
|
|
past_values[keep_indices, :, :, :]
|
|
)
|
|
|
|
else:
|
|
padded_past_keys[start_index:end_index, :, left_offset:-right_padding, :] = (
|
|
past_keys[keep_indices, :, :batch.input_length, :]
|
|
)
|
|
padded_past_values[start_index:end_index, :, left_offset:-right_padding, :] = (
|
|
past_values[keep_indices, :, :batch.input_length, :]
|
|
)
|
|
|
|
start_index = end_index
|
|
|
|
past_key_values.append(tuple([padded_past_keys, padded_past_values]))
|
|
past_key_values = tuple(past_key_values)
|
|
|
|
batch_id = batches[0].batch_id
|
|
top_n_tokens.extend([-1] * (new_bs - total_batch_size))
|
|
fsm_grammar_states.extend([-1] * (new_bs - total_batch_size))
|
|
|
|
for idx, req in enumerate(requests):
|
|
req.idx = idx
|
|
|
|
parameters = pad_next_token_chooser_parameters(parameters, new_bs)
|
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
|
parameters,
|
|
batches[0].next_token_chooser.dtype,
|
|
batches[0].next_token_chooser.device,
|
|
batches[0].next_token_chooser.tokenizer,
|
|
fsm_grammar_states,
|
|
quantization_enabled=hq_env.is_quantization_enabled,
|
|
)
|
|
input_length = max_input_length
|
|
|
|
htorch.core.mark_step()
|
|
|
|
return cls(
|
|
batch_id=batch_id,
|
|
requests=requests,
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
merged_kv_cache=False,
|
|
next_token_chooser=next_token_chooser,
|
|
top_n_tokens=top_n_tokens,
|
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
|
input_length=input_length,
|
|
pixel_values=None,
|
|
pixel_attention_mask=None,
|
|
image_sizes=None,
|
|
aspect_ratio_ids=None,
|
|
aspect_ratio_mask=None,
|
|
cross_attention_mask=cross_attention_mask,
|
|
prefilling=False,
|
|
)
|
|
|
|
class VlmCausalLM(Model):
|
|
def __init__(
|
|
self,
|
|
model_class,
|
|
model_id: str,
|
|
*,
|
|
processor_class=AutoProcessor,
|
|
processor_kwargs=None,
|
|
batch_class=VlmCausalLMBatch,
|
|
revision,
|
|
quantize: Optional[str] = None,
|
|
dtype,
|
|
trust_remote_code: bool,
|
|
**kwargs,
|
|
):
|
|
adapt_transformers_to_gaudi()
|
|
if processor_kwargs is None:
|
|
processor_kwargs = {}
|
|
self.processor = processor_class.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
trust_remote_code=trust_remote_code,
|
|
**processor_kwargs,
|
|
)
|
|
self.batch_class = batch_class
|
|
self.prev_bs = 0
|
|
self.quantize = quantize
|
|
|
|
# Create tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
padding_side="left",
|
|
truncation_side="left",
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
make_tokenizer_optional(tokenizer)
|
|
|
|
# Create model
|
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
rank = int(os.getenv("RANK", "0"))
|
|
dtype = torch.bfloat16 if dtype is None else dtype
|
|
device = torch.device("hpu")
|
|
|
|
if hq_env.is_quantization_enabled:
|
|
htorch.core.hpu_set_env()
|
|
|
|
if world_size > 1:
|
|
os.environ.setdefault("DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1")
|
|
model = self.get_deepspeed_model(
|
|
model_class, model_id, dtype, revision
|
|
)
|
|
model = hq_env.prepare_model_for_quantization(model)
|
|
else:
|
|
get_repo_root(model_id)
|
|
|
|
# Check support for rope scaling
|
|
model_kwargs = {}
|
|
config = AutoConfig.from_pretrained(
|
|
model_id
|
|
)
|
|
if hasattr(config, "rope_scaling"):
|
|
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
|
|
|
model = model_class.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
torch_dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
**model_kwargs
|
|
)
|
|
model = hq_env.prepare_model_for_quantization(model)
|
|
model = model.eval().to(device)
|
|
|
|
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
|
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
|
model = remove_kv_cache_from_output(model)
|
|
if self.enable_hpu_graph:
|
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
|
else:
|
|
if LAZY_MODE == 0:
|
|
# It is said that "keep_input_mutations" is safe for inference to be done
|
|
dbg_trace(
|
|
"TORCH COMPILE", f'Torch compiling of model')
|
|
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
|
|
|
model = hq_env.setup_quantization(model)
|
|
|
|
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
|
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
|
|
|
if tokenizer.pad_token_id is None:
|
|
if model.config.pad_token_id is not None:
|
|
tokenizer.pad_token_id = model.config.pad_token_id
|
|
elif model.config.eos_token_id is not None:
|
|
if isinstance(model.config.eos_token_id, int):
|
|
tokenizer.pad_token_id = model.config.eos_token_id
|
|
elif isinstance(model.config.eos_token_id, list):
|
|
tokenizer.pad_token_id = model.config.eos_token_id[0]
|
|
else:
|
|
raise ValueError(
|
|
f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id"
|
|
)
|
|
elif tokenizer.eos_token_id is not None:
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
else:
|
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
|
|
self.kwargs = {
|
|
"use_cache": True,
|
|
"return_dict": True,
|
|
}
|
|
|
|
if model.config.model_type in ["llava_next"]:
|
|
self.kwargs["attn_softmax_bf16"] = True
|
|
self.kwargs["trim_logits"] = True
|
|
|
|
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
|
|
self.kwargs["use_flash_attention"] = True
|
|
if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true":
|
|
self.kwargs["flash_attention_recompute"] = True
|
|
|
|
self.speculate = get_speculate()
|
|
if model.config.model_type == "mllama":
|
|
global CROSS_ATTENTION_LAYERS, BASE_IMAGE_TOKENS
|
|
CROSS_ATTENTION_LAYERS = model.config.text_config.cross_attention_layers
|
|
BASE_IMAGE_TOKENS = 0
|
|
|
|
super(VlmCausalLM, self).__init__(
|
|
model_id=model_id,
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
requires_padding=True,
|
|
dtype=dtype,
|
|
device=device,
|
|
rank=rank,
|
|
)
|
|
|
|
# Create profiler
|
|
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
|
|
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
|
|
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
|
|
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
|
|
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
|
|
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
|
|
if self.profiling_steps > 0:
|
|
self.hb_profiler = HabanaProfile(
|
|
wait=self.profiling_wait_steps,
|
|
warmup=self.profiling_warmup_steps,
|
|
active=self.profiling_steps,
|
|
output_dir=output_dir,
|
|
record_shapes=record_shapes
|
|
)
|
|
self.hb_profiler.start()
|
|
else:
|
|
self.hb_profiler = None
|
|
self.step = 0
|
|
|
|
|
|
@property
|
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
|
return self.batch_class
|
|
|
|
def max_past(self) -> Optional[int]:
|
|
return getattr(self.model.text_model, "max_past", None)
|
|
|
|
def get_deepspeed_model(
|
|
self,
|
|
model_class,
|
|
model_id: str,
|
|
dtype: torch.dtype,
|
|
revision: Optional[str] = None
|
|
) -> torch.nn.Module:
|
|
import deepspeed
|
|
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
|
|
|
world_size, rank, local_rank = initialize_distributed_hpu()
|
|
model_kwargs = {
|
|
"revision": revision
|
|
}
|
|
|
|
# Initialize process(es) for DeepSpeed
|
|
deepspeed.init_distributed(dist_backend="hccl")
|
|
logger.info(
|
|
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank)
|
|
)
|
|
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
|
load_to_meta = model_on_meta(config)
|
|
|
|
# Check support for rope scaling
|
|
if hasattr(config, "rope_scaling"):
|
|
config.rope_scaling = self.get_rope_scaling()
|
|
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
|
|
|
if load_to_meta:
|
|
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
|
|
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
|
model = model_class.from_config(config, torch_dtype=dtype)
|
|
else:
|
|
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
|
# TODO: revisit placement on CPU when auto-injection is possible
|
|
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
|
model = model_class.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs)
|
|
model = model.eval()
|
|
|
|
# Initialize the model
|
|
ds_inference_kwargs = {"dtype": dtype}
|
|
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
|
ds_inference_kwargs["enable_cuda_graph"] = False
|
|
ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(model.language_model.config)
|
|
|
|
if load_to_meta:
|
|
# model loaded to meta is managed differently
|
|
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
|
|
write_checkpoints_json(model_id, local_rank, checkpoints_json)
|
|
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
|
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
|
|
|
return model.module
|
|
|
|
def get_rope_scaling(self) -> Optional[Dict]:
|
|
rope_scaling = os.getenv("ROPE_SCALING", None)
|
|
if rope_scaling is None:
|
|
return None
|
|
|
|
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
|
|
return {
|
|
'type': rope_scaling, 'factor': float(rope_factor)
|
|
}
|
|
|
|
def decode(self, generated_ids: List[int]) -> str:
|
|
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
|
|
def decode_token(
|
|
self,
|
|
all_input_ids: List[int],
|
|
prefix_offset: int = 0,
|
|
read_offset: int = 0,
|
|
) -> Tuple[str, int, int]:
|
|
if is_tokenizer_transparent(self.tokenizer):
|
|
new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False)
|
|
return new_text, read_offset, len(all_input_ids)
|
|
else:
|
|
return super().decode_token(all_input_ids, prefix_offset, read_offset)
|
|
|
|
def forward(
|
|
self,
|
|
batch: VlmCausalLMBatch,
|
|
bypass_hpu_graph: Optional[bool] = None,
|
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
# Model Forward
|
|
kwargs = {
|
|
"input_ids": batch.input_ids,
|
|
"attention_mask": batch.attention_mask,
|
|
"past_key_values": batch.past_key_values,
|
|
"token_idx": batch.token_idx,
|
|
"pixel_values": batch.pixel_values,
|
|
}
|
|
|
|
if self.model.config.model_type == "mllama":
|
|
kwargs["aspect_ratio_ids"] = batch.aspect_ratio_ids
|
|
kwargs["aspect_ratio_mask"] = batch.aspect_ratio_mask
|
|
kwargs["cross_attention_mask"] = batch.cross_attention_mask
|
|
else:
|
|
kwargs["image_sizes"] = batch.image_sizes
|
|
|
|
hpu_kwargs = {}
|
|
# Optimum Habana got "lazy_mode" key-val only supported for llama type of models
|
|
if self.model.config.model_type == "llama" :
|
|
hpu_kwargs["lazy_mode"] = LAZY_MODE == 1
|
|
|
|
if self.has_position_ids:
|
|
kwargs["position_ids"] = batch.position_ids
|
|
if bypass_hpu_graph != None:
|
|
hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
|
|
|
kwargs.update(self.kwargs)
|
|
model_inputs = self.model.prepare_inputs_for_generation(**kwargs)
|
|
|
|
if batch.past_key_values is not None:
|
|
return self.model.forward(**model_inputs, **hpu_kwargs)
|
|
else:
|
|
outputs = self.model.forward(**model_inputs, **hpu_kwargs)
|
|
return outputs.logits, outputs.past_key_values
|
|
|
|
@tracer.start_as_current_span("generate_token")
|
|
def generate_token(
|
|
self, batches: list[VlmCausalLMBatch], is_warmup: bool = False
|
|
) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]:
|
|
|
|
start = time.time_ns()
|
|
# Results
|
|
generations: List[Generation] = []
|
|
prev_batches = []
|
|
requests_to_generate = []
|
|
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
|
|
# Stage 1. Collect next token ids of any previously started generations
|
|
for batch_id, batch in enumerate(batches):
|
|
if batch.logits is not None:
|
|
logits = batch.logits
|
|
past = batch.past
|
|
prefill = batch.past_key_values is None
|
|
if prefill:
|
|
# no right padding for prefill
|
|
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
|
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
|
else:
|
|
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
|
|
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
|
|
|
# Select next token
|
|
input_length = batch.input_length
|
|
if logits.shape[-2] > 1:
|
|
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
|
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate
|
|
)
|
|
else:
|
|
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
|
batch.input_ids, logits.squeeze(-2), self.speculate
|
|
)
|
|
# Speculation is not active for causal
|
|
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
|
batch.top_n_tokens,
|
|
batch.top_n_tokens_tensor,
|
|
logprobs,
|
|
accepted_ids,
|
|
)
|
|
|
|
prev_batches.append({
|
|
'next_token_ids': next_token_ids,
|
|
'next_token_logprobs': next_token_logprobs,
|
|
})
|
|
|
|
for req_idx, req in enumerate(batch.requests):
|
|
requests_to_generate.append({
|
|
'req': req,
|
|
'prev_req_idx': req.idx,
|
|
'batch_id': batch_id,
|
|
'seed': batch.next_token_chooser.seeds[req_idx],
|
|
'do_sample': batch.next_token_chooser.do_sample[req_idx],
|
|
'top_n_tokens': batch.top_n_tokens[req_idx],
|
|
'top_token_ids': batch_top_token_ids[req_idx],
|
|
'top_token_logprobs': batch_top_token_logprobs[req_idx],
|
|
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx],
|
|
})
|
|
|
|
htorch.core.mark_step()
|
|
|
|
# Add new token into input_ids
|
|
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
|
|
|
|
# Update attention_mask as we added a new token to input_ids
|
|
batch.attention_mask.index_fill_(1, token_idx, 1)
|
|
|
|
# add cross-attn mask for new token
|
|
if batch.cross_attention_mask is not None:
|
|
cross_attention_mask_prev = batch.cross_attention_mask
|
|
if token_idx is not None:
|
|
mask = cross_attention_mask_prev[:, token_idx - 2 : token_idx - 1, ...]
|
|
cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask)
|
|
batch.cross_attention_mask = cross_attention_mask_prev
|
|
|
|
# Adjust lengths
|
|
batch.input_length += 1
|
|
# Update position_ids
|
|
if prefill:
|
|
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
|
else:
|
|
batch.position_ids += 1
|
|
# Update past key values
|
|
if prefill:
|
|
batch.past_key_values = past
|
|
|
|
htorch.core.mark_step()
|
|
|
|
# Stage 2. Prepare new batch for speculative scheduling
|
|
if len(batches) > 1:
|
|
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup)
|
|
else:
|
|
batch = batches[0]
|
|
|
|
prefill = batch.past_key_values is None
|
|
|
|
# Check if we need to do any bookkeeping first
|
|
if not prefill:
|
|
batch = self.batch_type.recombine([batch], self.tokenizer.pad_token_id, is_warmup)
|
|
|
|
scenario = 'PREFILL' if prefill else 'GENERATE'
|
|
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs:
|
|
self.model.clear_cache()
|
|
self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
|
|
dbg_trace(
|
|
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
|
#assert batch.right_padding > 0, 'No more room for next token!'
|
|
|
|
# Execute batch
|
|
if prefill:
|
|
# no right padding for prefill
|
|
#token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
|
batch.logits, batch.past = self.forward(
|
|
batch,
|
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
|
)
|
|
|
|
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
|
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
|
# - we've already generated the first and only needed token in the prefill phase
|
|
pass
|
|
else:
|
|
#token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
|
batch.logits = self.forward(
|
|
batch,
|
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
|
)
|
|
|
|
if batch.pixel_values is not None:
|
|
batch.pixel_values = None
|
|
if batch.aspect_ratio_ids is not None:
|
|
batch.aspect_ratio_ids = None
|
|
if batch.aspect_ratio_mask is not None:
|
|
batch.aspect_ratio_mask = None
|
|
|
|
htorch.core.mark_step()
|
|
|
|
start_decode = time.time_ns()
|
|
|
|
# Stage 3. Finish and return previous generations
|
|
stopped = len(requests_to_generate) > 0
|
|
for prev_batch in prev_batches:
|
|
prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist()
|
|
prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu()
|
|
htorch.core.mark_step()
|
|
|
|
for req_data in requests_to_generate:
|
|
req = req_data['req']
|
|
i = req_data['prev_req_idx']
|
|
prev_batch_id = req_data['batch_id']
|
|
assert len(prev_batches) > prev_batch_id
|
|
next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu']
|
|
next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs']
|
|
|
|
request = req.data
|
|
input_length = req.input_length
|
|
prefix_offset = req.prefix_offset
|
|
read_offset = req.read_offset
|
|
do_sample = req_data['do_sample']
|
|
seed = req_data['seed']
|
|
stopping_criteria = req.stopping_criteria
|
|
all_input_ids = req.all_input_ids
|
|
next_token_id = next_token_ids_cpu[i]
|
|
next_token_logprob = next_token_logprobs[i]
|
|
top_n_tokens = req_data['top_n_tokens']
|
|
top_token_ids = req_data['top_token_ids']
|
|
top_token_logprobs = req_data['top_token_logprobs']
|
|
grammar_state = req_data['grammar_state']
|
|
|
|
# Append next token to all tokens
|
|
all_input_ids[input_length] = next_token_id
|
|
new_input_length = input_length + 1
|
|
|
|
# Generated token
|
|
if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0:
|
|
next_token_text = ''
|
|
else:
|
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
|
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
|
|
)
|
|
|
|
# Evaluate stopping criteria
|
|
stop, reason = stopping_criteria(
|
|
next_token_id,
|
|
next_token_text,
|
|
)
|
|
|
|
if not stop:
|
|
stopped = False
|
|
|
|
# Shard generations
|
|
# All generations will be appended in the rust sharded client
|
|
if i % self.world_size == self.rank:
|
|
if stop:
|
|
# Decode generated tokens
|
|
if is_tokenizer_transparent(self.tokenizer):
|
|
output_text = None
|
|
else:
|
|
output_text = self.decode(
|
|
all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0]
|
|
)
|
|
generated_text = GeneratedText(
|
|
output_text,
|
|
stopping_criteria.current_tokens,
|
|
reason,
|
|
seed if do_sample else None,
|
|
)
|
|
else:
|
|
generated_text = None
|
|
|
|
# Prefill
|
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
|
# Remove generated token to only have prefill and add nan for first prompt token
|
|
prefill_logprobs = [float("nan")] + next_token_logprobs
|
|
prefill_token_ids = all_input_ids[0: new_input_length - 1]
|
|
prefill_texts = self.tokenizer.batch_decode(
|
|
prefill_token_ids,
|
|
clean_up_tokenization_spaces=False,
|
|
skip_special_tokens=False,
|
|
)
|
|
prefill_tokens = Tokens(
|
|
prefill_token_ids,
|
|
prefill_logprobs,
|
|
prefill_texts,
|
|
is_special=[],
|
|
)
|
|
else:
|
|
prefill_tokens = None
|
|
|
|
if top_n_tokens > 0:
|
|
all_top_tokens = []
|
|
for top_token_ids, top_token_logprobs in zip(
|
|
top_token_ids, top_token_logprobs
|
|
):
|
|
toptoken_texts = self.tokenizer.batch_decode(
|
|
top_token_ids,
|
|
clean_up_tokenization_spaces=False,
|
|
skip_special_tokens=False,
|
|
)
|
|
special_toptokens = [
|
|
token_id in self.all_special_ids
|
|
for token_id in top_token_ids
|
|
]
|
|
top_tokens = Tokens(
|
|
top_token_ids,
|
|
top_token_logprobs,
|
|
toptoken_texts,
|
|
special_toptokens,
|
|
)
|
|
all_top_tokens.append(top_tokens)
|
|
top_tokens = all_top_tokens
|
|
else:
|
|
top_tokens = None
|
|
|
|
generation = Generation(
|
|
request.id,
|
|
prefill_tokens,
|
|
Tokens(
|
|
[next_token_id],
|
|
[next_token_logprob],
|
|
[next_token_text],
|
|
[next_token_id in self.all_special_ids],
|
|
),
|
|
generated_text,
|
|
top_tokens,
|
|
)
|
|
|
|
generations.append(generation)
|
|
|
|
batch.next_token_chooser = (
|
|
batch.next_token_chooser.advance_grammar_single_with_past_state(
|
|
req.idx, next_token_id, grammar_state
|
|
)
|
|
)
|
|
|
|
req.all_input_ids = all_input_ids
|
|
req.input_length = new_input_length
|
|
req.prefix_offset = prefix_offset
|
|
req.read_offset = read_offset
|
|
|
|
htorch.core.mark_step()
|
|
self.step = self.step + 1
|
|
if self.hb_profiler is not None:
|
|
if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps:
|
|
self.hb_profiler.stop()
|
|
else:
|
|
self.hb_profiler.step()
|
|
|
|
forward_ns = start_decode - start
|
|
decode_ns = time.time_ns() - start_decode
|
|
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
|
|
|
def batch_from_pb(self, batch, is_warmup):
|
|
return self.batch_type.from_pb_processor(
|
|
batch,
|
|
self.tokenizer,
|
|
self.processor,
|
|
self.model.config,
|
|
self.dtype,
|
|
self.device,
|
|
is_warmup
|
|
)
|
|
|
|
def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup):
|
|
batch = copy.deepcopy(request.batch)
|
|
for req in batch.requests:
|
|
req.truncate = seq_len
|
|
|
|
for i in range(len(batch.requests) - batch_size):
|
|
batch.requests.pop()
|
|
|
|
return self.batch_from_pb(batch, is_warmup)
|
|
|
|
def warmup(self, request) -> None:
|
|
global MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS
|
|
MAX_TOTAL_TOKENS = request.max_total_tokens
|
|
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
|
batch = self.batch_from_pb(request.batch, is_warmup=True)
|
|
max_input_length = batch.input_ids.shape[1]
|
|
max_prefill_batch_size = batch.input_ids.shape[0]
|
|
|
|
try:
|
|
# max prefill batch size warmup
|
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
|
except:
|
|
raise RuntimeError(
|
|
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
|
)
|
|
|
|
global BASE_IMAGE_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST
|
|
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
|
batch_size = 1
|
|
while batch_size <= max_prefill_batch_size:
|
|
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
|
batch_size = batch_size * 2
|
|
if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size :
|
|
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
|
|
|
|
if self.model.config.model_type == "mllama":
|
|
seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF
|
|
else:
|
|
seq_len = BASE_IMAGE_TOKENS
|
|
|
|
PREFILL_WARMUP_SEQLEN_LIST = []
|
|
i = 0
|
|
while seq_len <= max_input_length:
|
|
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
|
|
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
|
|
i += 1
|
|
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
|
|
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
|
|
|
|
#Prefill and decode warmup
|
|
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
|
prefill_batch = None
|
|
decode_batch = None
|
|
try:
|
|
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
|
|
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
|
batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup=True)
|
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
|
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup=True)
|
|
|
|
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
|
|
|
except:
|
|
raise RuntimeError(
|
|
f"Not enough memory to handle following prefill and decode warmup."
|
|
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
|
|
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
|
|
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
|
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
|
)
|
|
|
|
mem_stats = get_hpu_memory_stats(self.device)
|
|
logger.info(
|
|
f"\nFollowing prefill and decode warmup successfully.\n"
|
|
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
|
|
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
|
|
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
|
|
f"Memory stats: {mem_stats} "
|
|
)
|
|
|
|
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
|
batch_size = max_prefill_batch_size * 2
|
|
# Decode warmup with bigger batch_size
|
|
try:
|
|
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size:
|
|
batches = []
|
|
while batch_size <= max_decode_batch_size:
|
|
for i in range(int(batch_size/max_prefill_batch_size)) :
|
|
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0]-1, max_prefill_batch_size, is_warmup=False)
|
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
|
batches.append(prefill_batch)
|
|
|
|
_, decode_batch, _ = self.generate_token(batches, is_warmup=True)
|
|
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
|
batch_size = batch_size * 2
|
|
batches.clear()
|
|
|
|
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
|
|
max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2
|
|
batch_size = max_decode_batch_size
|
|
for i in range(int(max_decode_batch_size / 2)) :
|
|
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0]-1, 2, is_warmup=False)
|
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
|
batches.append(prefill_batch)
|
|
_, decode_batch, _ = self.generate_token(batches, is_warmup=True)
|
|
DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size)
|
|
|
|
MAX_BATCH_TOTAL_TOKENS = max_decode_batch_size * MAX_TOTAL_TOKENS
|
|
except :
|
|
raise RuntimeError(
|
|
f"Not enough memory to handle batch_size({batch_size}) decode warmup."
|
|
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
|
|
f"max_decode_batch_size is {max_decode_batch_size}"
|
|
f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'"
|
|
)
|
|
|
|
mem_stats = get_hpu_memory_stats(self.device)
|
|
logger.info(
|
|
f"\nFollowing decode warmup successfully.\n"
|
|
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
|
|
f"Memory stats: {mem_stats}"
|
|
)
|
|
|
|
return MAX_BATCH_TOTAL_TOKENS
|