mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Working for TP, Llama + Mistral
Still unsolved: - Rust parameter validation (to calculate number of tokens). - Integration test. - Validate other text heads. - Quantization.
This commit is contained in:
parent
df4c700828
commit
6c350f2f75
@ -123,6 +123,7 @@ class CLIPAttention(nn.Module):
|
|||||||
f" {self.num_heads})."
|
f" {self.num_heads})."
|
||||||
)
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||||
self.scale = self.head_size**-0.5
|
self.scale = self.head_size**-0.5
|
||||||
self.dropout = config.attention_dropout
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
@ -137,7 +138,7 @@ class CLIPAttention(nn.Module):
|
|||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.out_proj",
|
prefix=f"{prefix}.out_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
@ -155,7 +156,7 @@ class CLIPAttention(nn.Module):
|
|||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
|
|
||||||
@ -225,7 +226,7 @@ class CLIPAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
@ -342,12 +342,6 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
|
||||||
prefix=(
|
|
||||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
|
||||||
),
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
@ -384,6 +378,8 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
true_max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
@ -417,6 +413,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=(
|
||||||
|
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||||
|
),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
self.model = FlashLlamaModel(prefix, config, weights)
|
self.model = FlashLlamaModel(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
@ -447,6 +449,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
true_max_s=max_s,
|
||||||
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -348,9 +348,6 @@ class MistralModel(torch.nn.Module):
|
|||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
|
||||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
|
||||||
)
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MistralLayer(
|
MistralLayer(
|
||||||
@ -373,34 +370,7 @@ class MistralModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
input_lengths: torch.Tensor,
|
|
||||||
max_s: int,
|
|
||||||
true_max_s: int,
|
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
|
||||||
return self.with_hidden_states(
|
|
||||||
hidden_states,
|
|
||||||
position_ids,
|
|
||||||
cu_seqlen_prefill,
|
|
||||||
kv_cache,
|
|
||||||
block_tables,
|
|
||||||
slots,
|
|
||||||
input_lengths,
|
|
||||||
max_s,
|
|
||||||
true_max_s,
|
|
||||||
prefill_cache_indices,
|
|
||||||
)
|
|
||||||
|
|
||||||
def with_hidden_states(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -411,6 +381,7 @@ class MistralModel(torch.nn.Module):
|
|||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
|
hidden_states = inputs_embeds
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
@ -441,6 +412,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.model.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
self.model = MistralModel(
|
self.model = MistralModel(
|
||||||
prefix="model" if not prefix else f"{prefix}.model",
|
prefix="model" if not prefix else f"{prefix}.model",
|
||||||
config=config,
|
config=config,
|
||||||
@ -480,8 +454,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||||
|
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
inputs_embeds,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
|
@ -142,7 +142,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
vision_config = config.vision_config
|
vision_config = config.vision_config
|
||||||
# Instead of selecting in hidden_states[-2].
|
# Instead of selecting in hidden_states[-2].
|
||||||
# Instead compute only the n -2 + 1 layers and don't pool
|
# Instead compute only the n -2 + 1 layers and don't pool
|
||||||
|
if config.vision_feature_layer < 0:
|
||||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||||
|
else:
|
||||||
|
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||||
self.vision_tower = load_vision_model(
|
self.vision_tower = load_vision_model(
|
||||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||||
config=config.vision_config,
|
config=config.vision_config,
|
||||||
@ -195,7 +198,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||||
@ -290,6 +293,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
|
true_max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -5,6 +5,7 @@ from opentelemetry import trace
|
|||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.flash_mistral import (
|
from text_generation_server.models.flash_mistral import (
|
||||||
BaseFlashMistral,
|
BaseFlashMistral,
|
||||||
@ -36,14 +37,73 @@ def split(string) -> List[Dict[str, str]]:
|
|||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
|
"""
|
||||||
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (`tuple`):
|
||||||
|
The size of the input image in the format (width, height).
|
||||||
|
grid_pinpoints (`List`):
|
||||||
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||||
|
of the form `(height, width)`.
|
||||||
|
patch_size (`int`):
|
||||||
|
The size of each image patch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: The shape of the image patch grid in the format (width, height).
|
||||||
|
"""
|
||||||
|
if not isinstance(grid_pinpoints, list):
|
||||||
|
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||||
|
|
||||||
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||||
|
return height // patch_size, width // patch_size
|
||||||
|
|
||||||
|
|
||||||
|
def get_number_of_features(height: int, width: int, config) -> int:
|
||||||
|
# From config
|
||||||
|
# Hardcoded for CLIP for now
|
||||||
|
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
||||||
|
image_grid_pinpoints = config.image_grid_pinpoints
|
||||||
|
image_size = config.vision_config.image_size
|
||||||
|
patch_size = config.vision_config.patch_size
|
||||||
|
|
||||||
|
assert image_size % patch_size == 0
|
||||||
|
|
||||||
|
npatches = image_size // patch_size
|
||||||
|
|
||||||
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||||
|
[height, width],
|
||||||
|
image_grid_pinpoints,
|
||||||
|
image_size,
|
||||||
|
)
|
||||||
|
import math
|
||||||
|
|
||||||
|
height_of_patch = math.ceil(height / width * npatches)
|
||||||
|
|
||||||
|
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
|
||||||
|
# They are only added after width
|
||||||
|
newline_features = height_of_patch * num_patch_width
|
||||||
|
# The base patch covers the entire image
|
||||||
|
base_features = npatches**2
|
||||||
|
return unpadded_features + newline_features + base_features
|
||||||
|
if height == 640 and width == 640:
|
||||||
|
return 2928
|
||||||
|
return 2634
|
||||||
|
|
||||||
|
|
||||||
|
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
|
||||||
|
# assert get_number_of_features(640, 640) == 2928
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashMistralBatch):
|
class VlmCausalLMBatch(FlashMistralBatch):
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor):
|
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
images = []
|
image_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
chunks = split(r.inputs)
|
chunks = split(r.inputs)
|
||||||
@ -52,8 +112,13 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
if chunk["type"] == "text":
|
if chunk["type"] == "text":
|
||||||
full_text += chunk["content"]
|
full_text += chunk["content"]
|
||||||
elif chunk["type"] == "image":
|
elif chunk["type"] == "image":
|
||||||
full_text += "<image>" * 2928
|
image = chunk["content"]
|
||||||
images.append(chunk["content"])
|
image = processor.image_processor.fetch_images(image)
|
||||||
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
|
height, width = image_input["image_sizes"][0]
|
||||||
|
num_features = get_number_of_features(height, width, config)
|
||||||
|
full_text += "<image>" * num_features
|
||||||
|
image_inputs.append(image_input)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||||
|
|
||||||
@ -63,9 +128,13 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
batch_tokenized_inputs = tokenizer(
|
batch_tokenized_inputs = tokenizer(
|
||||||
batch_inputs, truncation=True, max_length=max_truncation
|
batch_inputs, truncation=True, max_length=max_truncation
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
images = processor.image_processor.fetch_images(images)
|
if image_inputs:
|
||||||
if images:
|
image_inputs = {
|
||||||
image_inputs = processor.image_processor(images, return_tensors="pt")
|
"pixel_values": torch.cat(
|
||||||
|
[img["pixel_values"] for img in image_inputs], dim=0
|
||||||
|
),
|
||||||
|
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs
|
||||||
@ -76,11 +145,12 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
processor,
|
processor,
|
||||||
|
config,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "VlmCausalLMBatch":
|
||||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||||
pb.requests, tokenizer, processor
|
pb.requests, tokenizer, processor, config
|
||||||
)
|
)
|
||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
|
@ -87,6 +87,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
request.batch,
|
request.batch,
|
||||||
self.model.tokenizer,
|
self.model.tokenizer,
|
||||||
self.model.processor,
|
self.model.processor,
|
||||||
|
self.model.model.config,
|
||||||
self.model.dtype,
|
self.model.dtype,
|
||||||
self.model.device,
|
self.model.device,
|
||||||
)
|
)
|
||||||
@ -110,6 +111,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
request.batch,
|
request.batch,
|
||||||
self.model.tokenizer,
|
self.model.tokenizer,
|
||||||
self.model.processor,
|
self.model.processor,
|
||||||
|
self.model.model.config,
|
||||||
self.model.dtype,
|
self.model.dtype,
|
||||||
self.model.device,
|
self.model.device,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user