mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix vlm and seq2seq
This commit is contained in:
parent
460e830444
commit
8188deac22
@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle):
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||
|
||||
class Weather(BaseModel):
|
||||
unit: str
|
||||
temperature: List[int]
|
||||
|
@ -699,7 +699,6 @@ def check_args(
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
|
@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module):
|
||||
aspect_ratio_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
|
||||
pixel_values.shape
|
||||
)
|
||||
(
|
||||
batch_size,
|
||||
num_concurrent_media,
|
||||
num_tiles,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
) = pixel_values.shape
|
||||
|
||||
pixel_values = pixel_values.reshape(
|
||||
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
||||
|
@ -80,7 +80,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
current_tokens=len(self.input_ids),
|
||||
current_tokens=len(self.decoder_input_ids),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -295,7 +295,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
postfix_lengths = batch.postfix_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
@ -338,7 +338,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
postfix_lengths = batch.postfix_lengths_tensor
|
||||
prefix_lengths_tensor = batch.prefix_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
@ -347,7 +347,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
# Try to find an associated cuda graph
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
|
@ -120,8 +120,12 @@ def _load_and_merge(
|
||||
if adapter.id == BASE_MODEL_ADAPTER_ID:
|
||||
raise ValueError("Base model adapter cannot be merged.")
|
||||
|
||||
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
||||
load_module_map(
|
||||
(
|
||||
module_map,
|
||||
adapter_config,
|
||||
adapter_weight_names,
|
||||
adapter_tokenizer,
|
||||
) = load_module_map(
|
||||
model_id,
|
||||
adapter.revision,
|
||||
adapter.id,
|
||||
@ -129,7 +133,6 @@ def _load_and_merge(
|
||||
weight_names,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
||||
adapters_to_merge.append((module_map, adapter_config))
|
||||
merged_weight_names = merged_weight_names.union(adapter_weight_names)
|
||||
|
Loading…
Reference in New Issue
Block a user