mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-22 17:10:18 +00:00
fix: update position ids so first dim is batch, simplify rotary and bump vlm default token limit
This commit is contained in:
parent
68e3ee8e79
commit
c75c01e9b9
@ -2049,7 +2049,16 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
None => {
|
None => {
|
||||||
let compute_type = compute_type(num_shard);
|
let compute_type = compute_type(num_shard);
|
||||||
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
|
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
|
||||||
let default = compute_optimal.unwrap_or(4096);
|
// TODO: remove this when we correctly esimate the flops for VLMs
|
||||||
|
// this is a short term temporary fix to enable vlms to avoid rejecting images
|
||||||
|
let default_optimal = match config {
|
||||||
|
Some(ref config) => match config.model_type.as_deref() {
|
||||||
|
Some("qwen2_vl") => 10_000,
|
||||||
|
_ => 4096,
|
||||||
|
},
|
||||||
|
None => 4096,
|
||||||
|
};
|
||||||
|
let default = compute_optimal.unwrap_or(default_optimal);
|
||||||
let vram_maximum = vram_maximum(
|
let vram_maximum = vram_maximum(
|
||||||
config.as_ref(),
|
config.as_ref(),
|
||||||
compute_type.as_ref(),
|
compute_type.as_ref(),
|
||||||
|
@ -568,9 +568,7 @@ def apply_llama3_scaling(
|
|||||||
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||||
def __init__(self, inv_freq, scaling_factor, sections):
|
def __init__(self, inv_freq, scaling_factor, sections):
|
||||||
super().__init__(inv_freq, scaling_factor)
|
super().__init__(inv_freq, scaling_factor)
|
||||||
# expand the inv_freq for the 3 sections
|
self.sections = sections
|
||||||
self.inv_freq_exp = inv_freq[None, None, :, None].expand(3, -1, -1, 1)
|
|
||||||
self.sections = sections * 2
|
|
||||||
self._cos_cached = None
|
self._cos_cached = None
|
||||||
self._sin_cached = None
|
self._sin_cached = None
|
||||||
|
|
||||||
@ -582,7 +580,7 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
):
|
):
|
||||||
# prepare input tensors
|
# prepare input tensors
|
||||||
q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)]
|
q, k = [x.transpose(0, 1) for x in (query, key)]
|
||||||
rotary_dim = cos.shape[-1]
|
rotary_dim = cos.shape[-1]
|
||||||
q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim]
|
q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim]
|
||||||
q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1)
|
q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1)
|
||||||
@ -596,15 +594,14 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
# recomputing if the sequence length is smaller than the cached one
|
# recomputing if the sequence length is smaller than the cached one
|
||||||
if (
|
if (
|
||||||
seqlen > self._seq_len_cached
|
seqlen > self._seq_len_cached
|
||||||
or self._cos_cached_exp.device != device
|
or self._cos_cached.device != device
|
||||||
or self._cos_cached_exp.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
freqs = freqs.expand(3, -1, -1)
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
self._cos_cached_exp = freqs.cos().to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
self._sin_cached_exp = freqs.sin().to(dtype)
|
|
||||||
|
|
||||||
def get_cos_sin(
|
def get_cos_sin(
|
||||||
self,
|
self,
|
||||||
@ -613,23 +610,24 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||||
# expand the position_ids to match the shape of the cached cos/sin
|
|
||||||
indices = (
|
# access freqs for each of the 3 sections and stack them
|
||||||
position_ids.squeeze(1)
|
cos_c = torch.stack(
|
||||||
.unsqueeze(-1)
|
[self._cos_cached[position_ids[:, i]] for i in range(3)], dim=0
|
||||||
.expand(-1, -1, self._cos_cached_exp.shape[-1])
|
|
||||||
)
|
)
|
||||||
indices = indices.to(dtype=torch.int64)
|
sin_c = torch.stack(
|
||||||
cos_c = torch.gather(self._cos_cached_exp, 1, indices)
|
[self._sin_cached[position_ids[:, i]] for i in range(3)], dim=0
|
||||||
cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1)
|
)
|
||||||
|
|
||||||
|
# chunk based on sections
|
||||||
split_cos = torch.split(cos_c, self.sections, dim=-1)
|
split_cos = torch.split(cos_c, self.sections, dim=-1)
|
||||||
cos_c = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1)
|
|
||||||
cos_c = cos_c.unsqueeze(1)
|
|
||||||
|
|
||||||
sin_c = torch.gather(self._sin_cached_exp, 1, indices)
|
|
||||||
sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1)
|
|
||||||
split_sin = torch.split(sin_c, self.sections, dim=-1)
|
split_sin = torch.split(sin_c, self.sections, dim=-1)
|
||||||
sin_c = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
|
|
||||||
sin_c = sin_c.unsqueeze(1)
|
|
||||||
|
|
||||||
return cos_c, sin_c
|
# for each section, select the corresponding cos/sin (0, 1, 2, ...)
|
||||||
|
cos_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1)
|
||||||
|
sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
|
||||||
|
|
||||||
|
# double the size and add a batch dimension
|
||||||
|
cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(0)
|
||||||
|
sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(0)
|
||||||
|
return cos, sin
|
||||||
|
@ -413,31 +413,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
def get_position_ids(
|
def get_position_ids(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
image_grid_thw: torch.Tensor,
|
image_grid_thw: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if image_grid_thw is None:
|
||||||
# TODO: avoid the early return and extra work in a more efficient way
|
# (batch_size, 3)
|
||||||
if image_grid_thw is not None:
|
return (
|
||||||
|
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||||
if input_ids.dim() == 1:
|
.unsqueeze(1)
|
||||||
input_ids = input_ids.unsqueeze(0)
|
.repeat(1, 3)
|
||||||
|
|
||||||
position_ids = torch.ones(
|
|
||||||
3,
|
|
||||||
1,
|
|
||||||
input_ids.shape[0],
|
|
||||||
dtype=input_ids.dtype,
|
|
||||||
device=input_ids.device,
|
|
||||||
)
|
)
|
||||||
position_ids = (
|
|
||||||
torch.arange(input_ids.shape[1], device=input_ids.device)
|
|
||||||
.view(1, 1, -1)
|
|
||||||
.repeat(3, input_ids.shape[0], 1)
|
|
||||||
)
|
|
||||||
return position_ids
|
|
||||||
|
|
||||||
# if image grid provided than we need to calculate the position ids
|
# if image grid provided than we need to calculate the position ids
|
||||||
|
|
||||||
spatial_merge_size = self.spatial_merge_size
|
spatial_merge_size = self.spatial_merge_size
|
||||||
vision_start_token_id = self.vision_start_token_id
|
vision_start_token_id = self.vision_start_token_id
|
||||||
vision_end_token_id = self.vision_end_token_id
|
vision_end_token_id = self.vision_end_token_id
|
||||||
@ -445,12 +431,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
dtype = input_ids.dtype
|
dtype = input_ids.dtype
|
||||||
input_ids_len = input_ids.shape[0]
|
input_ids_len = input_ids.shape[0]
|
||||||
position_ids = torch.ones(
|
|
||||||
3,
|
|
||||||
input_ids_len,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# capture vision segments
|
# capture vision segments
|
||||||
starts = torch.where(input_ids == vision_start_token_id)[0]
|
starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||||
@ -513,11 +493,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||||
full_llm_pos_ids_list.append(m + max_s)
|
full_llm_pos_ids_list.append(m + max_s)
|
||||||
|
|
||||||
# combine all the segments and reshape to (3, input_ids_len)
|
# concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
||||||
llm_positions = torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1)
|
position_ids = (
|
||||||
position_ids[..., :] = llm_positions.to(position_ids.device)
|
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||||
# TODO: avoid the extra dimension when updating the consumer of this function
|
)
|
||||||
return position_ids.unsqueeze(1)
|
return position_ids
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -1431,7 +1431,7 @@ class FlashCausalLM(Model):
|
|||||||
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||||
)
|
)
|
||||||
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||||
position_ids = self.cuda_graphs[max_bs]["position_ids"][..., :bs]
|
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||||
else:
|
else:
|
||||||
@ -2046,7 +2046,7 @@ class FlashCausalLM(Model):
|
|||||||
# instantly become of shape [BATCH_SIZE]
|
# instantly become of shape [BATCH_SIZE]
|
||||||
if prefill and finished_prefilling:
|
if prefill and finished_prefilling:
|
||||||
indices = batch.cu_seqlen_prefill[1:] - 1
|
indices = batch.cu_seqlen_prefill[1:] - 1
|
||||||
batch.position_ids = batch.position_ids[(..., indices)]
|
batch.position_ids = batch.position_ids[indices]
|
||||||
batch.slot_indices = batch.slot_indices[indices]
|
batch.slot_indices = batch.slot_indices[indices]
|
||||||
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
||||||
indices
|
indices
|
||||||
|
Loading…
Reference in New Issue
Block a user