From 99771cfad58dc65c0cfd57df353e849d51b1ce31 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 8 Apr 2024 09:56:37 +0000 Subject: [PATCH] Upgrade tests (still missing load tests for some reason). --- .../test_flash_llava_next_all_params.json | 65 ++ .../test_flash_llava_next_simple.json | 73 ++ integration-tests/models/test_idefics.py | 8 +- integration-tests/models/test_llava_next.py | 10 +- .../custom_modeling/flash_mistral_modeling.py | 5 +- .../models/flash_causal_lm.py | 7 +- .../text_generation_server/models/idefics.py | 2 +- .../models/idefics_causal_lm.py | 848 +----------------- .../models/vlm_causal_lm.py | 31 + 9 files changed, 193 insertions(+), 856 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json diff --git a/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_all_params.json b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_all_params.json new file mode 100644 index 00000000..e9d3e5ef --- /dev/null +++ b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_all_params.json @@ -0,0 +1,65 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "stop_sequence", + "generated_tokens": 6, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -10.5, + "text": "Test" + }, + { + "id": 2159, + "logprob": -12.140625, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": -1.0654297, + "special": false, + "text": "\n" + }, + { + "id": 1014, + "logprob": -2.7460938, + "special": false, + "text": "The" + }, + { + "id": 6032, + "logprob": -1.359375, + "special": false, + "text": " purpose" + }, + { + "id": 302, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 456, + "logprob": 0.0, + "special": false, + "text": " this" + }, + { + "id": 1369, + "logprob": -0.40063477, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "Test request\nThe purpose of this test" +} diff --git a/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json new file mode 100644 index 00000000..f0f2ee9e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.00756073, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.20117188, + "special": false, + "text": "\n" + }, + { + "id": 16114, + "logprob": -1.2597656, + "special": false, + "text": "Once" + }, + { + "id": 3714, + "logprob": -0.20825195, + "special": false, + "text": " upon" + }, + { + "id": 264, + "logprob": -0.00178051, + "special": false, + "text": " a" + }, + { + "id": 727, + "logprob": -0.011955261, + "special": false, + "text": " time" + }, + { + "id": 28725, + "logprob": -0.17541504, + "special": false, + "text": "," + }, + { + "id": 736, + "logprob": -0.91308594, + "special": false, + "text": " there" + }, + { + "id": 403, + "logprob": -0.058410645, + "special": false, + "text": " was" + }, + { + "id": 264, + "logprob": -0.009689331, + "special": false, + "text": " a" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nOnce upon a time, there was a" +} diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index 7fb70a8f..aeeaffa1 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -33,7 +33,9 @@ async def test_idefics(idefics, response_snapshot): ) assert response.details.generated_tokens == 10 - assert response.generated_text == "\n\nDeep learning is a new type of machine" + assert ( + response.generated_text == " \nAssistant: A rooster stands" + ), f"{repr(response.generated_text)}" assert response == response_snapshot @@ -49,7 +51,9 @@ async def test_idefics_load(idefics, generate_load, response_snapshot): generated_texts = [r.generated_text for r in responses] - assert generated_texts[0] == "\n\nDeep learning is a new type of machine" + assert ( + generated_texts[0] == " \nAssistant: A rooster stands" + ), f"{response.generated_text}" assert len(generated_texts) == 4 assert generated_texts, all( [text == generated_texts[0] for text in generated_texts] diff --git a/integration-tests/models/test_llava_next.py b/integration-tests/models/test_llava_next.py index be968a4b..5deaafed 100644 --- a/integration-tests/models/test_llava_next.py +++ b/integration-tests/models/test_llava_next.py @@ -13,7 +13,7 @@ def get_chicken(): def flash_llava_next_handle(launcher): with launcher( "llava-hf/llava-v1.6-mistral-7b-hf", - num_shard=4, + num_shard=1, max_input_length=4000, max_total_tokens=4096, ) as handle: @@ -34,7 +34,9 @@ async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): f"User:![]({chicken})Can you tell me a very short story based on the image?", max_new_tokens=10, ) - assert response.generated_text == "toto" + assert ( + response.generated_text == "\n\nOnce upon a time, there was a" + ), f"{repr(response.generated_text)}" assert response.details.generated_tokens == 10 assert response == response_snapshot @@ -58,7 +60,7 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): seed=0, ) - assert response.details.generated_tokens == 5 + assert response.details.generated_tokens == 6 assert response == response_snapshot @@ -75,7 +77,7 @@ async def test_flash_llava_next_load( n=4, ) generated_texts = [r.generated_text for r in responses] - assert generated_texts[0] == "\n\nDeep learning is a new type of machine" + assert generated_texts[0] == "\n\nOnce upon a time, there was a" assert len(generated_texts) == 4 assert all([r.generated_text == generated_texts[0] for r in responses]) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index cab72f63..ffaa0c32 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -413,7 +413,10 @@ class FlashMistralForCausalLM(torch.nn.Module): super().__init__() self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.model.embed_tokens", weights=weights + prefix=( + "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" + ), + weights=weights, ) self.model = MistralModel( prefix="model" if not prefix else f"{prefix}.model", diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 94a7f023..be513511 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1047,12 +1047,7 @@ class FlashCausalLM(Model): batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids - try: - batch.input_lengths_tensor += accepted_ids - except Exception: - import ipdb - - ipdb.set_trace() + batch.input_lengths_tensor += accepted_ids batch.slot_indices += accepted_ids if prefill and prefill_logprobs: diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index a90a0d96..30bf4aa6 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -82,7 +82,7 @@ class IDEFICSSharded(IdeficsCausalLM): model = IdeficsForVisionText2Text(config, weights) torch.distributed.barrier(group=self.process_group) - super(VlmCausalLM, self).__init__( + super(IdeficsCausalLM, self).__init__( model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 211b425d..e78a9655 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -81,850 +81,10 @@ class IdeficsCausalLMBatch(Batch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - processor: ProcessorMixin, # Hack dtype: torch.dtype, device: torch.device, ) -> "IdeficsCausalLMBatch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - prefix_offsets = [] - read_offsets = [] - requests_idx_mapping = {} - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - inputs.append(r.inputs) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) - - prompts = [] - for inp in inputs: - # Each input is encoded into a list, where each element of this input list is either a string or a URL - prompts.append(split(inp)) - - # The processor replaces the call to tokenizer, and - # a/ takes care of fetching images from the URL - # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model - tokenized_inputs = processor( - prompts, - return_tensors="pt", - padding=True, - truncation=True, - max_length=max_truncation, - add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token - ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append( - input_len - 5 - ) # To decode without potential fallbacks errors - read_offsets.append( - input_len - ) # To decode without potential fallbacks errors - - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() - - input_ids = tokenized_inputs["input_ids"] - pixel_values = tokenized_inputs.get("pixel_values", None) - image_hidden_states = None - # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) - ) - # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] - # Do the same for image_attention_mask - if pixel_values is None: - image_attention_mask = None - else: - image_attention_mask = input_ids.new_zeros( - ( - pb.size, - max_input_length + padding_right_offset, - pixel_values.size(1), - ) - ) - image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ - "image_attention_mask" - ] - - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].T.split( - 1, dim=1 - ) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list - - max_tokens = len(inputs) * (max_input_length + max_decode_tokens) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) - - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: - # It deletes requests from the batch. For instance when client lost connection - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - requests = [] - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - max_input_length = 0 - - next_token_choosers = [] - stopping_criterias = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - keep_indices.append(idx) - - requests.append(self.requests[idx]) - prefix_offsets.append(self.prefix_offsets[idx]) - read_offsets.append(self.read_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - position_ids = self.position_ids[keep_indices] - self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - ] - # Do the same for pixel_values and image_attention_mask - pixel_values = self.pixel_values[keep_indices] - self.image_attention_mask = self.image_attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.image_attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - :, - ] - if self.image_hidden_states is None: - image_hidden_states = None - else: - image_hidden_states = self.image_hidden_states[keep_indices] - - # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) == tuple: - self.past_key_values = [list(layer) for layer in self.past_key_values] - - # Update tensors in-place to allow incremental garbage collection - past_kv_length = max_input_length - 1 - for layer in self.past_key_values: - past_keys, past_values = layer - if len(past_keys.shape) == 3: - # Force past to be of dim [self_size, num_heads, ...] for easy indexing - past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) - past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) - if self.keys_head_dim_last: - layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] - else: - layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] - del past_keys - layer[1] = past_values[keep_indices, :, -past_kv_length:, :] - del past_values - - max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.pixel_values = pixel_values - self.image_hidden_states = image_hidden_states - self.position_ids = position_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.prefix_offsets = prefix_offsets - self.read_offsets = read_offsets - self.next_token_choosers = next_token_choosers - self.stopping_criterias = stopping_criterias - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - - return self - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate( - cls, batches: List["IdeficsCausalLMBatch"] - ) -> "IdeficsCausalLMBatch": - # It adds new requests to the batch - # Used for padding - total_batch_size = 0 - max_input_length = 0 - max_num_images = 0 - padding_right_offset = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - max_num_images = max(max_num_images, batch.pixel_values.size(1)) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - max_tokens = 0 - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - pixel_values = None - image_hidden_states = None - image_attention_mask = None - past_key_values = [] - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - prefix_offsets.extend(batch.prefix_offsets) - read_offsets.extend(batch.read_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - # Create padded tensor - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_input_length + padding_right_offset), - ) - - curr_batch_max_num_images = batch.pixel_values.size(1) - if pixel_values is None: - pixel_values = batch.pixel_values.new_zeros( - (total_batch_size, max_num_images, 3, 224, 224) - ) - pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( - batch.pixel_values - ) - - if image_attention_mask is None: - image_attention_mask = batch.image_attention_mask.new_zeros( - ( - total_batch_size, - max_input_length + padding_right_offset, - max_num_images, - ) - ) - - # We need to slice the attention mask to remove padding from previous steps - # and to remove unused allocated space - left_offset = max_input_length - batch.max_input_length - batch_left_offset = ( - batch.attention_mask.shape[1] - - batch.max_input_length - - batch.padding_right_offset - ) - attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - ] = batch.attention_mask[ - :, - batch_left_offset : -batch.padding_right_offset, - ] - image_attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - :curr_batch_max_num_images, - ] = batch.image_attention_mask[ - :, batch_left_offset : -batch.padding_right_offset, : - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((total_batch_size, 1)) - position_ids[start_index:end_index] = batch.position_ids - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - if type(batch.past_key_values[0]) == tuple: - batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] - for layer in batch.past_key_values - ] - elif len(batch.past_key_values[0][0].shape) == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) - - start_index = end_index - - first_past_kvs = batches[0].past_key_values - _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - - padded_past_values_shape = ( - total_batch_size, - num_heads, - max_input_length - 1, - head_dim, - ) - - if batches[0].keys_head_dim_last: - padded_past_keys_shape = padded_past_values_shape - else: - # seq_length is last for BLOOM - padded_past_keys_shape = ( - total_batch_size, - num_heads, - head_dim, - max_input_length - 1, - ) - - # Iterate over attention layers - # Concatenate past key values layer by layer to allow incremental garbage collection - for j in range(len(first_past_kvs)): - padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) - start_index = 0 - for batch in batches: - past_keys = batch.past_key_values[j][0] - # Clear reference to the original tensor - batch.past_key_values[j][0] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - if batch.keys_head_dim_last: - padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( - past_keys[:, :, -past_seq_len:, :] - ) - else: - # BLOOM case - padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( - past_keys[:, :, :, -past_seq_len:] - ) - del past_keys - - start_index = end_index - - padded_past_values = first_past_kvs[j][1].new_zeros( - padded_past_values_shape - ) - start_index = 0 - for batch in batches: - past_values = batch.past_key_values[j][1] - # Clear reference to the original tensor - batch.past_key_values[j][1] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the past values to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( - past_values[:, :, -past_seq_len:, :] - ) - del past_values - - # Update values - start_index = end_index - - past_key_values.append([padded_past_keys, padded_past_values]) - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=past_key_values, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - ) - - def __len__(self): - return len(self.requests) - - -class IdeficsCausalLM(Model): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - from text_generation_server.models.custom_modeling.idefics_modeling import ( - IdeficsForVisionText2Text, - ) - - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.bfloat16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = IdeficsForVisionText2Text.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - device_map=( - "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None - ), - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1: - model = model.cuda() - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - tokenizer.pad_token_id = model.config.eos_token_id - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": ""}) - - super(IdeficsCausalLM, self).__init__( - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - @property - def batch_type(self) -> Type[IdeficsCausalLMBatch]: - return IdeficsCausalLMBatch - - def forward( - self, - input_ids, - attention_mask, - position_ids, - pixel_values, - image_hidden_states, - image_attention_mask, - past_key_values: Optional = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "image_hidden_states": image_hidden_states, - "image_attention_mask": image_attention_mask, - "past_key_values": past_key_values, - "use_cache": True, - "return_dict": True, - } - if self.has_position_ids: - kwargs["position_ids"] = position_ids - - outputs, speculative_logits = self.model.forward(**kwargs) - return ( - outputs.logits, - speculative_logits, - outputs.past_key_values, - outputs.image_hidden_states, - ) - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batch: IdeficsCausalLMBatch - ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]: - start = time.time_ns() - # slice the attention mask to the correct shape - attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - if batch.image_attention_mask is None: - image_attention_mask = None - else: - if batch.input_ids.size(1) == 1: - # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), - # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension - # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated - # token need to attend to the encoder hidden states (i.e. the vision encoder) - # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic - image_attention_mask = batch.image_attention_mask[ - :, -(batch.padding_right_offset + 1) - ].unsqueeze(1) - else: - image_attention_mask = batch.image_attention_mask[ - :, : -batch.padding_right_offset - ] - - logits, speculative_logits, past, image_hidden_states = self.forward( - input_ids=batch.input_ids, - attention_mask=attention_mask, - position_ids=batch.position_ids, - pixel_values=batch.pixel_values, - image_hidden_states=batch.image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=batch.past_key_values, - ) - # Hardcoded remove image tokens - logits[:, 32000:32001] = torch.finfo(logits.dtype).min - - start_decode = time.time_ns() - - # Results - generations: List[Generation] = [] - stopped = True - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - ) - - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits[-1:, :] - ) - - # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) - new_input_length = input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[:, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_squeezed, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + torch.log_softmax( - logits, -1 - ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - # Update values - batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( - next_token_id_squeezed.item() - ) - batch.input_ids[i, 0] = next_token_id - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) - - # We finished all generations in the batch; there is no next batch - if stopped: - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, None, (forward_ns, decode_ns) - - # Slice unused values from prefill - batch.input_ids = batch.input_ids[:, :1] - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask[:, -batch.padding_right_offset] = 1 - batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( - batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] - ) - # Decrease right offset - batch.padding_right_offset -= 1 - - # Update position_ids - batch.position_ids = batch.position_ids[:, -1:] + 1 - - # Update past key values - batch.past_key_values = past - batch.image_hidden_states = image_hidden_states - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch, (forward_ns, decode_ns) - - -import time - -from dataclasses import dataclass -from opentelemetry import trace -from transformers import ( - AutoProcessor, - AutoTokenizer, - PreTrainedTokenizerBase, - ProcessorMixin, -) -from typing import Optional, Tuple, List, Type, Dict - -from text_generation_server.models import Model -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling - -import re - -tracer = trace.get_tracer(__name__) - - -@dataclass -class IdeficsCausalLMBatch(Batch): - batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] - - # Decoder values - input_ids: torch.Tensor - attention_mask: torch.Tensor - position_ids: torch.Tensor - pixel_values: Optional[torch.Tensor] - image_hidden_states: Optional[torch.Tensor] - image_attention_mask: Optional[torch.Tensor] - past_key_values: Optional[List[Tuple]] - - # All tokens - all_input_ids: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - prefix_offsets: List[int] - read_offsets: List[int] - - # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - - # Metadata used for padding - max_input_length: int - padding_right_offset: int - - # Maximum number of tokens this batch will grow to - max_tokens: int - - # Past metadata - keys_head_dim_last: bool = True - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) + raise NotImplementedError @classmethod def from_pb_processor( @@ -932,6 +92,7 @@ class IdeficsCausalLMBatch(Batch): pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, processor: ProcessorMixin, # Hack + config, dtype: torch.dtype, device: torch.device, ) -> "IdeficsCausalLMBatch": @@ -966,7 +127,10 @@ class IdeficsCausalLMBatch(Batch): prompts = [] for inp in inputs: # Each input is encoded into a list, where each element of this input list is either a string or a URL - prompts.append(split(inp)) + prompt = [] + for chunk in split(inp): + prompt.append(chunk["content"]) + prompts.append(prompt) # The processor replaces the call to tokenizer, and # a/ takes care of fetching images from the URL diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 4ffce225..c965bea8 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,5 +1,8 @@ import re import torch +from PIL import Image +from io import BytesIO +import base64 from opentelemetry import trace from typing import Optional, Tuple, List, Type, Dict @@ -92,6 +95,13 @@ def get_number_of_features(height: int, width: int, config) -> int: return 2634 +def load_data_uri(image_uri: str) -> Image.Image: + image_uri = image_uri.split(",")[-1] + content = base64.b64decode(image_uri) + image = Image.open(BytesIO(content)) + return image + + # assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}" # assert get_number_of_features(640, 640) == 2928 @@ -100,6 +110,21 @@ class VlmCausalLMBatch(FlashMistralBatch): pixel_values: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches): + batch = super(VlmCausalLMBatch, cls).concatenate(batches) + batch.pixel_values = None + batch.image_sizes = None + return batch + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]): + batch = super().filter(request_ids) + batch.pixel_values = None + batch.image_sizes = None + return batch + @classmethod def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): batch_inputs = [] @@ -115,6 +140,12 @@ class VlmCausalLMBatch(FlashMistralBatch): image = chunk["content"] if image.startswith("https://") or image.startswith("http://"): image = processor.image_processor.fetch_images(image) + elif image.startswith("data:"): + image = load_data_uri(image) + else: + raise RuntimeError( + "Cannot process input image not starting with http(s):// nor data:" + ) image_input = processor.image_processor(image, return_tensors="pt") height, width = image_input["image_sizes"][0] num_features = get_number_of_features(height, width, config)