Working for TP, Llama + Mistral

Still unsolved:
- Rust parameter validation (to calculate number of tokens).
- Integration test.
- Validate other text heads.
- Quantization.
This commit is contained in:
Nicolas Patry 2024-04-05 15:27:29 +00:00
parent df4c700828
commit 6c350f2f75
6 changed files with 108 additions and 51 deletions

View File

@ -123,6 +123,7 @@ class CLIPAttention(nn.Module):
f" {self.num_heads})."
)
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout
@ -137,7 +138,7 @@ class CLIPAttention(nn.Module):
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=False,
bias=True,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@ -155,7 +156,7 @@ class CLIPAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
bsz, tgt_len, _ = hidden_states.size()
# get query proj
@ -225,7 +226,7 @@ class CLIPAttention(nn.Module):
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)

View File

@ -342,12 +342,6 @@ class FlashLlamaModel(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
self.layers = nn.ModuleList(
[
FlashLlamaLayer(
@ -384,6 +378,8 @@ class FlashLlamaModel(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -417,6 +413,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
self.model = FlashLlamaModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
@ -447,6 +449,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
slots,
input_lengths,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -348,9 +348,6 @@ class MistralModel(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
MistralLayer(
@ -373,34 +370,7 @@ class MistralModel(torch.nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
return self.with_hidden_states(
hidden_states,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
true_max_s,
prefill_cache_indices,
)
def with_hidden_states(
self,
hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -411,6 +381,7 @@ class MistralModel(torch.nn.Module):
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
):
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -441,6 +412,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.model.embed_tokens", weights=weights
)
self.model = MistralModel(
prefix="model" if not prefix else f"{prefix}.model",
config=config,
@ -480,8 +454,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
input_ids,
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,

View File

@ -142,7 +142,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
vision_config.num_hidden_layers += config.vision_feature_layer + 1
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
@ -195,7 +198,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = self.language_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
@ -290,6 +293,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -5,6 +5,7 @@ from opentelemetry import trace
from typing import Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
@ -36,14 +37,73 @@ def split(string) -> List[Dict[str, str]]:
return parts
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
image_grid_pinpoints = config.image_grid_pinpoints
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
assert image_size % patch_size == 0
npatches = image_size // patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
[height, width],
image_grid_pinpoints,
image_size,
)
import math
height_of_patch = math.ceil(height / width * npatches)
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
# The base patch covers the entire image
base_features = npatches**2
return unpadded_features + newline_features + base_features
if height == 640 and width == 640:
return 2928
return 2634
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor):
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
images = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
@ -52,8 +112,13 @@ class VlmCausalLMBatch(FlashMistralBatch):
if chunk["type"] == "text":
full_text += chunk["content"]
elif chunk["type"] == "image":
full_text += "<image>" * 2928
images.append(chunk["content"])
image = chunk["content"]
image = processor.image_processor.fetch_images(image)
image_input = processor.image_processor(image, return_tensors="pt")
height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config)
full_text += "<image>" * num_features
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
@ -63,9 +128,13 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
images = processor.image_processor.fetch_images(images)
if images:
image_inputs = processor.image_processor(images, return_tensors="pt")
if image_inputs:
image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
}
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
@ -76,11 +145,12 @@ class VlmCausalLMBatch(FlashMistralBatch):
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:

View File

@ -87,6 +87,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch,
self.model.tokenizer,
self.model.processor,
self.model.model.config,
self.model.dtype,
self.model.device,
)
@ -110,6 +111,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch,
self.model.tokenizer,
self.model.processor,
self.model.model.config,
self.model.dtype,
self.model.device,
)