mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Hardcode remove image tokens.
This commit is contained in:
parent
51348713d6
commit
defc477a03
@ -119,7 +119,6 @@ class IdeficsCausalLMBatch(Batch):
|
||||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}")
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
@ -136,7 +135,6 @@ class IdeficsCausalLMBatch(Batch):
|
||||
prompts = []
|
||||
for inp in inputs:
|
||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}")
|
||||
if isinstance(inp, str):
|
||||
prompts.append([inp])
|
||||
elif isinstance(inp, list):
|
||||
@ -169,9 +167,6 @@ class IdeficsCausalLMBatch(Batch):
|
||||
max_length=max_truncation,
|
||||
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||
).to(device)
|
||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}")
|
||||
# from loguru import logger; logger.info({k: v.size() for k,v in processed_inputs.items()})
|
||||
# {'input_ids': torch.Size([4, 5]), 'attention_mask': torch.Size([4, 5]), 'pixel_values': torch.Size([4, num_images, 3, 224, 224]), 'image_attention_mask': torch.Size([4, 5, num_images])}
|
||||
for _ in pb.requests:
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
prefix_offsets.append(input_len - 5) # To decode without potential fallbacks errors
|
||||
@ -193,8 +188,6 @@ class IdeficsCausalLMBatch(Batch):
|
||||
image_attention_mask = input_ids.new_zeros(
|
||||
(pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1))
|
||||
)
|
||||
# image_attention_mask = tokenized_inputs["image_attention_mask"]
|
||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
||||
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
@ -227,7 +220,6 @@ class IdeficsCausalLMBatch(Batch):
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
|
||||
# from loguru import logger; logger.info(f"filter in idefics_causal_lm.py")
|
||||
# It deletes requests from the batch. For instance when client lost connection
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
@ -339,7 +331,6 @@ class IdeficsCausalLMBatch(Batch):
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch":
|
||||
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py")
|
||||
# It adds new requests to the batch
|
||||
# Used for padding
|
||||
total_batch_size = 0
|
||||
@ -436,8 +427,6 @@ class IdeficsCausalLMBatch(Batch):
|
||||
:,
|
||||
batch_left_offset : -batch.padding_right_offset,
|
||||
]
|
||||
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
||||
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}")
|
||||
image_attention_mask[
|
||||
start_index:end_index,
|
||||
left_offset:-padding_right_offset,
|
||||
@ -678,7 +667,6 @@ class IdeficsCausalLM(Model):
|
||||
def generate_token(
|
||||
self, batch: IdeficsCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
|
||||
# from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter")
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
if batch.input_ids.size(1) == 1:
|
||||
@ -690,14 +678,6 @@ class IdeficsCausalLM(Model):
|
||||
image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
|
||||
else:
|
||||
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
|
||||
# image_hidden_states = batch.image_hidden_states[:, :-batch.padding_right_offset]
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}")
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}")
|
||||
|
||||
logits, past, image_hidden_states = self.forward(
|
||||
input_ids=batch.input_ids,
|
||||
@ -708,6 +688,8 @@ class IdeficsCausalLM(Model):
|
||||
image_attention_mask=image_attention_mask,
|
||||
past_key_values=batch.past_key_values,
|
||||
)
|
||||
# Hardcoded remove image tokens
|
||||
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
@ -815,7 +797,6 @@ class IdeficsCausalLM(Model):
|
||||
|
||||
# Update values
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}")
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
@ -828,7 +809,6 @@ class IdeficsCausalLM(Model):
|
||||
|
||||
# Slice unused values from prefill
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}")
|
||||
|
||||
# Update attention_mask as we added a new token to input_ids
|
||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
@ -843,5 +823,4 @@ class IdeficsCausalLM(Model):
|
||||
batch.past_key_values = past
|
||||
batch.image_hidden_states = image_hidden_states
|
||||
|
||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}")
|
||||
return generations, batch
|
||||
|
Loading…
Reference in New Issue
Block a user