diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 8fd4d2f8..bdf1969c 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -119,7 +119,6 @@ class IdeficsCausalLMBatch(Batch): padding_right_offset = 0 max_decode_tokens = 0 for i, r in enumerate(pb.requests): - # from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {i=} {r=}") requests_idx_mapping[r.id] = i inputs.append(r.inputs) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) @@ -136,7 +135,6 @@ class IdeficsCausalLMBatch(Batch): prompts = [] for inp in inputs: # Each input is encoded into a list, where each element of this input list is either a string or a URL - # from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py {inp=}") if isinstance(inp, str): prompts.append([inp]) elif isinstance(inp, list): @@ -169,9 +167,6 @@ class IdeficsCausalLMBatch(Batch): max_length=max_truncation, add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token ).to(device) - # from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - {tokenized_inputs['input_ids']=}") - # from loguru import logger; logger.info({k: v.size() for k,v in processed_inputs.items()}) - # {'input_ids': torch.Size([4, 5]), 'attention_mask': torch.Size([4, 5]), 'pixel_values': torch.Size([4, num_images, 3, 224, 224]), 'image_attention_mask': torch.Size([4, 5, num_images])} for _ in pb.requests: input_len = tokenized_inputs["input_ids"].shape[1] prefix_offsets.append(input_len - 5) # To decode without potential fallbacks errors @@ -193,8 +188,6 @@ class IdeficsCausalLMBatch(Batch): image_attention_mask = input_ids.new_zeros( (pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1)) ) - # image_attention_mask = tokenized_inputs["image_attention_mask"] - # from loguru import logger; logger.info(f"from_pb in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}") position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 @@ -227,7 +220,6 @@ class IdeficsCausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: - # from loguru import logger; logger.info(f"filter in idefics_causal_lm.py") # It deletes requests from the batch. For instance when client lost connection if len(request_ids) == 0: raise ValueError("Batch must have at least one request") @@ -339,7 +331,6 @@ class IdeficsCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch": - # from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py") # It adds new requests to the batch # Used for padding total_batch_size = 0 @@ -436,8 +427,6 @@ class IdeficsCausalLMBatch(Batch): :, batch_left_offset : -batch.padding_right_offset, ] - # from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - image_attention_mask {image_attention_mask.size()}") - # from loguru import logger; logger.info(f"concatenate in idefics_causal_lm.py - batch.image_attention_mask {batch.image_attention_mask.size()}") image_attention_mask[ start_index:end_index, left_offset:-padding_right_offset, @@ -678,7 +667,6 @@ class IdeficsCausalLM(Model): def generate_token( self, batch: IdeficsCausalLMBatch ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]: - # from loguru import logger; logger.info("generate_token in idefics_causal_lm.py - enter") # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] if batch.input_ids.size(1) == 1: @@ -690,14 +678,6 @@ class IdeficsCausalLM(Model): image_attention_mask = batch.image_attention_mask[:, -batch.padding_right_offset].unsqueeze(1) #TODO: verify that index. i have a doubt whether there is +1 hanging around else: image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset] - # image_hidden_states = batch.image_hidden_states[:, :-batch.padding_right_offset] - # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.padding_right_offset=}") - # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.attention_mask.size()=}") - # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {attention_mask.size()=}") - # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask=}") - # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {batch.image_attention_mask.size()=}") - # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask.size()=}") - # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {image_attention_mask=}") logits, past, image_hidden_states = self.forward( input_ids=batch.input_ids, @@ -708,6 +688,8 @@ class IdeficsCausalLM(Model): image_attention_mask=image_attention_mask, past_key_values=batch.past_key_values, ) + # Hardcoded remove image tokens + logits[:, 32000:32001] = torch.finfo(logits.dtype).min # Results generations: List[Generation] = [] @@ -815,7 +797,6 @@ class IdeficsCausalLM(Model): # Update values batch.input_ids[i, 0] = next_token_id - # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 1 {batch.input_ids.size()}") batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length batch.prefix_offsets[i] = prefix_offset @@ -828,7 +809,6 @@ class IdeficsCausalLM(Model): # Slice unused values from prefill batch.input_ids = batch.input_ids[:, :1] - # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - batch.input_ids 2 {batch.input_ids.size()}") # Update attention_mask as we added a new token to input_ids batch.attention_mask[:, -batch.padding_right_offset] = 1 @@ -843,5 +823,4 @@ class IdeficsCausalLM(Model): batch.past_key_values = past batch.image_hidden_states = image_hidden_states - # from loguru import logger; logger.info(f"generate_token in idefics_causal_lm.py - {stopped=}") return generations, batch