wip not faster

This commit is contained in:
OlivierDehaene 2024-01-25 15:26:51 +01:00
parent bf700e7eef
commit ef99678798
5 changed files with 147 additions and 96 deletions

View File

@ -1,4 +1,4 @@
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3 flash_att_v2_commit_cuda := 54e80a3829c6d2337570d01e78ebd9529c02d342
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69 flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69

View File

@ -1,9 +1,11 @@
import math import math
import torch import torch
import os
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
BLOCK_SIZE: int = 16 USE_VLLM = os.getenv("USE_VLLM", "False") == "True"
BLOCK_SIZE: int = 256 if not USE_VLLM else 16
# Will be set in warmup # Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None CACHE_MANAGER: Optional["CacheManager"] = None
@ -26,15 +28,22 @@ class CacheManager:
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size x = self.block_size // element_size
if USE_VLLM:
k_shape = (num_blocks, num_heads, head_size // x, self.block_size, x)
v_shape = (num_blocks, num_heads, head_size, self.block_size)
else:
k_shape = (num_blocks, BLOCK_SIZE, num_heads, head_size)
v_shape = (num_blocks, BLOCK_SIZE, num_heads, head_size)
self.kv_cache = [ self.kv_cache = [
( (
torch.empty( torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x), k_shape,
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),
torch.empty( torch.empty(
(num_blocks, num_heads, head_size, self.block_size), v_shape,
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),

View File

@ -17,6 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import torch import torch
import torch.distributed import torch.distributed
@ -40,26 +41,26 @@ from text_generation_server.utils.layers import (
class LlamaConfig(PretrainedConfig): class LlamaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size=32000, vocab_size=32000,
hidden_size=4096, hidden_size=4096,
intermediate_size=11008, intermediate_size=11008,
num_hidden_layers=32, num_hidden_layers=32,
num_attention_heads=32, num_attention_heads=32,
num_key_value_heads=None, num_key_value_heads=None,
hidden_act="silu", hidden_act="silu",
max_position_embeddings=2048, max_position_embeddings=2048,
initializer_range=0.02, initializer_range=0.02,
rms_norm_eps=1e-6, rms_norm_eps=1e-6,
use_cache=True, use_cache=True,
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
pretraining_tp=1, pretraining_tp=1,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_scaling=None, rope_scaling=None,
rope_theta=10000.0, rope_theta=10000.0,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
@ -139,10 +140,10 @@ def _load_gqa(config, prefix: str, weights):
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
prefix: str, prefix: str,
config, config,
weights, weights,
): ):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
@ -156,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module):
device=weights.device, device=weights.device,
) )
self.softmax_scale = self.head_size**-0.5 self.softmax_scale = self.head_size ** -0.5
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
@ -165,7 +166,7 @@ class FlashLlamaAttention(torch.nn.Module):
) )
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
@ -182,16 +183,16 @@ class FlashLlamaAttention(torch.nn.Module):
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, kv = qkv.split(
@ -204,17 +205,24 @@ class FlashLlamaAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) use_vllm = os.getenv("USE_VLLM", "False") == "True"
paged_attention.reshape_and_cache( if use_vllm:
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots paged_attention.reshape_and_cache(
) kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
# output tensor # output tensor
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
if not use_vllm:
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[:, 0]
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[:, 1]
# flash attention # flash attention
flash_attn.attention( flash_attn.attention(
query, query,
@ -227,17 +235,41 @@ class FlashLlamaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention.attention( if not use_vllm:
attn_output, import flash_attn_2_cuda
query, flash_attn_2_cuda.fwd_kvcache(
kv_cache[0], query.unsqueeze(1), # q
kv_cache[1], kv_cache[0], # kcache
self.kv_head_mapping, kv_cache[1], # vcache
self.softmax_scale, torch.select(kv, dim=1, index=0).unsqueeze(1), # k
block_tables, torch.select(kv, dim=1, index=1).unsqueeze(1), # v
input_lengths, input_lengths, # seqlens_k
max_s, self.rotary_emb._cos_cached, # rotary_cos
) self.rotary_emb._sin_cached, # rotary_sin
# None,None,
None, # cache_batch_idx
block_tables, # block_tables
None, # alibi_slopes
attn_output.unsqueeze(1), # out
self.softmax_scale, # softmax_scale
True, # is_causal
-1, # window_size_left
0, # window_size_right
False, # is_rotary_interleaved
0, # num_splits
)
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -271,7 +303,7 @@ class LlamaMLP(nn.Module):
bias=False, bias=False,
) )
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
def forward(self, hidden_states): def forward(self, hidden_states):
@ -299,17 +331,17 @@ class FlashLlamaLayer(nn.Module):
) )
def forward( def forward(
self, self,
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -367,23 +399,27 @@ class FlashLlamaModel(torch.nn.Module):
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( use_vllm = os.getenv("USE_VLLM", "False") == "True"
position_ids, max_s, hidden_states.dtype if cu_seqlen_prefill is not None or use_vllm:
) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
else:
cos, sin = None, None
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):

View File

@ -127,6 +127,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
) )
async def Decode(self, request, context): async def Decode(self, request, context):
from torch.profiler import profile, ProfilerActivity
start = time.time_ns() start = time.time_ns()
if len(request.batches) == 0: if len(request.batches) == 0:
raise ValueError("Must provide at least one batch") raise ValueError("Must provide at least one batch")
@ -149,7 +151,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = batches[0] batch = batches[0]
concat_ns = None concat_ns = None
generations, next_batch, timings = self.model.generate_token(batch) with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
generations, next_batch, timings = self.model.generate_token(batch)
prefill_prof.export_chrome_trace("new_decode.json")
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(

View File

@ -57,7 +57,7 @@ except ImportError as e:
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()): for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name( if "MI210" not in torch.cuda.get_device_name(
idx idx
) and "MI250" not in torch.cuda.get_device_name(idx): ) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError( raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
@ -68,27 +68,29 @@ except ImportError as e:
def attention( def attention(
q, q,
k, k,
v, v,
out, out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
): ):
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
if HAS_FLASH_ATTN_V2_CUDA: if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q, # q
k, k, # k
v, v, # v
out, out, # out
cu_seqlens, cu_seqlens, # cu_seqlens_q
cu_seqlens, cu_seqlens, # cu_seqlens_k
max_s, None,
None,
max_s, #
max_s, max_s,
0.0, 0.0,
softmax_scale, softmax_scale,