Starting to get there.

This commit is contained in:
Nicolas Patry 2024-09-28 00:38:23 +02:00
parent 85771989d6
commit ef4fa3ea7c
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 101 additions and 265 deletions

View File

@ -22,7 +22,6 @@ from torch import nn
import flash_attn_2_cuda import flash_attn_2_cuda
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -734,6 +733,7 @@ class MllamaTextCrossAttention(nn.Module):
cu_seqlen_k, cu_seqlen_k,
max_q, max_q,
max_k, max_k,
indices,
) = cross_attention_states ) = cross_attention_states
key_states = self.k_proj(cross_attention_states) key_states = self.k_proj(cross_attention_states)
@ -862,6 +862,8 @@ class FlashLlamaCrossLayer(torch.nn.Module):
return hidden_states, residual return hidden_states, residual
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
# indices = cross_attention_states[-1]
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
out_hidden_states = hidden_states[:] out_hidden_states = hidden_states[:]
hidden_states = hidden_states[:] hidden_states = hidden_states[:]
@ -892,115 +894,6 @@ class FlashLlamaCrossLayer(torch.nn.Module):
return hidden_states, None return hidden_states, None
class MllamaTextSelfAttention(nn.Module):
def __init__(self, *, prefix, config, weights, layer_idx):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
self.num_key_value_heads // weights.process_group.size()
)
self.layer_idx = layer_idx
self.qkv_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
past_key_value=None,
cache_position=None,
**kwargs,
):
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_key_value_heads,
self.head_dim * self.num_key_value_heads,
],
dim=2,
)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
# TODO
# attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
class MllamaTextRMSNorm(nn.Module): class MllamaTextRMSNorm(nn.Module):
def __init__(self, weight, eps): def __init__(self, weight, eps):
@ -1026,144 +919,6 @@ class MllamaTextRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT
class MllamaSelfAttentionDecoderLayer(nn.Module):
def __init__(self, *, prefix, config, weights, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MllamaTextSelfAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_idx=layer_idx,
)
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.45
image_indices: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class MllamaRotaryEmbedding(nn.Module):
def __init__(
self,
*,
config,
weights,
):
super().__init__()
device = weights.device
self.rope_type = config.rope_scaling["rope_type"]
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
inv_freq.to(device=device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer(
"inv_freq", inv_freq, persistent=False
) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if (
seq_len < self.original_max_seq_len
and self.max_seq_len_cached > self.original_max_seq_len
): # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
1, 2
)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class MllamaForConditionalGeneration(nn.Module): class MllamaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
@ -1192,14 +947,15 @@ class MllamaForConditionalGeneration(nn.Module):
"`aspect_ratio_ids` must be provided if `pixel_values` is provided" "`aspect_ratio_ids` must be provided if `pixel_values` is provided"
) )
# logger.info(f"PIxel values {pixel_values.shape}") # logger.info(f"PIxel values {pixel_values.shape}")
batch_size = pixel_values.shape[0]
vision_states = self.vision_model( vision_states = self.vision_model(
pixel_values, aspect_ratio_ids, aspect_ratio_mask pixel_values, aspect_ratio_ids, aspect_ratio_mask
) )
cross_attention_states = self.multi_modal_projector(vision_states).reshape( cross_attention_states = self.multi_modal_projector(vision_states).reshape(
-1, vision_states.shape[-2], self.hidden_size -1, vision_states.shape[-2], self.hidden_size
) )
n, m, h = cross_attention_states.shape _, _, h = cross_attention_states.shape
cross_attention_states = cross_attention_states.view(1, n * m, h) cross_attention_states = cross_attention_states.view(batch_size, -1, h)
# logger.info(f"cross {cross_attention_states.shape}") # logger.info(f"cross {cross_attention_states.shape}")
return cross_attention_states return cross_attention_states
@ -1219,7 +975,7 @@ class MllamaForConditionalGeneration(nn.Module):
cross_attention_states: Optional[torch.Tensor], cross_attention_states: Optional[torch.Tensor],
image_indices, image_indices,
): ):
# if cross_attention_mask is not None: # if cross_att_sention_mask is not None:
# cross_attention_mask, full_text_row_masked_out_mask = ( # cross_attention_mask, full_text_row_masked_out_mask = (
# _prepare_cross_attention_mask( # _prepare_cross_attention_mask(
# cross_attention_mask, # cross_attention_mask,
@ -1237,24 +993,59 @@ class MllamaForConditionalGeneration(nn.Module):
# ] # ]
if cross_attention_states is not None: if cross_attention_states is not None:
seqlen_q = input_ids.shape[0] seqlen_q = len(image_indices)
n_images = cross_attention_states.shape[0]
seqlen_k = cross_attention_states.shape[1] seqlen_k = cross_attention_states.shape[1]
device = cross_attention_states.device
if cu_seqlen_prefill is not None:
# raise RuntimeError("TODO")
offset = 0
cu_q = []
indices = []
for index in image_indices:
cu_q.append(offset)
length = seqlen.input_lengths[index]
input_ids_offset = seqlen.cu_seqlen_q[index]
indices.extend(range(input_ids_offset, input_ids_offset + length))
offset += length
cu_q.append(offset)
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
max_q = cu_seqlen_q[-1].item()
max_k = seqlen_k
else:
cu_seqlen_q = torch.arange(
seqlen_q + 1, device=device, dtype=torch.int32
)
seqlen_k = cross_attention_states.shape[1]
n_images = cross_attention_states.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
max_q = seqlen_q
max_k = seqlen_k
indices = image_indices[:]
device = input_ids.device
cu_seqlen_q = torch.Tensor([0, seqlen_q]).to(
dtype=torch.int32, device=device
)
cu_seqlen_k = torch.Tensor([0, seqlen_k]).to(
dtype=torch.int32, device=device
)
max_q = seqlen_q
max_k = seqlen_k
cross_attention_states = ( cross_attention_states = (
cross_attention_states, cross_attention_states,
cu_seqlen_q, cu_seqlen_q,
cu_seqlen_k, cu_seqlen_k,
max_q, max_q,
max_k, max_k,
indices,
) )
outputs = self.text_model( outputs = self.text_model(

View File

@ -1,7 +1,7 @@
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
import torch import torch
from typing import Iterable from typing import Iterable, List
from text_generation_server.pb.generate_pb2 import Request from text_generation_server.pb.generate_pb2 import Request
from dataclasses import dataclass from dataclasses import dataclass
@ -20,10 +20,54 @@ tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class MllamaCausalLMBatch(VlmCausalLMBatch): class MllamaCausalLMBatch(VlmCausalLMBatch):
image_indices: List[int] = 42
aspect_ratio_ids: Optional[torch.Tensor] = None aspect_ratio_ids: Optional[torch.Tensor] = None
aspect_ratio_mask: Optional[torch.Tensor] = None aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None cross_attention_states: Optional[torch.Tensor] = None
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super().concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
offset = 0
image_indices = []
attention_states = []
for b in batches:
attention_states.append(b.cross_attention_states)
image_indices.extend([i + offset for i in b.image_indices])
offset += len(b.image_indices)
batch.cross_attention_states = torch.cat(attention_states, dim=0)
batch.image_indices = image_indices
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
assert self.image_indices is not None
batch = super().filter(request_ids)
assert self.image_indices is not None
indices = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
offset = 0
new_image_indices = []
prev_i = None
for i in self.image_indices:
if i in indices:
new_image_indices.append(offset)
if i != prev_i:
offset += 1
prev_i = i
batch.image_indices = new_image_indices
batch.cross_attention_states = self.cross_attention_states[indices]
assert offset <= batch.cross_attention_states.shape[0]
return batch
@classmethod @classmethod
def batch_tokenized_inputs( def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config cls, requests: Iterable[Request], tokenizer, processor, config
@ -115,5 +159,6 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.pixel_values = None batch.pixel_values = None
batch.aspect_ratio_ids = None batch.aspect_ratio_ids = None
batch.aspect_ratio_mask = None batch.aspect_ratio_mask = None
batch.image_indices = None batch.image_indices = []
assert batch.image_indices is not None
return batch return batch

View File

@ -141,7 +141,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches): def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch = super().concatenate(batches)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
@ -402,7 +402,7 @@ class VlmCausalLM(FlashCausalLM):
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
adapter_data=adapter_data, adapter_data=adapter_data,
image_indices=batch.image_indices, image_indices=batch.image_indices[:],
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None