wip not faster

This commit is contained in:
OlivierDehaene 2024-01-25 15:26:51 +01:00
parent bf700e7eef
commit ef99678798
5 changed files with 147 additions and 96 deletions

View File

@ -1,4 +1,4 @@
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
flash_att_v2_commit_cuda := 54e80a3829c6d2337570d01e78ebd9529c02d342
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69

View File

@ -1,9 +1,11 @@
import math
import torch
import os
from typing import Optional, List, Tuple
BLOCK_SIZE: int = 16
USE_VLLM = os.getenv("USE_VLLM", "False") == "True"
BLOCK_SIZE: int = 256 if not USE_VLLM else 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None
@ -26,15 +28,22 @@ class CacheManager:
element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
if USE_VLLM:
k_shape = (num_blocks, num_heads, head_size // x, self.block_size, x)
v_shape = (num_blocks, num_heads, head_size, self.block_size)
else:
k_shape = (num_blocks, BLOCK_SIZE, num_heads, head_size)
v_shape = (num_blocks, BLOCK_SIZE, num_heads, head_size)
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x),
k_shape,
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, self.block_size),
v_shape,
dtype=dtype,
device=device,
),

View File

@ -17,6 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import torch.distributed
@ -156,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module):
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
self.softmax_scale = self.head_size ** -0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
@ -204,17 +205,24 @@ class FlashLlamaAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
use_vllm = os.getenv("USE_VLLM", "False") == "True"
if use_vllm:
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
if not use_vllm:
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[:, 0]
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[:, 1]
# flash attention
flash_attn.attention(
query,
@ -226,6 +234,30 @@ class FlashLlamaAttention(torch.nn.Module):
self.softmax_scale,
)
# Decode
else:
if not use_vllm:
import flash_attn_2_cuda
flash_attn_2_cuda.fwd_kvcache(
query.unsqueeze(1), # q
kv_cache[0], # kcache
kv_cache[1], # vcache
torch.select(kv, dim=1, index=0).unsqueeze(1), # k
torch.select(kv, dim=1, index=1).unsqueeze(1), # v
input_lengths, # seqlens_k
self.rotary_emb._cos_cached, # rotary_cos
self.rotary_emb._sin_cached, # rotary_sin
# None,None,
None, # cache_batch_idx
block_tables, # block_tables
None, # alibi_slopes
attn_output.unsqueeze(1), # out
self.softmax_scale, # softmax_scale
True, # is_causal
-1, # window_size_left
0, # window_size_right
False, # is_rotary_interleaved
0, # num_splits
)
else:
paged_attention.attention(
attn_output,
@ -381,9 +413,13 @@ class FlashLlamaModel(torch.nn.Module):
# Get rotary cos and sin for this forward
# Avoid to index in each layer
use_vllm = os.getenv("USE_VLLM", "False") == "True"
if cu_seqlen_prefill is not None or use_vllm:
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
else:
cos, sin = None, None
residual = None
for i, layer in enumerate(self.layers):

View File

@ -127,6 +127,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
async def Decode(self, request, context):
from torch.profiler import profile, ProfilerActivity
start = time.time_ns()
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
@ -149,7 +151,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = batches[0]
concat_ns = None
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
generations, next_batch, timings = self.model.generate_token(batch)
prefill_prof.export_chrome_trace("new_decode.json")
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(

View File

@ -82,13 +82,15 @@ def attention(
if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
q, # q
k, # k
v, # v
out, # out
cu_seqlens, # cu_seqlens_q
cu_seqlens, # cu_seqlens_k
None,
None,
max_s, #
max_s,
0.0,
softmax_scale,