Hardcode remove image tokens.

This commit is contained in:
Nicolas Patry 2023-08-15 15:24:36 +00:00
parent 51348713d6
commit defc477a03

View File

@ -119,7 +119,6 @@ class IdeficsCausalLMBatch(Batch):
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}")
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
@ -136,7 +135,6 @@ class IdeficsCausalLMBatch(Batch):
prompts = []
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}")
if isinstance(inp, str):
prompts.append([inp])
elif isinstance(inp, list):
@ -169,9 +167,6 @@ class IdeficsCausalLMBatch(Batch):
max_length=max_truncation,
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device)
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}")
# from loguru import logger; logger.info({k: v.size() for k,v in processed_inputs.items()})
# {'input_ids': torch.Size([4, 5]), 'attention_mask': torch.Size([4, 5]), 'pixel_values': torch.Size([4, num_images, 3, 224, 224]), 'image_attention_mask': torch.Size([4, 5, num_images])}
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(input_len - 5) # To decode without potential fallbacks errors
@ -193,8 +188,6 @@ class IdeficsCausalLMBatch(Batch):
image_attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1))
)
# image_attention_mask = tokenized_inputs["image_attention_mask"]
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
@ -227,7 +220,6 @@ class IdeficsCausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
# from loguru import logger; logger.info(f"filter in idefics_causal_lm.py")
# It deletes requests from the batch. For instance when client lost connection
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
@ -339,7 +331,6 @@ class IdeficsCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch":
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py")
# It adds new requests to the batch
# Used for padding
total_batch_size = 0
@ -436,8 +427,6 @@ class IdeficsCausalLMBatch(Batch):
:,
batch_left_offset : -batch.padding_right_offset,
]
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}")
image_attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
@ -678,7 +667,6 @@ class IdeficsCausalLM(Model):
def generate_token(
self, batch: IdeficsCausalLMBatch
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
# from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter")
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.input_ids.size(1) == 1:
@ -690,14 +678,6 @@ class IdeficsCausalLM(Model):
image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
else:
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
# image_hidden_states = batch.image_hidden_states[:, :-batch.padding_right_offset]
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}")
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}")
logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids,
@ -708,6 +688,8 @@ class IdeficsCausalLM(Model):
image_attention_mask=image_attention_mask,
past_key_values=batch.past_key_values,
)
# Hardcoded remove image tokens
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
# Results
generations: List[Generation] = []
@ -815,7 +797,6 @@ class IdeficsCausalLM(Model):
# Update values
batch.input_ids[i, 0] = next_token_id
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}")
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
@ -828,7 +809,6 @@ class IdeficsCausalLM(Model):
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}")
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
@ -843,5 +823,4 @@ class IdeficsCausalLM(Model):
batch.past_key_values = past
batch.image_hidden_states = image_hidden_states
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}")
return generations, batch