| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  | from typing import Iterable, Optional, Tuple, List, Dict | 
					
						
							|  |  |  | from text_generation_server.pb.generate_pb2 import Request | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  | from io import BytesIO | 
					
						
							|  |  |  | from PIL import Image | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  | from dataclasses import dataclass | 
					
						
							|  |  |  | from opentelemetry import trace | 
					
						
							|  |  |  | from transformers import ( | 
					
						
							|  |  |  |     PreTrainedTokenizerBase, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  | from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM | 
					
						
							|  |  |  | from text_generation_server.pb import generate_pb2 | 
					
						
							|  |  |  | from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION | 
					
						
							|  |  |  | from text_generation_server.layers.attention import Seqlen | 
					
						
							| 
									
										
										
										
											2024-10-25 21:10:00 +00:00
										 |  |  | from text_generation_server.models.metadata_kernels import block_tables_to_ragged | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | tracer = trace.get_tracer(__name__) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @dataclass | 
					
						
							|  |  |  | class MllamaCausalLMBatch(VlmCausalLMBatch): | 
					
						
							|  |  |  |     image_indices: List[int] = 42 | 
					
						
							|  |  |  |     aspect_ratio_ids: Optional[torch.Tensor] = None | 
					
						
							|  |  |  |     aspect_ratio_mask: Optional[torch.Tensor] = None | 
					
						
							|  |  |  |     cross_attention_states: Optional[torch.Tensor] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     @tracer.start_as_current_span("concatenate") | 
					
						
							|  |  |  |     def concatenate(cls, batches): | 
					
						
							|  |  |  |         batch = super().concatenate(batches) | 
					
						
							|  |  |  |         batch.pixel_values = None | 
					
						
							|  |  |  |         batch.pixel_attention_mask = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         offset = 0 | 
					
						
							|  |  |  |         image_indices = [] | 
					
						
							|  |  |  |         attention_states = [] | 
					
						
							|  |  |  |         for b in batches: | 
					
						
							|  |  |  |             if b.cross_attention_states is not None: | 
					
						
							|  |  |  |                 attention_states.append(b.cross_attention_states) | 
					
						
							|  |  |  |             image_indices.extend([i + offset for i in b.image_indices]) | 
					
						
							|  |  |  |             offset += len(b.image_indices) | 
					
						
							|  |  |  |         if len(attention_states) > 0: | 
					
						
							|  |  |  |             assert len(image_indices) > 0 | 
					
						
							|  |  |  |             batch.cross_attention_states = torch.cat(attention_states, dim=0) | 
					
						
							|  |  |  |             batch.image_indices = image_indices | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             batch.cross_attention_states = None | 
					
						
							|  |  |  |             batch.image_indices = [] | 
					
						
							|  |  |  |         return batch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @tracer.start_as_current_span("filter") | 
					
						
							|  |  |  |     def filter(self, request_ids: List[int]): | 
					
						
							|  |  |  |         assert self.image_indices is not None | 
					
						
							|  |  |  |         batch = super().filter(request_ids) | 
					
						
							|  |  |  |         assert self.image_indices is not None | 
					
						
							|  |  |  |         indices = [] | 
					
						
							|  |  |  |         for i, request_id in enumerate(request_ids): | 
					
						
							|  |  |  |             idx = self.requests_idx_mapping[request_id] | 
					
						
							|  |  |  |             indices.append(idx) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         offset = 0 | 
					
						
							|  |  |  |         new_image_indices = [] | 
					
						
							|  |  |  |         prev_i = None | 
					
						
							|  |  |  |         for i in self.image_indices: | 
					
						
							|  |  |  |             if i in indices: | 
					
						
							|  |  |  |                 new_image_indices.append(offset) | 
					
						
							|  |  |  |                 if i != prev_i: | 
					
						
							|  |  |  |                     offset += 1 | 
					
						
							|  |  |  |                 prev_i = i | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         batch.image_indices = new_image_indices | 
					
						
							|  |  |  |         if len(new_image_indices) > 0: | 
					
						
							|  |  |  |             assert max(new_image_indices) < self.cross_attention_states.shape[0] | 
					
						
							|  |  |  |             assert offset <= self.cross_attention_states.shape[0] | 
					
						
							|  |  |  |             batch.cross_attention_states = self.cross_attention_states[ | 
					
						
							|  |  |  |                 new_image_indices | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             batch.cross_attention_states = None | 
					
						
							|  |  |  |         return batch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def batch_tokenized_inputs( | 
					
						
							|  |  |  |         cls, requests: Iterable[Request], tokenizer, processor, config | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         image_inputs = [] | 
					
						
							|  |  |  |         texts = [] | 
					
						
							|  |  |  |         image_indices = [] | 
					
						
							|  |  |  |         batch_tokenized_inputs = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for i, r in enumerate(requests): | 
					
						
							|  |  |  |             # Each input is encoded into a list, where each element of this input list is either a string or a URL | 
					
						
							|  |  |  |             curr_text = "" | 
					
						
							|  |  |  |             curr_image = None | 
					
						
							|  |  |  |             curr_i = None | 
					
						
							|  |  |  |             for chunk in r.input_chunks.chunks: | 
					
						
							|  |  |  |                 chunk_type = chunk.WhichOneof("chunk") | 
					
						
							|  |  |  |                 if chunk_type == "text": | 
					
						
							|  |  |  |                     curr_text += chunk.text | 
					
						
							|  |  |  |                 elif chunk_type == "image": | 
					
						
							|  |  |  |                     image = Image.open(BytesIO(chunk.image.data)) | 
					
						
							|  |  |  |                     # TODO unsure about BOS | 
					
						
							|  |  |  |                     curr_text += "<|image|>" | 
					
						
							|  |  |  |                     image_input = processor.image_processor(image, return_tensors="pt") | 
					
						
							|  |  |  |                     curr_image = image_input | 
					
						
							|  |  |  |                     curr_i = i | 
					
						
							|  |  |  |                     # image_inputs.append(image_input) | 
					
						
							|  |  |  |                     # image_indices.append(i) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     raise RuntimeError(f"Invalid chunk type {chunk_type}") | 
					
						
							|  |  |  |             texts.append(curr_text) | 
					
						
							|  |  |  |             if curr_image is not None: | 
					
						
							|  |  |  |                 image_inputs.append(curr_image) | 
					
						
							|  |  |  |                 image_indices.append(curr_i) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             input_ids = tokenizer( | 
					
						
							|  |  |  |                 curr_text, | 
					
						
							|  |  |  |                 truncation=True, | 
					
						
							|  |  |  |                 max_length=r.truncate, | 
					
						
							|  |  |  |                 add_special_tokens=r.add_special_tokens, | 
					
						
							|  |  |  |             )["input_ids"] | 
					
						
							|  |  |  |             batch_tokenized_inputs.append(input_ids) | 
					
						
							|  |  |  |         if image_inputs: | 
					
						
							|  |  |  |             image_input = image_inputs[0] | 
					
						
							|  |  |  |             new_image_inputs = { | 
					
						
							|  |  |  |                 "pixel_values": torch.cat( | 
					
						
							|  |  |  |                     [img["pixel_values"] for img in image_inputs], dim=0 | 
					
						
							|  |  |  |                 ), | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if "aspect_ratio_ids" in image_input: | 
					
						
							|  |  |  |                 new_image_inputs["aspect_ratio_ids"] = torch.cat( | 
					
						
							|  |  |  |                     [img["aspect_ratio_ids"] for img in image_inputs], dim=0 | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             if "aspect_ratio_mask" in image_input: | 
					
						
							|  |  |  |                 new_image_inputs["aspect_ratio_mask"] = torch.cat( | 
					
						
							|  |  |  |                     [img["aspect_ratio_mask"] for img in image_inputs], dim=0 | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             image_inputs = new_image_inputs | 
					
						
							|  |  |  |             image_inputs["image_indices"] = image_indices | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             image_inputs = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if image_inputs is not None: | 
					
						
							|  |  |  |             assert len(image_indices) == image_inputs["pixel_values"].shape[0] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return batch_tokenized_inputs, image_inputs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def from_pb_processor( | 
					
						
							|  |  |  |         cls, | 
					
						
							|  |  |  |         pb: generate_pb2.Batch, | 
					
						
							|  |  |  |         tokenizer: PreTrainedTokenizerBase, | 
					
						
							|  |  |  |         processor, | 
					
						
							|  |  |  |         config, | 
					
						
							|  |  |  |         dtype: torch.dtype, | 
					
						
							|  |  |  |         device: torch.device, | 
					
						
							|  |  |  |     ) -> "VlmCausalLMBatch": | 
					
						
							|  |  |  |         batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( | 
					
						
							|  |  |  |             pb.requests, tokenizer, processor, config | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) | 
					
						
							|  |  |  |         # XXX: <|image|> token is actually out of bounds and bugs out the logit processors. | 
					
						
							|  |  |  |         batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( | 
					
						
							|  |  |  |             max=config.text_config.vocab_size - 1 | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |         if isinstance(batch.input_ids, list): | 
					
						
							|  |  |  |             if len(batch) > 1: | 
					
						
							|  |  |  |                 input_ids = np.concatenate(batch.input_ids, dtype=np.int64) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 input_ids = batch.input_ids[0] | 
					
						
							|  |  |  |             batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  |         batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if image_inputs is not None: | 
					
						
							|  |  |  |             batch.pixel_values = image_inputs["pixel_values"].to( | 
					
						
							|  |  |  |                 device=device, dtype=dtype | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device) | 
					
						
							|  |  |  |             batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to( | 
					
						
							|  |  |  |                 device=device | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             batch.image_indices = image_inputs["image_indices"] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             batch.pixel_values = None | 
					
						
							|  |  |  |             batch.aspect_ratio_ids = None | 
					
						
							|  |  |  |             batch.aspect_ratio_mask = None | 
					
						
							|  |  |  |             batch.image_indices = [] | 
					
						
							|  |  |  |         assert batch.image_indices is not None | 
					
						
							|  |  |  |         return batch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MllamaCausalLM(VlmCausalLM): | 
					
						
							|  |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |         batch: MllamaCausalLMBatch, | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  |         adapter_data: Optional[Dict[str, torch.Tensor]] = None, | 
					
						
							|  |  |  |     ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | 
					
						
							|  |  |  |         # Model Forward | 
					
						
							|  |  |  |         if batch.speculative_ids is not None: | 
					
						
							|  |  |  |             input_ids = batch.input_ids | 
					
						
							|  |  |  |             position_ids = batch.position_ids | 
					
						
							|  |  |  |             cu_seqlen_prefill = batch.cu_seqlen_prefill | 
					
						
							|  |  |  |             kv_cache = self.kv_cache | 
					
						
							|  |  |  |             block_tables = batch.block_tables_tensor | 
					
						
							|  |  |  |             slots = batch.slots[batch.slot_indices] | 
					
						
							|  |  |  |             input_lengths = batch.input_lengths_tensor | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |             max_s = batch.max_current_length | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  |             lm_head_indices = batch.prefill_head_indices | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             speculative_ids = batch.speculative_ids | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             B, speculative_length = speculative_ids.shape | 
					
						
							|  |  |  |             new_length = speculative_length + 1 | 
					
						
							|  |  |  |             new_input_ids = torch.cat( | 
					
						
							|  |  |  |                 [input_ids.unsqueeze(-1), speculative_ids], dim=1 | 
					
						
							|  |  |  |             ).reshape(-1) | 
					
						
							|  |  |  |             arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) | 
					
						
							|  |  |  |             arange_int = arange.to(dtype=torch.int32) | 
					
						
							|  |  |  |             new_position_ids = ( | 
					
						
							|  |  |  |                 position_ids.unsqueeze(-1).expand(B, new_length) + arange | 
					
						
							|  |  |  |             ).view(-1) | 
					
						
							|  |  |  |             slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) | 
					
						
							|  |  |  |             input_lengths = ( | 
					
						
							|  |  |  |                 input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int | 
					
						
							|  |  |  |             ).view(-1) | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |             cache_lengths_tensor = ( | 
					
						
							|  |  |  |                 batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  |             ).reshape(-1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Add Copy the block tables for all members | 
					
						
							|  |  |  |             block_tables = ( | 
					
						
							|  |  |  |                 block_tables.unsqueeze(1) | 
					
						
							|  |  |  |                 .expand(B, new_length, -1) | 
					
						
							|  |  |  |                 .reshape(B * new_length, -1) | 
					
						
							|  |  |  |                 .contiguous() | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             max_s = max_s + speculative_length | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             input_ids = new_input_ids | 
					
						
							|  |  |  |             position_ids = new_position_ids | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             input_ids = batch.input_ids | 
					
						
							|  |  |  |             position_ids = batch.position_ids | 
					
						
							|  |  |  |             cu_seqlen_prefill = batch.cu_seqlen_prefill | 
					
						
							|  |  |  |             kv_cache = self.kv_cache | 
					
						
							|  |  |  |             block_tables = batch.block_tables_tensor | 
					
						
							|  |  |  |             slots = batch.slots[batch.slot_indices] | 
					
						
							|  |  |  |             input_lengths = batch.input_lengths_tensor | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |             cache_lengths_tensor = batch.cache_lengths_tensor | 
					
						
							|  |  |  |             max_s = batch.max_current_length | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  |             lm_head_indices = batch.prefill_head_indices | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Try to find an associated cuda graph | 
					
						
							|  |  |  |         bs = input_ids.shape[0] | 
					
						
							|  |  |  |         sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) | 
					
						
							|  |  |  |         if sorted_padded_bs: | 
					
						
							|  |  |  |             # Get associated cuda graph | 
					
						
							|  |  |  |             cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             cuda_graph = None | 
					
						
							|  |  |  |         if ( | 
					
						
							|  |  |  |             cu_seqlen_prefill is not None | 
					
						
							|  |  |  |             or cuda_graph is None | 
					
						
							|  |  |  |             # Only run cuda graphs when there's no images. | 
					
						
							|  |  |  |             or batch.cross_attention_states is not None | 
					
						
							|  |  |  |         ): | 
					
						
							|  |  |  |             if PREFIX_CACHING: | 
					
						
							|  |  |  |                 block_tables = block_tables_to_ragged( | 
					
						
							|  |  |  |                     block_tables=block_tables, | 
					
						
							|  |  |  |                     input_lengths=batch.input_lengths, | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |                     cache_lengths=batch.cache_lengths, | 
					
						
							| 
									
										
										
										
											2024-10-25 21:10:00 +00:00
										 |  |  |                     input_lengths_tensor=batch.input_lengths_tensor, | 
					
						
							|  |  |  |                     cache_lengths_tensor=batch.cache_lengths_tensor, | 
					
						
							|  |  |  |                     max_current_length=batch.max_current_length, | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  |                 ) | 
					
						
							|  |  |  |             with self._forward_context( | 
					
						
							|  |  |  |                 block_tables=block_tables, | 
					
						
							|  |  |  |                 cu_seqlen_prefill=cu_seqlen_prefill, | 
					
						
							|  |  |  |                 input_lengths_tensor=input_lengths, | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |                 cache_lengths_tensor=cache_lengths_tensor, | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  |             ): | 
					
						
							|  |  |  |                 seqlen = Seqlen( | 
					
						
							|  |  |  |                     input_lengths=input_lengths, | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |                     cache_lengths=cache_lengths_tensor, | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  |                     cu_seqlen_q=cu_seqlen_prefill, | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |                     max_q=batch.max_input_length, | 
					
						
							|  |  |  |                     max_k=batch.max_current_length, | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if batch.pixel_values is not None: | 
					
						
							|  |  |  |                     cross_attention_states = self.model.vision_forward( | 
					
						
							|  |  |  |                         pixel_values=batch.pixel_values, | 
					
						
							|  |  |  |                         aspect_ratio_ids=batch.aspect_ratio_ids, | 
					
						
							|  |  |  |                         aspect_ratio_mask=batch.aspect_ratio_mask, | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                     batch.cross_attention_states = cross_attention_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 cross_attention_states = batch.cross_attention_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 logits, speculative_logits = self.model.forward( | 
					
						
							|  |  |  |                     input_ids=input_ids, | 
					
						
							|  |  |  |                     position_ids=position_ids, | 
					
						
							|  |  |  |                     cu_seqlen_prefill=cu_seqlen_prefill, | 
					
						
							|  |  |  |                     kv_cache=kv_cache, | 
					
						
							|  |  |  |                     block_tables=block_tables, | 
					
						
							|  |  |  |                     slots=slots, | 
					
						
							|  |  |  |                     seqlen=seqlen, | 
					
						
							|  |  |  |                     max_s=max_s, | 
					
						
							|  |  |  |                     prefill_cache_indices=batch.prefill_cache_indices, | 
					
						
							|  |  |  |                     lm_head_indices=lm_head_indices, | 
					
						
							|  |  |  |                     cross_attention_states=cross_attention_states, | 
					
						
							|  |  |  |                     adapter_data=adapter_data, | 
					
						
							|  |  |  |                     image_indices=batch.image_indices[:], | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 if batch.prefill_cache_indices is not None: | 
					
						
							|  |  |  |                     batch.prefill_cache_indices = None | 
					
						
							|  |  |  |                 if batch.pixel_values is not None: | 
					
						
							|  |  |  |                     batch.pixel_values = None | 
					
						
							|  |  |  |                 return logits, speculative_logits | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Copy inputs to the static inputs of the cuda graph | 
					
						
							|  |  |  |         # Static inputs are potentially padded | 
					
						
							|  |  |  |         cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids | 
					
						
							|  |  |  |         cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids | 
					
						
							|  |  |  |         if ATTENTION == "flashinfer": | 
					
						
							|  |  |  |             block_tables = block_tables_to_ragged( | 
					
						
							|  |  |  |                 block_tables=block_tables, | 
					
						
							|  |  |  |                 input_lengths=batch.input_lengths, | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |                 cache_lengths=batch.cache_lengths, | 
					
						
							| 
									
										
										
										
											2024-10-25 21:10:00 +00:00
										 |  |  |                 input_lengths_tensor=batch.input_lengths_tensor, | 
					
						
							|  |  |  |                 cache_lengths_tensor=batch.cache_lengths_tensor, | 
					
						
							|  |  |  |                 max_current_length=batch.max_current_length, | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  |             ) | 
					
						
							|  |  |  |             cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             cuda_graph["block_tables"][ | 
					
						
							|  |  |  |                 : block_tables.shape[0], : block_tables.shape[1] | 
					
						
							|  |  |  |             ] = block_tables | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # XXX: This is working only because block 0 is reserved for the healthcheck | 
					
						
							|  |  |  |         # so it doesn't matter if we override it with bogus values. | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  |         cuda_graph["slots"].fill_(0) | 
					
						
							|  |  |  |         cuda_graph["slots"][: slots.shape[0]] = slots | 
					
						
							|  |  |  |         cuda_graph["input_lengths"].zero_() | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |         cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths | 
					
						
							|  |  |  |         cuda_graph["cache_lengths"].zero_() | 
					
						
							|  |  |  |         cuda_graph["cache_lengths"][ | 
					
						
							|  |  |  |             : cache_lengths_tensor.shape[0] | 
					
						
							|  |  |  |         ] = cache_lengths_tensor | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 10:49:33 +00:00
										 |  |  |         with self._forward_context( | 
					
						
							|  |  |  |             block_tables=cuda_graph["block_tables"], | 
					
						
							|  |  |  |             cu_seqlen_prefill=None, | 
					
						
							|  |  |  |             input_lengths_tensor=cuda_graph["input_lengths"], | 
					
						
							|  |  |  |             cache_lengths_tensor=cuda_graph["cache_lengths"], | 
					
						
							|  |  |  |             state=cuda_graph["state"], | 
					
						
							|  |  |  |         ): | 
					
						
							|  |  |  |             # Replay the graph | 
					
						
							|  |  |  |             cuda_graph["graph"].replay() | 
					
						
							| 
									
										
										
										
											2024-10-02 09:22:13 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Slice output to the correct shape | 
					
						
							|  |  |  |         speculative_logits = ( | 
					
						
							|  |  |  |             cuda_graph["speculative_logits"][:bs] | 
					
						
							|  |  |  |             if cuda_graph["speculative_logits"] is not None | 
					
						
							|  |  |  |             else None | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         logits = cuda_graph["logits"][:bs] | 
					
						
							|  |  |  |         return logits, speculative_logits |