import torch import inspect import re from io import BytesIO import base64 from PIL import Image import re from dataclasses import dataclass from opentelemetry import trace from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model from text_generation_server.models.types import ( Batch, PrefillTokens, Generation, GeneratedText, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling import re IMAGES = re.compile(r'!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)') def split(string): parts = [] cursor = 0 for pattern in IMAGES.finditer(string): start = pattern.start() if start != cursor: parts.append(string[cursor:start]) parts.append(pattern.group(1)) cursor = pattern.end() if cursor != len(string): parts.append(string[cursor:]) return parts tracer = trace.get_tracer(__name__) @dataclass class IdeficsCausalLMBatch(Batch): batch_id: int requests: List[generate_pb2.Request] requests_idx_mapping: Dict[int, int] # Decoder values input_ids: torch.Tensor attention_mask: torch.Tensor position_ids: torch.Tensor pixel_values: Optional[torch.Tensor] image_hidden_states: Optional[torch.Tensor] image_attention_mask: Optional[torch.Tensor] past_key_values: Optional[List[Tuple]] # All tokens all_input_ids: List[torch.Tensor] # Lengths of all generations present in the batch input_lengths: List[int] prefix_offsets: List[int] read_offsets: List[int] # Generation helpers next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] # Metadata used for padding max_input_length: int padding_right_offset: int # Maximum number of tokens this batch will grow to max_tokens: int # Past metadata keys_head_dim_last: bool = True def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, ) @classmethod def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, processor: ProcessorMixin, # Hack dtype: torch.dtype, device: torch.device, ) -> "IdeficsCausalLMBatch": inputs = [] next_token_choosers = [] stopping_criterias = [] prefix_offsets = [] read_offsets = [] requests_idx_mapping = {} # Parse batch max_truncation = 0 padding_right_offset = 0 max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( padding_right_offset, stopping_criteria.max_new_tokens ) prompts = [] for inp in inputs: # Each input is encoded into a list, where each element of this input list is either a string or a URL prompts.append(split(inp)) # The processor replaces the call to tokenizer, and # a/ takes care of fetching images from the URL # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model tokenized_inputs = processor( prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_truncation, add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token ).to(device) for _ in pb.requests: input_len = tokenized_inputs["input_ids"].shape[1] prefix_offsets.append(input_len - 5) # To decode without potential fallbacks errors read_offsets.append(input_len) # To decode without potential fallbacks errors input_lengths = tokenized_inputs["attention_mask"].sum(1) max_input_length = input_lengths.max() input_ids = tokenized_inputs["input_ids"] pixel_values = tokenized_inputs["pixel_values"] image_hidden_states = None # Allocate maximum attention_mask attention_mask = input_ids.new_zeros( (pb.size, max_input_length + padding_right_offset) ) # Copy tokenizer attention_mask into fully allocated attention_mask attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] # Do the same for image_attention_mask image_attention_mask = input_ids.new_zeros( (pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1)) ) image_attention_mask[:, :max_input_length, :] = tokenized_inputs["image_attention_mask"] position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list max_tokens = len(inputs) * (max_input_length + max_decode_tokens) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, pixel_values=pixel_values, image_hidden_states=image_hidden_states, image_attention_mask=image_attention_mask, past_key_values=None, all_input_ids=list(all_input_ids), input_lengths=input_lengths.tolist(), prefix_offsets=prefix_offsets, read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens, ) @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: # It deletes requests from the batch. For instance when client lost connection if len(request_ids) == 0: raise ValueError("Batch must have at least one request") if len(request_ids) == len(self): return self keep_indices = [] # New values after filtering requests_idx_mapping = {} requests = [] input_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] max_input_length = 0 next_token_choosers = [] stopping_criterias = [] total_remaining_decode_tokens = 0 new_padding_right_offset = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] requests_idx_mapping[request_id] = i keep_indices.append(idx) requests.append(self.requests[idx]) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) all_input_ids.append(self.all_input_ids[idx]) request_input_length = self.input_lengths[idx] input_lengths.append(request_input_length) max_input_length = max(max_input_length, request_input_length) next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) remaining_decode_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) total_remaining_decode_tokens += remaining_decode_tokens new_padding_right_offset = max( new_padding_right_offset, remaining_decode_tokens ) # Apply indices to input_ids, attention mask, past key values and other items that need to be cached input_ids = self.input_ids[keep_indices] position_ids = self.position_ids[keep_indices] self.attention_mask = self.attention_mask[ keep_indices, -(self.padding_right_offset + max_input_length) : ( self.attention_mask.shape[1] - self.padding_right_offset ) + new_padding_right_offset, ] # Do the same for pixel_values and image_attention_mask pixel_values = self.pixel_values[keep_indices] self.image_attention_mask = self.image_attention_mask[ keep_indices, -(self.padding_right_offset + max_input_length) : ( self.image_attention_mask.shape[1] - self.padding_right_offset ) + new_padding_right_offset, : ] if self.image_hidden_states is None: image_hidden_states = None else: image_hidden_states = self.image_hidden_states[keep_indices] # Ensure that past_key_values tensors can be updated in-place if type(self.past_key_values[0]) == tuple: self.past_key_values = [list(layer) for layer in self.past_key_values] # Update tensors in-place to allow incremental garbage collection past_kv_length = max_input_length - 1 for layer in self.past_key_values: past_keys, past_values = layer if len(past_keys.shape) == 3: # Force past to be of dim [self_size, num_heads, ...] for easy indexing past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) if self.keys_head_dim_last: layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] else: layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] del past_keys layer[1] = past_values[keep_indices, :, -past_kv_length:, :] del past_values max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens self.requests = requests self.requests_idx_mapping = requests_idx_mapping self.input_ids = input_ids self.pixel_values = pixel_values self.image_hidden_states = image_hidden_states self.position_ids = position_ids self.all_input_ids = all_input_ids self.input_lengths = input_lengths self.prefix_offsets = prefix_offsets self.read_offsets = read_offsets self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias self.max_input_length = max_input_length self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens return self @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch": # It adds new requests to the batch # Used for padding total_batch_size = 0 max_input_length = 0 max_num_images = 0 padding_right_offset = 0 for batch in batches: total_batch_size += len(batch) max_input_length = max(max_input_length, batch.max_input_length) max_num_images = max(max_num_images, batch.pixel_values.size(1)) padding_right_offset = max(padding_right_offset, batch.padding_right_offset) # Batch attributes requests = [] requests_idx_mapping = {} input_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] next_token_choosers = [] stopping_criterias = [] max_tokens = 0 # Batch tensors input_ids = None attention_mask = None position_ids = None pixel_values = None image_hidden_states = None image_attention_mask = None past_key_values = [] # Used for slicing correctly inside the tensors # Equivalent to a cumsum on batch sizes start_index = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) all_input_ids.extend(batch.all_input_ids) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) if i == 0: requests_idx_mapping = batch.requests_idx_mapping else: # We need to offset the mapping for each batch by the cumulative batch size for k, v in batch.requests_idx_mapping.items(): requests_idx_mapping[k] = v + start_index # Slicing end index for this batch end_index = start_index + len(batch) # We only concatenate batches that did at least one step if batch.past_key_values is None: raise ValueError("only concatenate prefilled batches") # Create empty tensor # input_ids is always of shape [batch_size, 1] # We do not need to pad it if input_ids is None: input_ids = batch.input_ids.new_empty((total_batch_size, 1)) # Copy to correct indices input_ids[start_index:end_index] = batch.input_ids # Create padded tensor if attention_mask is None: attention_mask = batch.attention_mask.new_zeros( (total_batch_size, max_input_length + padding_right_offset), ) curr_batch_max_num_images = batch.pixel_values.size(1) if pixel_values is None: pixel_values = batch.pixel_values.new_zeros((total_batch_size, max_num_images, 3, 224, 224)) pixel_values[start_index:end_index, :curr_batch_max_num_images] = batch.pixel_values if image_attention_mask is None: image_attention_mask = batch.image_attention_mask.new_zeros( (total_batch_size, max_input_length + padding_right_offset, max_num_images) ) # We need to slice the attention mask to remove padding from previous steps # and to remove unused allocated space left_offset = max_input_length - batch.max_input_length batch_left_offset = ( batch.attention_mask.shape[1] - batch.max_input_length - batch.padding_right_offset ) attention_mask[ start_index:end_index, left_offset:-padding_right_offset, ] = batch.attention_mask[ :, batch_left_offset : -batch.padding_right_offset, ] image_attention_mask[ start_index:end_index, left_offset:-padding_right_offset, :curr_batch_max_num_images ] = batch.image_attention_mask[ :, batch_left_offset : - batch.padding_right_offset, : ] # Create empty tensor # position_ids is always of shape [batch_size, 1] if position_ids is None: position_ids = batch.position_ids.new_empty((total_batch_size, 1)) position_ids[start_index:end_index] = batch.position_ids # Shenanigans to get dimensions because BLOOM outputs a past with a different shape # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] # And ensure that we can update tensors in-place if type(batch.past_key_values[0]) == tuple: batch.past_key_values = [ [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values ] elif len(batch.past_key_values[0][0].shape) == 3: for layer in batch.past_key_values: for k, t in enumerate(layer): layer[k] = t.view(len(batch), -1, *t.shape[-2:]) # Add eventual padding tokens that were added while concatenating max_tokens += batch.max_tokens + ( max_input_length - batch.max_input_length ) * len(batch) start_index = end_index first_past_kvs = batches[0].past_key_values _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape padded_past_values_shape = ( total_batch_size, num_heads, max_input_length - 1, head_dim, ) if batches[0].keys_head_dim_last: padded_past_keys_shape = padded_past_values_shape else: # seq_length is last for BLOOM padded_past_keys_shape = ( total_batch_size, num_heads, head_dim, max_input_length - 1, ) # Iterate over attention layers # Concatenate past key values layer by layer to allow incremental garbage collection for j in range(len(first_past_kvs)): padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) start_index = 0 for batch in batches: past_keys = batch.past_key_values[j][0] # Clear reference to the original tensor batch.past_key_values[j][0] = None # Slicing end index for this batch end_index = start_index + len(batch) # We slice the keys to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 if batch.keys_head_dim_last: padded_past_keys[ start_index:end_index, :, -past_seq_len:, : ] = past_keys[:, :, -past_seq_len:, :] else: # BLOOM case padded_past_keys[ start_index:end_index, :, :, -past_seq_len: ] = past_keys[:, :, :, -past_seq_len:] del past_keys start_index = end_index padded_past_values = first_past_kvs[j][1].new_zeros( padded_past_values_shape ) start_index = 0 for batch in batches: past_values = batch.past_key_values[j][1] # Clear reference to the original tensor batch.past_key_values[j][1] = None # Slicing end index for this batch end_index = start_index + len(batch) # We slice the past values to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 padded_past_values[ start_index:end_index, :, -past_seq_len:, : ] = past_values[:, :, -past_seq_len:, :] del past_values # Update values start_index = end_index past_key_values.append([padded_past_keys, padded_past_values]) return cls( batch_id=batches[0].batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, pixel_values=pixel_values, image_hidden_states=image_hidden_states, image_attention_mask=image_attention_mask, past_key_values=past_key_values, all_input_ids=all_input_ids, input_lengths=input_lengths, prefix_offsets=prefix_offsets, read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length, padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, max_tokens=max_tokens, ) def __len__(self): return len(self.requests) class IdeficsCausalLM(Model): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") device = torch.device("cpu") dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left", trust_remote_code=trust_remote_code, ) self.processor = AutoProcessor.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left", trust_remote_code=trust_remote_code, ) model = IdeficsForVisionText2Text.from_pretrained( model_id, revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: tokenizer.pad_token_id = model.config.pad_token_id elif model.config.eos_token_id is not None: tokenizer.pad_token_id = model.config.eos_token_id elif tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id else: tokenizer.add_special_tokens({"pad_token": ""}) super(IdeficsCausalLM, self).__init__( model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, device=device, ) @property def batch_type(self) -> Type[IdeficsCausalLMBatch]: return IdeficsCausalLMBatch def decode(self, generated_ids: List[int]) -> str: return self.tokenizer.decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) def forward( self, input_ids, attention_mask, position_ids, pixel_values, image_hidden_states, image_attention_mask, past_key_values: Optional = None, ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values, "image_hidden_states": image_hidden_states, "image_attention_mask": image_attention_mask, "past_key_values": past_key_values, "use_cache": True, "return_dict": True, } if self.has_position_ids: kwargs["position_ids"] = position_ids outputs = self.model.forward(**kwargs) return outputs.logits, outputs.past_key_values, outputs.image_hidden_states @tracer.start_as_current_span("generate_token") def generate_token( self, batch: IdeficsCausalLMBatch ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]: # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] if batch.input_ids.size(1) == 1: # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated # token need to attend to the encoder hidden states (i.e. the vision encoder) # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic image_attention_mask = batch.image_attention_mask[:, -(batch.padding_right_offset+1)].unsqueeze(1) else: image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset] logits, past, image_hidden_states = self.forward( input_ids=batch.input_ids, attention_mask=attention_mask, position_ids=batch.position_ids, pixel_values=batch.pixel_values, image_hidden_states=batch.image_hidden_states, image_attention_mask=image_attention_mask, past_key_values=batch.past_key_values, ) # Hardcoded remove image tokens logits[:, 32000:32001] = torch.finfo(logits.dtype).min # Results generations: List[Generation] = [] stopped = True # Zipped iterator iterator = zip( batch.requests, batch.input_lengths, batch.prefix_offsets, batch.read_offsets, logits, batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, ) # For each member of the batch for i, ( request, input_length, prefix_offset, read_offset, logits, next_token_chooser, stopping_criteria, all_input_ids, ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser( all_input_ids.view(1, -1), logits[-1:, :] ) # Append next token to all tokens all_input_ids = torch.cat([all_input_ids, next_token_id]) new_input_length = input_length + 1 # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() next_token_text, prefix_offset, read_offset = self.decode_token( all_input_ids[:, 0], prefix_offset, read_offset ) # Evaluate stopping criteria stop, reason = stopping_criteria( next_token_id_squeezed, next_token_text, ) if not stop: stopped = False # Shard generations # All generations will be appended in the rust sharded client if i % self.world_size == self.rank: if stop: # Decode generated tokens output_text = self.decode( all_input_ids[-stopping_criteria.current_tokens :, 0] ) # Get seed if isinstance(next_token_chooser.choice, Sampling): seed = next_token_chooser.choice.seed else: seed = None generated_text = GeneratedText( output_text, stopping_criteria.current_tokens, reason, seed ) else: generated_text = None # Prefill if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [float("nan")] + torch.log_softmax( logits, -1 ).gather(1, all_input_ids[1:]).squeeze(1)[ -new_input_length:-1 ].tolist() prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) prefill_tokens = PrefillTokens( prefill_token_ids, prefill_logprobs, prefill_texts ) else: prefill_tokens = None top_tokens=None generation = Generation( request.id, prefill_tokens, next_token_id_squeezed, next_token_logprob, next_token_text, next_token_id_squeezed.item() in self.all_special_ids, generated_text, top_tokens ) generations.append(generation) # Update values batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.max_input_length = max(batch.max_input_length, new_input_length) # We finished all generations in the batch; there is no next batch if stopped: return generations, None # Slice unused values from prefill batch.input_ids = batch.input_ids[:, :1] # Update attention_mask as we added a new token to input_ids batch.attention_mask[:, -batch.padding_right_offset] = 1 batch.image_attention_mask[:, -batch.padding_right_offset, :] = batch.image_attention_mask[:, -(batch.padding_right_offset+1), :] # Decrease right offset batch.padding_right_offset -= 1 # Update position_ids batch.position_ids = batch.position_ids[:, -1:] + 1 # Update past key values batch.past_key_values = past batch.image_hidden_states = image_hidden_states return generations, batch