mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
working
This commit is contained in:
parent
c9e7471742
commit
bfd6928c3e
@ -136,7 +136,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
@ -153,12 +153,12 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(kv[:, 0], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[layer_past_present_indices] = kv
|
layer_past[past_present_indices] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
@ -167,8 +167,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda_modif.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, 0],
|
torch.select(kv, dim=1, index=0),
|
||||||
kv[:, 1],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
start_seq,
|
||||||
end_seq,
|
end_seq,
|
||||||
@ -187,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = kv
|
layer_past[past_present_indices] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
@ -196,8 +196,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda_modif.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, 0],
|
torch.select(kv, dim=1, index=0),
|
||||||
kv[:, 1],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
@ -271,7 +271,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
@ -290,7 +290,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(kv[:, :, 0], cos, sin)
|
self.rotary_emb(kv[:, :, 0], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if past_present_indices is None:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = kv
|
layer_past[...] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
@ -323,7 +323,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = kv
|
layer_past[past_present_indices] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = (
|
kv = (
|
||||||
layer_past.unsqueeze(2)
|
layer_past.unsqueeze(2)
|
||||||
@ -430,7 +430,7 @@ class FlashRWLayer(nn.Module):
|
|||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
):
|
):
|
||||||
if self.parallel_attn:
|
if self.parallel_attn:
|
||||||
@ -446,7 +446,7 @@ class FlashRWLayer(nn.Module):
|
|||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -469,7 +469,7 @@ class FlashRWLayer(nn.Module):
|
|||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -517,7 +517,7 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
):
|
):
|
||||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||||
@ -531,7 +531,7 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -619,8 +619,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
# Create past tensor
|
# Create past tensor
|
||||||
past_key_values = hidden_states.new_zeros(
|
past_key_values = hidden_states.new_zeros(
|
||||||
(
|
(
|
||||||
len(self.h),
|
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
|
len(self.h),
|
||||||
*self.cache_size,
|
*self.cache_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -646,7 +646,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values[i],
|
past_key_values[:, i],
|
||||||
past_present_indices,
|
past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
)
|
)
|
||||||
|
@ -7,6 +7,7 @@ from typing import Optional
|
|||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda_modif
|
import flash_attn_cuda_modif
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
qkv = self.c_attn(hidden_states)
|
qkv = self.c_attn(hidden_states)
|
||||||
|
|
||||||
@ -166,9 +170,9 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = key_value
|
layer_past[past_present_indices] = key_value
|
||||||
# Expand from 1 to num_heads
|
# Expand from 1 to num_heads
|
||||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda_modif.fwd(
|
||||||
query,
|
query,
|
||||||
key_value[:, 0],
|
torch.select(key_value, dim=1, index=0),
|
||||||
key_value[:, 1],
|
torch.select(key_value, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens,
|
end_seq,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = key_value
|
layer_past[past_present_indices] = key_value
|
||||||
# Expand from 1 to num_heads
|
# Expand from 1 to num_heads
|
||||||
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda_modif.fwd(
|
flash_attn_cuda_modif.fwd(
|
||||||
query,
|
query,
|
||||||
key_value[:, 0],
|
torch.select(key_value, dim=1, index=0),
|
||||||
key_value[:, 1],
|
torch.select(key_value, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
start_seq_q,
|
||||||
cu_seqlens,
|
end_seq_q,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
1,
|
1,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -277,21 +285,27 @@ class Block(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.attn(
|
hidden_states = self.attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||||
@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_present_indices,
|
||||||
|
past_key_values=None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
@ -350,43 +367,37 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
|
assert pre_allocate_past_size is not None
|
||||||
|
|
||||||
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_zeros(
|
||||||
(
|
(
|
||||||
|
pre_allocate_past_size,
|
||||||
len(self.h),
|
len(self.h),
|
||||||
len(hidden_states)
|
|
||||||
if pre_allocate_past_size is None
|
|
||||||
else pre_allocate_past_size,
|
|
||||||
2,
|
2,
|
||||||
1,
|
1,
|
||||||
self.head_size,
|
self.head_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
|
||||||
slice_past_index = len(hidden_states)
|
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Create indices from cumulative sequence lengths
|
prefill = False
|
||||||
layer_past_present_indices = cu_seqlens[1:] - 1
|
|
||||||
slice_past_index = None
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.h):
|
for i, layer in enumerate(self.h):
|
||||||
# We added padding that we now need to slice
|
|
||||||
layer_past_key_values = (
|
|
||||||
past_key_values[i]
|
|
||||||
if slice_past_index is None
|
|
||||||
else past_key_values[i, :slice_past_index]
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past_key_values,
|
torch.select(past_key_values, dim=1, index=i),
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
@ -406,9 +417,12 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
@ -416,9 +430,12 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
hidden_states, present = self.transformer(
|
hidden_states, present = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
@ -186,8 +186,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
prefill_out_cumulative_length += 1
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
request_past_present_indices = np.zeros(input_length + max_new_tokens - 1)
|
request_past_present_indices = torch.arange(cumulative_max_length, cumulative_max_length + input_length, dtype=torch.int64)
|
||||||
request_past_present_indices[:input_length] = 1
|
|
||||||
past_present_indices.append(request_past_present_indices)
|
past_present_indices.append(request_past_present_indices)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
@ -210,10 +209,20 @@ class FlashCausalLMBatch(Batch):
|
|||||||
if len(pb.requests) > 1:
|
if len(pb.requests) > 1:
|
||||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||||
position_ids = torch.cat(position_ids)
|
position_ids = torch.cat(position_ids)
|
||||||
|
|
||||||
|
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
|
||||||
|
|
||||||
|
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
|
||||||
|
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
|
||||||
else:
|
else:
|
||||||
input_ids = all_input_ids[0]
|
input_ids = all_input_ids[0]
|
||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
|
|
||||||
|
past_present_indices = past_present_indices[0]
|
||||||
|
|
||||||
|
start_seq_prefill = start_seq
|
||||||
|
end_seq_prefill = end_seq
|
||||||
|
|
||||||
# Create tensors on device
|
# Create tensors on device
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
all_input_ids_tensor = torch.tensor(
|
all_input_ids_tensor = torch.tensor(
|
||||||
@ -222,19 +231,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||||
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
||||||
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
||||||
|
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64)
|
||||||
if len(pb.requests) > 1:
|
|
||||||
past_present_indices = np.concatenate(past_present_indices)
|
|
||||||
|
|
||||||
start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32)
|
|
||||||
end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32)
|
|
||||||
else:
|
|
||||||
past_present_indices = past_present_indices[0]
|
|
||||||
|
|
||||||
start_seq_prefill = start_seq
|
|
||||||
end_seq_prefill = end_seq
|
|
||||||
|
|
||||||
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.bool)
|
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
@ -298,7 +295,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
indices = []
|
indices = []
|
||||||
|
|
||||||
# past indices to keep
|
# past indices to keep
|
||||||
past_indices = torch.zeros(self.past_key_values.shape[1], dtype=torch.bool, device=device)
|
past_indices = torch.zeros(self.past_key_values.shape[0], dtype=torch.bool, device=device)
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
||||||
@ -352,7 +349,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
past_key_values = self.past_key_values[:, past_indices]
|
past_key_values = self.past_key_values[past_indices]
|
||||||
|
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
start_seq = start_seq.to(device)
|
start_seq = start_seq.to(device)
|
||||||
@ -409,11 +406,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
end_seq_q = start_seq_q + 1
|
end_seq_q = start_seq_q + 1
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
past_key_values = batches[0].past_key_values.new_empty((
|
past_key_values = []
|
||||||
batches[0].past_key_values.shape[0],
|
|
||||||
total_tokens,
|
|
||||||
*batches[0].past_key_values.shape[2:]
|
|
||||||
))
|
|
||||||
|
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
|
||||||
@ -449,11 +442,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_seq[start_index:end_index] = batch.start_seq + max_tokens
|
start_seq[start_index:end_index] = batch.start_seq + max_tokens
|
||||||
end_seq[start_index:end_index] = batch.end_seq + max_tokens
|
end_seq[start_index:end_index] = batch.end_seq + max_tokens
|
||||||
|
|
||||||
past_key_values[
|
|
||||||
:,
|
|
||||||
max_tokens: max_tokens + batch.max_tokens
|
|
||||||
] = batch.past_key_values
|
|
||||||
|
|
||||||
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
||||||
|
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
@ -464,6 +452,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
past_key_values.append(batch.past_key_values)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
@ -480,6 +469,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
past_key_values = torch.cat(past_key_values, dim=0)
|
||||||
past_present_indices = end_seq - 1
|
past_present_indices = end_seq - 1
|
||||||
|
|
||||||
all_input_ids_tensor = torch.zeros(
|
all_input_ids_tensor = torch.zeros(
|
||||||
@ -726,8 +716,8 @@ class FlashCausalLM(Model):
|
|||||||
# Set values in batch
|
# Set values in batch
|
||||||
batch.input_ids = next_input_ids
|
batch.input_ids = next_input_ids
|
||||||
batch.position_ids = next_position_ids + 1
|
batch.position_ids = next_position_ids + 1
|
||||||
batch.past_present_indices = torch.clone(batch.end_seq)
|
batch.past_present_indices = batch.end_seq
|
||||||
batch.end_seq += 1
|
batch.end_seq = batch.end_seq + 1
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
|
Loading…
Reference in New Issue
Block a user