mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Hardcode remove image tokens.
This commit is contained in:
parent
51348713d6
commit
defc477a03
@ -119,7 +119,6 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}")
|
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
@ -136,7 +135,6 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
prompts = []
|
prompts = []
|
||||||
for inp in inputs:
|
for inp in inputs:
|
||||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}")
|
|
||||||
if isinstance(inp, str):
|
if isinstance(inp, str):
|
||||||
prompts.append([inp])
|
prompts.append([inp])
|
||||||
elif isinstance(inp, list):
|
elif isinstance(inp, list):
|
||||||
@ -169,9 +167,6 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||||
).to(device)
|
).to(device)
|
||||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}")
|
|
||||||
# from loguru import logger; logger.info({k: v.size() for k,v in processed_inputs.items()})
|
|
||||||
# {'input_ids': torch.Size([4, 5]), 'attention_mask': torch.Size([4, 5]), 'pixel_values': torch.Size([4, num_images, 3, 224, 224]), 'image_attention_mask': torch.Size([4, 5, num_images])}
|
|
||||||
for _ in pb.requests:
|
for _ in pb.requests:
|
||||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
prefix_offsets.append(input_len - 5) # To decode without potential fallbacks errors
|
prefix_offsets.append(input_len - 5) # To decode without potential fallbacks errors
|
||||||
@ -193,8 +188,6 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
image_attention_mask = input_ids.new_zeros(
|
image_attention_mask = input_ids.new_zeros(
|
||||||
(pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1))
|
(pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1))
|
||||||
)
|
)
|
||||||
# image_attention_mask = tokenized_inputs["image_attention_mask"]
|
|
||||||
# from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
|
||||||
|
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
@ -227,7 +220,6 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
|
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
|
||||||
# from loguru import logger; logger.info(f"filter in idefics_causal_lm.py")
|
|
||||||
# It deletes requests from the batch. For instance when client lost connection
|
# It deletes requests from the batch. For instance when client lost connection
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
@ -339,7 +331,6 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch":
|
def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch":
|
||||||
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py")
|
|
||||||
# It adds new requests to the batch
|
# It adds new requests to the batch
|
||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
@ -436,8 +427,6 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
:,
|
:,
|
||||||
batch_left_offset : -batch.padding_right_offset,
|
batch_left_offset : -batch.padding_right_offset,
|
||||||
]
|
]
|
||||||
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}")
|
|
||||||
# from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}")
|
|
||||||
image_attention_mask[
|
image_attention_mask[
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
left_offset:-padding_right_offset,
|
left_offset:-padding_right_offset,
|
||||||
@ -678,7 +667,6 @@ class IdeficsCausalLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: IdeficsCausalLMBatch
|
self, batch: IdeficsCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
|
||||||
# from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter")
|
|
||||||
# slice the attention mask to the correct shape
|
# slice the attention mask to the correct shape
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
if batch.input_ids.size(1) == 1:
|
if batch.input_ids.size(1) == 1:
|
||||||
@ -690,14 +678,6 @@ class IdeficsCausalLM(Model):
|
|||||||
image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
|
image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around
|
||||||
else:
|
else:
|
||||||
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
|
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
|
||||||
# image_hidden_states = batch.image_hidden_states[:, :-batch.padding_right_offset]
|
|
||||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}")
|
|
||||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}")
|
|
||||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}")
|
|
||||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}")
|
|
||||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}")
|
|
||||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}")
|
|
||||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}")
|
|
||||||
|
|
||||||
logits, past, image_hidden_states = self.forward(
|
logits, past, image_hidden_states = self.forward(
|
||||||
input_ids=batch.input_ids,
|
input_ids=batch.input_ids,
|
||||||
@ -708,6 +688,8 @@ class IdeficsCausalLM(Model):
|
|||||||
image_attention_mask=image_attention_mask,
|
image_attention_mask=image_attention_mask,
|
||||||
past_key_values=batch.past_key_values,
|
past_key_values=batch.past_key_values,
|
||||||
)
|
)
|
||||||
|
# Hardcoded remove image tokens
|
||||||
|
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -815,7 +797,6 @@ class IdeficsCausalLM(Model):
|
|||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}")
|
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
@ -828,7 +809,6 @@ class IdeficsCausalLM(Model):
|
|||||||
|
|
||||||
# Slice unused values from prefill
|
# Slice unused values from prefill
|
||||||
batch.input_ids = batch.input_ids[:, :1]
|
batch.input_ids = batch.input_ids[:, :1]
|
||||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}")
|
|
||||||
|
|
||||||
# Update attention_mask as we added a new token to input_ids
|
# Update attention_mask as we added a new token to input_ids
|
||||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
@ -843,5 +823,4 @@ class IdeficsCausalLM(Model):
|
|||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
batch.image_hidden_states = image_hidden_states
|
batch.image_hidden_states = image_hidden_states
|
||||||
|
|
||||||
# from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}")
|
|
||||||
return generations, batch
|
return generations, batch
|
||||||
|
Loading…
Reference in New Issue
Block a user