mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-22 09:00:17 +00:00
fix: update position ids so first dim is batch, simplify rotary and bump vlm default token limit
This commit is contained in:
parent
68e3ee8e79
commit
c75c01e9b9
@ -2049,7 +2049,16 @@ fn main() -> Result<(), LauncherError> {
|
||||
None => {
|
||||
let compute_type = compute_type(num_shard);
|
||||
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
|
||||
let default = compute_optimal.unwrap_or(4096);
|
||||
// TODO: remove this when we correctly esimate the flops for VLMs
|
||||
// this is a short term temporary fix to enable vlms to avoid rejecting images
|
||||
let default_optimal = match config {
|
||||
Some(ref config) => match config.model_type.as_deref() {
|
||||
Some("qwen2_vl") => 10_000,
|
||||
_ => 4096,
|
||||
},
|
||||
None => 4096,
|
||||
};
|
||||
let default = compute_optimal.unwrap_or(default_optimal);
|
||||
let vram_maximum = vram_maximum(
|
||||
config.as_ref(),
|
||||
compute_type.as_ref(),
|
||||
|
@ -568,9 +568,7 @@ def apply_llama3_scaling(
|
||||
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
def __init__(self, inv_freq, scaling_factor, sections):
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
# expand the inv_freq for the 3 sections
|
||||
self.inv_freq_exp = inv_freq[None, None, :, None].expand(3, -1, -1, 1)
|
||||
self.sections = sections * 2
|
||||
self.sections = sections
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
|
||||
@ -582,7 +580,7 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
# prepare input tensors
|
||||
q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)]
|
||||
q, k = [x.transpose(0, 1) for x in (query, key)]
|
||||
rotary_dim = cos.shape[-1]
|
||||
q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim]
|
||||
q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1)
|
||||
@ -596,15 +594,14 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
# recomputing if the sequence length is smaller than the cached one
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached_exp.device != device
|
||||
or self._cos_cached_exp.dtype != dtype
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
freqs = freqs.expand(3, -1, -1)
|
||||
self._cos_cached_exp = freqs.cos().to(dtype)
|
||||
self._sin_cached_exp = freqs.sin().to(dtype)
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def get_cos_sin(
|
||||
self,
|
||||
@ -613,23 +610,24 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||
# expand the position_ids to match the shape of the cached cos/sin
|
||||
indices = (
|
||||
position_ids.squeeze(1)
|
||||
.unsqueeze(-1)
|
||||
.expand(-1, -1, self._cos_cached_exp.shape[-1])
|
||||
|
||||
# access freqs for each of the 3 sections and stack them
|
||||
cos_c = torch.stack(
|
||||
[self._cos_cached[position_ids[:, i]] for i in range(3)], dim=0
|
||||
)
|
||||
indices = indices.to(dtype=torch.int64)
|
||||
cos_c = torch.gather(self._cos_cached_exp, 1, indices)
|
||||
cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1)
|
||||
sin_c = torch.stack(
|
||||
[self._sin_cached[position_ids[:, i]] for i in range(3)], dim=0
|
||||
)
|
||||
|
||||
# chunk based on sections
|
||||
split_cos = torch.split(cos_c, self.sections, dim=-1)
|
||||
cos_c = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1)
|
||||
cos_c = cos_c.unsqueeze(1)
|
||||
|
||||
sin_c = torch.gather(self._sin_cached_exp, 1, indices)
|
||||
sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1)
|
||||
split_sin = torch.split(sin_c, self.sections, dim=-1)
|
||||
sin_c = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
|
||||
sin_c = sin_c.unsqueeze(1)
|
||||
|
||||
return cos_c, sin_c
|
||||
# for each section, select the corresponding cos/sin (0, 1, 2, ...)
|
||||
cos_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1)
|
||||
sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
|
||||
|
||||
# double the size and add a batch dimension
|
||||
cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(0)
|
||||
sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(0)
|
||||
return cos, sin
|
||||
|
@ -413,31 +413,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
def get_position_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_grid_thw: torch.Tensor,
|
||||
image_grid_thw: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# TODO: avoid the early return and extra work in a more efficient way
|
||||
if image_grid_thw is not None:
|
||||
|
||||
if input_ids.dim() == 1:
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
|
||||
position_ids = torch.ones(
|
||||
3,
|
||||
1,
|
||||
input_ids.shape[0],
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
if image_grid_thw is None:
|
||||
# (batch_size, 3)
|
||||
return (
|
||||
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||
.unsqueeze(1)
|
||||
.repeat(1, 3)
|
||||
)
|
||||
position_ids = (
|
||||
torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||
.view(1, 1, -1)
|
||||
.repeat(3, input_ids.shape[0], 1)
|
||||
)
|
||||
return position_ids
|
||||
|
||||
# if image grid provided than we need to calculate the position ids
|
||||
|
||||
spatial_merge_size = self.spatial_merge_size
|
||||
vision_start_token_id = self.vision_start_token_id
|
||||
vision_end_token_id = self.vision_end_token_id
|
||||
@ -445,12 +431,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
device = input_ids.device
|
||||
dtype = input_ids.dtype
|
||||
input_ids_len = input_ids.shape[0]
|
||||
position_ids = torch.ones(
|
||||
3,
|
||||
input_ids_len,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# capture vision segments
|
||||
starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||
@ -513,11 +493,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||
full_llm_pos_ids_list.append(m + max_s)
|
||||
|
||||
# combine all the segments and reshape to (3, input_ids_len)
|
||||
llm_positions = torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
position_ids[..., :] = llm_positions.to(position_ids.device)
|
||||
# TODO: avoid the extra dimension when updating the consumer of this function
|
||||
return position_ids.unsqueeze(1)
|
||||
# concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
||||
position_ids = (
|
||||
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||
)
|
||||
return position_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -1431,7 +1431,7 @@ class FlashCausalLM(Model):
|
||||
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||
)
|
||||
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||
position_ids = self.cuda_graphs[max_bs]["position_ids"][..., :bs]
|
||||
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||
else:
|
||||
@ -2046,7 +2046,7 @@ class FlashCausalLM(Model):
|
||||
# instantly become of shape [BATCH_SIZE]
|
||||
if prefill and finished_prefilling:
|
||||
indices = batch.cu_seqlen_prefill[1:] - 1
|
||||
batch.position_ids = batch.position_ids[(..., indices)]
|
||||
batch.position_ids = batch.position_ids[indices]
|
||||
batch.slot_indices = batch.slot_indices[indices]
|
||||
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
||||
indices
|
||||
|
Loading…
Reference in New Issue
Block a user