Starting to get there.

This commit is contained in:
Nicolas Patry 2024-09-28 00:38:23 +02:00
parent 85771989d6
commit ef4fa3ea7c
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 101 additions and 265 deletions

View File

@ -22,7 +22,6 @@ from torch import nn
import flash_attn_2_cuda
from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
import torch.nn.functional as F
from text_generation_server.layers import (
@ -734,6 +733,7 @@ class MllamaTextCrossAttention(nn.Module):
cu_seqlen_k,
max_q,
max_k,
indices,
) = cross_attention_states
key_states = self.k_proj(cross_attention_states)
@ -862,6 +862,8 @@ class FlashLlamaCrossLayer(torch.nn.Module):
return hidden_states, residual
if residual is not None:
hidden_states += residual
# indices = cross_attention_states[-1]
if cu_seqlen_prefill is not None:
out_hidden_states = hidden_states[:]
hidden_states = hidden_states[:]
@ -892,115 +894,6 @@ class FlashLlamaCrossLayer(torch.nn.Module):
return hidden_states, None
class MllamaTextSelfAttention(nn.Module):
def __init__(self, *, prefix, config, weights, layer_idx):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
self.num_key_value_heads // weights.process_group.size()
)
self.layer_idx = layer_idx
self.qkv_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
past_key_value=None,
cache_position=None,
**kwargs,
):
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_key_value_heads,
self.head_dim * self.num_key_value_heads,
],
dim=2,
)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
# TODO
# attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
class MllamaTextRMSNorm(nn.Module):
def __init__(self, weight, eps):
@ -1026,144 +919,6 @@ class MllamaTextRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT
class MllamaSelfAttentionDecoderLayer(nn.Module):
def __init__(self, *, prefix, config, weights, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MllamaTextSelfAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_idx=layer_idx,
)
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.45
image_indices: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class MllamaRotaryEmbedding(nn.Module):
def __init__(
self,
*,
config,
weights,
):
super().__init__()
device = weights.device
self.rope_type = config.rope_scaling["rope_type"]
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
inv_freq.to(device=device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer(
"inv_freq", inv_freq, persistent=False
) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if (
seq_len < self.original_max_seq_len
and self.max_seq_len_cached > self.original_max_seq_len
): # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
1, 2
)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class MllamaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
@ -1192,14 +947,15 @@ class MllamaForConditionalGeneration(nn.Module):
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# logger.info(f"PIxel values {pixel_values.shape}")
batch_size = pixel_values.shape[0]
vision_states = self.vision_model(
pixel_values, aspect_ratio_ids, aspect_ratio_mask
)
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
-1, vision_states.shape[-2], self.hidden_size
)
n, m, h = cross_attention_states.shape
cross_attention_states = cross_attention_states.view(1, n * m, h)
_, _, h = cross_attention_states.shape
cross_attention_states = cross_attention_states.view(batch_size, -1, h)
# logger.info(f"cross {cross_attention_states.shape}")
return cross_attention_states
@ -1219,7 +975,7 @@ class MllamaForConditionalGeneration(nn.Module):
cross_attention_states: Optional[torch.Tensor],
image_indices,
):
# if cross_attention_mask is not None:
# if cross_att_sention_mask is not None:
# cross_attention_mask, full_text_row_masked_out_mask = (
# _prepare_cross_attention_mask(
# cross_attention_mask,
@ -1237,24 +993,59 @@ class MllamaForConditionalGeneration(nn.Module):
# ]
if cross_attention_states is not None:
seqlen_q = input_ids.shape[0]
seqlen_q = len(image_indices)
n_images = cross_attention_states.shape[0]
seqlen_k = cross_attention_states.shape[1]
device = cross_attention_states.device
if cu_seqlen_prefill is not None:
# raise RuntimeError("TODO")
offset = 0
cu_q = []
indices = []
for index in image_indices:
cu_q.append(offset)
length = seqlen.input_lengths[index]
input_ids_offset = seqlen.cu_seqlen_q[index]
indices.extend(range(input_ids_offset, input_ids_offset + length))
offset += length
cu_q.append(offset)
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
max_q = cu_seqlen_q[-1].item()
max_k = seqlen_k
else:
cu_seqlen_q = torch.arange(
seqlen_q + 1, device=device, dtype=torch.int32
)
seqlen_k = cross_attention_states.shape[1]
n_images = cross_attention_states.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
max_q = seqlen_q
max_k = seqlen_k
indices = image_indices[:]
device = input_ids.device
cu_seqlen_q = torch.Tensor([0, seqlen_q]).to(
dtype=torch.int32, device=device
)
cu_seqlen_k = torch.Tensor([0, seqlen_k]).to(
dtype=torch.int32, device=device
)
max_q = seqlen_q
max_k = seqlen_k
cross_attention_states = (
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
indices,
)
outputs = self.text_model(

View File

@ -1,7 +1,7 @@
from io import BytesIO
from PIL import Image
import torch
from typing import Iterable
from typing import Iterable, List
from text_generation_server.pb.generate_pb2 import Request
from dataclasses import dataclass
@ -20,10 +20,54 @@ tracer = trace.get_tracer(__name__)
@dataclass
class MllamaCausalLMBatch(VlmCausalLMBatch):
image_indices: List[int] = 42
aspect_ratio_ids: Optional[torch.Tensor] = None
aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super().concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
offset = 0
image_indices = []
attention_states = []
for b in batches:
attention_states.append(b.cross_attention_states)
image_indices.extend([i + offset for i in b.image_indices])
offset += len(b.image_indices)
batch.cross_attention_states = torch.cat(attention_states, dim=0)
batch.image_indices = image_indices
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
assert self.image_indices is not None
batch = super().filter(request_ids)
assert self.image_indices is not None
indices = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
offset = 0
new_image_indices = []
prev_i = None
for i in self.image_indices:
if i in indices:
new_image_indices.append(offset)
if i != prev_i:
offset += 1
prev_i = i
batch.image_indices = new_image_indices
batch.cross_attention_states = self.cross_attention_states[indices]
assert offset <= batch.cross_attention_states.shape[0]
return batch
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
@ -115,5 +159,6 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.pixel_values = None
batch.aspect_ratio_ids = None
batch.aspect_ratio_mask = None
batch.image_indices = None
batch.image_indices = []
assert batch.image_indices is not None
return batch

View File

@ -141,7 +141,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch = super().concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
@ -402,7 +402,7 @@ class VlmCausalLM(FlashCausalLM):
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
adapter_data=adapter_data,
image_indices=batch.image_indices,
image_indices=batch.image_indices[:],
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None