fix vlm and seq2seq

This commit is contained in:
OlivierDehaene 2024-10-07 15:08:30 +02:00
parent 460e830444
commit 8188deac22
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
6 changed files with 23 additions and 18 deletions

View File

@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot): async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
class Weather(BaseModel): class Weather(BaseModel):
unit: str unit: str
temperature: List[int] temperature: List[int]

View File

@ -699,7 +699,6 @@ def check_args(
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,

View File

@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module):
aspect_ratio_ids: torch.Tensor, aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( (
pixel_values.shape batch_size,
) num_concurrent_media,
num_tiles,
num_channels,
height,
width,
) = pixel_values.shape
pixel_values = pixel_values.reshape( pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width batch_size * num_concurrent_media * num_tiles, num_channels, height, width

View File

@ -80,7 +80,7 @@ class Seq2SeqLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self.input_ids), current_tokens=len(self.decoder_input_ids),
) )
@classmethod @classmethod

View File

@ -295,7 +295,7 @@ class VlmCausalLM(FlashCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
postfix_lengths = batch.postfix_lengths_tensor postfix_lengths = batch.postfix_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
@ -338,7 +338,7 @@ class VlmCausalLM(FlashCausalLM):
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
postfix_lengths = batch.postfix_lengths_tensor postfix_lengths = batch.postfix_lengths_tensor
prefix_lengths_tensor = batch.prefix_lengths_tensor prefix_lengths_tensor = batch.prefix_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
@ -347,7 +347,6 @@ class VlmCausalLM(FlashCausalLM):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph # Try to find an associated cuda graph
bs = input_ids.shape[0] bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])

View File

@ -120,8 +120,12 @@ def _load_and_merge(
if adapter.id == BASE_MODEL_ADAPTER_ID: if adapter.id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.") raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( (
load_module_map( module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_module_map(
model_id, model_id,
adapter.revision, adapter.revision,
adapter.id, adapter.id,
@ -129,7 +133,6 @@ def _load_and_merge(
weight_names, weight_names,
trust_remote_code, trust_remote_code,
) )
)
adapters_to_merge.append((module_map, adapter_config)) adapters_to_merge.append((module_map, adapter_config))
merged_weight_names = merged_weight_names.union(adapter_weight_names) merged_weight_names = merged_weight_names.union(adapter_weight_names)