diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 05ed0202..3c9ee850 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -2049,7 +2049,16 @@ fn main() -> Result<(), LauncherError> { None => { let compute_type = compute_type(num_shard); let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref()); - let default = compute_optimal.unwrap_or(4096); + // TODO: remove this when we correctly esimate the flops for VLMs + // this is a short term temporary fix to enable vlms to avoid rejecting images + let default_optimal = match config { + Some(ref config) => match config.model_type.as_deref() { + Some("qwen2_vl") => 10_000, + _ => 4096, + }, + None => 4096, + }; + let default = compute_optimal.unwrap_or(default_optimal); let vram_maximum = vram_maximum( config.as_ref(), compute_type.as_ref(), diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 7b3500e3..c0baaf59 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -568,9 +568,7 @@ def apply_llama3_scaling( class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): def __init__(self, inv_freq, scaling_factor, sections): super().__init__(inv_freq, scaling_factor) - # expand the inv_freq for the 3 sections - self.inv_freq_exp = inv_freq[None, None, :, None].expand(3, -1, -1, 1) - self.sections = sections * 2 + self.sections = sections self._cos_cached = None self._sin_cached = None @@ -582,7 +580,7 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): sin: torch.Tensor, ): # prepare input tensors - q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)] + q, k = [x.transpose(0, 1) for x in (query, key)] rotary_dim = cos.shape[-1] q1, k1 = q[..., :rotary_dim], k[..., :rotary_dim] q2 = torch.cat((-q[..., rotary_dim // 2 :], q[..., : rotary_dim // 2]), dim=-1) @@ -596,15 +594,14 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): # recomputing if the sequence length is smaller than the cached one if ( seqlen > self._seq_len_cached - or self._cos_cached_exp.device != device - or self._cos_cached_exp.dtype != dtype + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - freqs = freqs.expand(3, -1, -1) - self._cos_cached_exp = freqs.cos().to(dtype) - self._sin_cached_exp = freqs.sin().to(dtype) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) def get_cos_sin( self, @@ -613,23 +610,24 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): dtype: torch.dtype, ): self._update_cos_sin_cache(dtype, position_ids.device, max_s) - # expand the position_ids to match the shape of the cached cos/sin - indices = ( - position_ids.squeeze(1) - .unsqueeze(-1) - .expand(-1, -1, self._cos_cached_exp.shape[-1]) + + # access freqs for each of the 3 sections and stack them + cos_c = torch.stack( + [self._cos_cached[position_ids[:, i]] for i in range(3)], dim=0 ) - indices = indices.to(dtype=torch.int64) - cos_c = torch.gather(self._cos_cached_exp, 1, indices) - cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1) + sin_c = torch.stack( + [self._sin_cached[position_ids[:, i]] for i in range(3)], dim=0 + ) + + # chunk based on sections split_cos = torch.split(cos_c, self.sections, dim=-1) - cos_c = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1) - cos_c = cos_c.unsqueeze(1) - - sin_c = torch.gather(self._sin_cached_exp, 1, indices) - sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1) split_sin = torch.split(sin_c, self.sections, dim=-1) - sin_c = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1) - sin_c = sin_c.unsqueeze(1) - return cos_c, sin_c + # for each section, select the corresponding cos/sin (0, 1, 2, ...) + cos_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1) + sin_sliced = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1) + + # double the size and add a batch dimension + cos = torch.cat([cos_sliced, cos_sliced], dim=-1).unsqueeze(0) + sin = torch.cat([sin_sliced, sin_sliced], dim=-1).unsqueeze(0) + return cos, sin diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index fdc426bc..65d18963 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -413,31 +413,17 @@ class Qwen2VLForConditionalGeneration(nn.Module): def get_position_ids( self, input_ids: torch.Tensor, - image_grid_thw: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - # TODO: avoid the early return and extra work in a more efficient way - if image_grid_thw is not None: - - if input_ids.dim() == 1: - input_ids = input_ids.unsqueeze(0) - - position_ids = torch.ones( - 3, - 1, - input_ids.shape[0], - dtype=input_ids.dtype, - device=input_ids.device, + if image_grid_thw is None: + # (batch_size, 3) + return ( + torch.arange(input_ids.shape[0], device=input_ids.device) + .unsqueeze(1) + .repeat(1, 3) ) - position_ids = ( - torch.arange(input_ids.shape[1], device=input_ids.device) - .view(1, 1, -1) - .repeat(3, input_ids.shape[0], 1) - ) - return position_ids # if image grid provided than we need to calculate the position ids - spatial_merge_size = self.spatial_merge_size vision_start_token_id = self.vision_start_token_id vision_end_token_id = self.vision_end_token_id @@ -445,12 +431,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): device = input_ids.device dtype = input_ids.dtype input_ids_len = input_ids.shape[0] - position_ids = torch.ones( - 3, - input_ids_len, - dtype=dtype, - device=device, - ) # capture vision segments starts = torch.where(input_ids == vision_start_token_id)[0] @@ -513,11 +493,11 @@ class Qwen2VLForConditionalGeneration(nn.Module): m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) full_llm_pos_ids_list.append(m + max_s) - # combine all the segments and reshape to (3, input_ids_len) - llm_positions = torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., :] = llm_positions.to(position_ids.device) - # TODO: avoid the extra dimension when updating the consumer of this function - return position_ids.unsqueeze(1) + # concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) + position_ids = ( + torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) + ) + return position_ids def forward( self, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a7d7f711..c5d80bc5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1431,7 +1431,7 @@ class FlashCausalLM(Model): "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" ) input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] - position_ids = self.cuda_graphs[max_bs]["position_ids"][..., :bs] + position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] if ATTENTION == "flashinfer": block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] else: @@ -2046,7 +2046,7 @@ class FlashCausalLM(Model): # instantly become of shape [BATCH_SIZE] if prefill and finished_prefilling: indices = batch.cu_seqlen_prefill[1:] - 1 - batch.position_ids = batch.position_ids[(..., indices)] + batch.position_ids = batch.position_ids[indices] batch.slot_indices = batch.slot_indices[indices] batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ indices