fix: update position ids so first dim is batch, simplify rotary and bump vlm default token limit

This commit is contained in:
drbh 2025-01-28 19:25:23 +00:00
parent 68e3ee8e79
commit c75c01e9b9
4 changed files with 47 additions and 60 deletions

View File

@ -2049,7 +2049,16 @@ fn main() -> Result<(), LauncherError> {
None => {
let compute_type = compute_type(num_shard);
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
let default = compute_optimal.unwrap_or(4096);
// TODO: remove this when we correctly esimate the flops for VLMs
// this is a short term temporary fix to enable vlms to avoid rejecting images
let default_optimal = match config {
Some(ref config) => match config.model_type.as_deref() {
Some("qwen2_vl") => 10_000,
_ => 4096,
},
None => 4096,
};
let default = compute_optimal.unwrap_or(default_optimal);
let vram_maximum = vram_maximum(
config.as_ref(),
compute_type.as_ref(),

View File

@ -568,9 +568,7 @@ def apply_llama3_scaling(
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
def __init__(self, inv_freq, scaling_factor, sections):
super().__init__(inv_freq, scaling_factor)
# expand the inv_freq for the 3 sections
self.inv_freq_exp = inv_freq[None, None, :, None].expand(3, -1, -1, 1)
self.sections = sections * 2
self.sections = sections
self._cos_cached = None
self._sin_cached = None
@ -582,7 +580,7 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
sin: torch.Tensor,
):
# prepare input tensors
q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)]
q, k = [x.transpose(0, 1) for x in (query, key)]
rotary_dim = cos.shape[-1]
q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim]
q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1)
@ -596,15 +594,14 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
# recomputing if the sequence length is smaller than the cached one
if (
seqlen > self._seq_len_cached
or self._cos_cached_exp.device != device
or self._cos_cached_exp.dtype != dtype
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
freqs = freqs.expand(3, -1, -1)
self._cos_cached_exp = freqs.cos().to(dtype)
self._sin_cached_exp = freqs.sin().to(dtype)
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(
self,
@ -613,23 +610,24 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
dtype: torch.dtype,
):
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
# expand the position_ids to match the shape of the cached cos/sin
indices = (
position_ids.squeeze(1)
.unsqueeze(-1)
.expand(-1, -1, self._cos_cached_exp.shape[-1])
# access freqs for each of the 3 sections and stack them
cos_c = torch.stack(
[self._cos_cached[position_ids[:, i]] for i in range(3)], dim=0
)
indices = indices.to(dtype=torch.int64)
cos_c = torch.gather(self._cos_cached_exp, 1, indices)
cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1)
sin_c = torch.stack(
[self._sin_cached[position_ids[:, i]] for i in range(3)], dim=0
)
# chunk based on sections
split_cos = torch.split(cos_c, self.sections, dim=-1)
cos_c = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1)
cos_c = cos_c.unsqueeze(1)
sin_c = torch.gather(self._sin_cached_exp, 1, indices)
sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1)
split_sin = torch.split(sin_c, self.sections, dim=-1)
sin_c = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
sin_c = sin_c.unsqueeze(1)
return cos_c, sin_c
# for each section, select the corresponding cos/sin (0, 1, 2, ...)
cos_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1)
sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
# double the size and add a batch dimension
cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(0)
sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(0)
return cos, sin

View File

@ -413,31 +413,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def get_position_ids(
self,
input_ids: torch.Tensor,
image_grid_thw: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# TODO: avoid the early return and extra work in a more efficient way
if image_grid_thw is not None:
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
position_ids = torch.ones(
3,
1,
input_ids.shape[0],
dtype=input_ids.dtype,
device=input_ids.device,
if image_grid_thw is None:
# (batch_size, 3)
return (
torch.arange(input_ids.shape[0], device=input_ids.device)
.unsqueeze(1)
.repeat(1, 3)
)
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.repeat(3, input_ids.shape[0], 1)
)
return position_ids
# if image grid provided than we need to calculate the position ids
spatial_merge_size = self.spatial_merge_size
vision_start_token_id = self.vision_start_token_id
vision_end_token_id = self.vision_end_token_id
@ -445,12 +431,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
device = input_ids.device
dtype = input_ids.dtype
input_ids_len = input_ids.shape[0]
position_ids = torch.ones(
3,
input_ids_len,
dtype=dtype,
device=device,
)
# capture vision segments
starts = torch.where(input_ids == vision_start_token_id)[0]
@ -513,11 +493,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
full_llm_pos_ids_list.append(m + max_s)
# combine all the segments and reshape to (3, input_ids_len)
llm_positions = torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., :] = llm_positions.to(position_ids.device)
# TODO: avoid the extra dimension when updating the consumer of this function
return position_ids.unsqueeze(1)
# concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
position_ids = (
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
)
return position_ids
def forward(
self,

View File

@ -1431,7 +1431,7 @@ class FlashCausalLM(Model):
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
)
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
position_ids = self.cuda_graphs[max_bs]["position_ids"][..., :bs]
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
if ATTENTION == "flashinfer":
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
else:
@ -2046,7 +2046,7 @@ class FlashCausalLM(Model):
# instantly become of shape [BATCH_SIZE]
if prefill and finished_prefilling:
indices = batch.cu_seqlen_prefill[1:] - 1
batch.position_ids = batch.position_ids[(..., indices)]
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices]
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
indices