This commit is contained in:
OlivierDehaene 2023-05-05 15:27:08 +02:00
parent f6df8db680
commit 1cbc5c633e
6 changed files with 180 additions and 111 deletions

View File

@ -554,6 +554,7 @@ class FlashLlamaModel(torch.nn.Module):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
@ -575,15 +576,11 @@ class FlashLlamaModel(torch.nn.Module):
)
)
layer_past_present_indices = None
cu_seqlens_q = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
slice_past_index = None
# Get rotary cos and sin for this forward
@ -650,6 +647,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
@ -658,6 +656,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values,
pre_allocate_past_size,

View File

@ -617,6 +617,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
@ -638,15 +639,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
)
)
layer_past_present_indices = None
cu_seqlens_q = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
slice_past_index = None
# Get rotary cos and sin for this forward
@ -726,6 +723,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
@ -734,6 +732,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values,
pre_allocate_past_size,

View File

@ -37,11 +37,14 @@ class FlashCausalLMBatch(Batch):
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
# cumulative sequence lengths
cu_seqlens: torch.Tensor
# cumulative query sequence lengths, only used in decode
cu_seqlens_q: Optional[torch.Tensor]
max_seqlen: int
# past key values, only used in decode
past_key_values: Optional[torch.Tensor]
max_seqlen: int
# All tokens
all_input_ids: List[List[int]]
@ -128,8 +131,9 @@ class FlashCausalLMBatch(Batch):
cumulative_length += input_length
max_tokens += input_length + max_new_tokens
# Create tensors on device
input_ids = torch.tensor(
np.concatenate(all_input_ids), dtype=torch.int32, device=device
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
position_ids = torch.tensor(
np.concatenate(position_ids), dtype=torch.int32, device=device
@ -172,9 +176,13 @@ class FlashCausalLMBatch(Batch):
# New values after filtering
requests_idx_mapping = {}
input_ids = []
position_ids = []
cu_seqlens = [0]
input_ids = self.input_ids.new_empty(len(requests))
position_ids = self.position_ids.new_empty(len(requests))
# Create on CPU to only move to GPU once instead of at every copy
cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32)
cu_seqlens_q = torch.arange(
0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
)
max_seqlen = 0
past_key_values = []
@ -197,16 +205,18 @@ class FlashCausalLMBatch(Batch):
# Get length
request_input_length = self.input_lengths[idx]
input_ids.append(self.input_ids[idx])
position_ids.append(self.position_ids[idx])
cu_seqlens.append(cumulative_length + request_input_length)
max_seqlen = max(max_seqlen, request_input_length)
# True index for past
past_key_values.append(self.past_key_values[2 * idx])
# Copy tensors (GPU)
input_ids[i] = self.input_ids[idx]
position_ids[i] = self.position_ids[idx]
if not single_request:
# Add one padding
past_key_values.append(self.past_pad)
# Copy to tensor (CPU)
cu_seqlens[i + 1] = cumulative_length + request_input_length
max_seqlen = max(max_seqlen, request_input_length)
# Slice from past
past_key_values.append(
self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
)
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
@ -227,7 +237,7 @@ class FlashCausalLMBatch(Batch):
if single_request:
# Preallocate tensor for bs = 1 case
past_key_values = torch.nn.functional.pad(
past_key_values = F.pad(
past_key_values[0],
(
0,
@ -241,15 +251,21 @@ class FlashCausalLMBatch(Batch):
- stopping_criterias[0].current_tokens,
),
)
else:
# Cat all past
past_key_values = torch.cat(past_key_values, dim=1)
# Move to GPU now that we have the whole tensor
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
return FlashCausalLMBatch(
batch_id=self.batch_id,
past_pad=self.past_pad,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
@ -269,9 +285,16 @@ class FlashCausalLMBatch(Batch):
requests = []
requests_idx_mapping = {}
input_ids = []
position_ids = []
total_batch_size = sum([len(b) for b in batches])
device = batches[0].input_ids.device
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
cu_seqlens = [0]
cu_seqlens_q = torch.arange(
0, total_batch_size + 1, device=device, dtype=torch.int32
)
max_seqlen = 0
past_key_values = []
@ -300,22 +323,25 @@ class FlashCausalLMBatch(Batch):
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
input_ids.extend(batch.input_ids)
position_ids.extend(batch.position_ids)
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
# Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
# Add cumulative lengths of all previous inputs
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
max_seqlen = max(max_seqlen, batch.max_seqlen)
if len(batch) != 1:
past_key_values.extend(batch.past_key_values)
past_key_values.append(batch.past_key_values)
else:
# past was pre-allocated for this batch
# We need to slice to remove the padding
past_key_values.append(
batch.past_key_values[:, : batch.input_lengths[0]]
)
# Add one padding
past_key_values.append(batch.past_pad)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
@ -332,14 +358,19 @@ class FlashCausalLMBatch(Batch):
cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
# Cat past
past_key_values = torch.cat(past_key_values, dim=1)
# Create final tensor on GPU
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
past_pad=batches[0].past_pad,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
@ -367,7 +398,7 @@ class FlashCausalLM(Model):
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
dtype = torch.float16
else:
raise NotImplementedError("FlashCausalLM is only available on GPU")
@ -429,7 +460,6 @@ class FlashCausalLM(Model):
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
# Shortcut when batch_size == 1
if prefill and len(batch) == 1:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
@ -450,15 +480,62 @@ class FlashCausalLM(Model):
)
if prefill:
# Compute logprobs for the whole batch
prefill_logprobs_tensor = torch.log_softmax(out, -1)
else:
prefill_logprobs_tensor = None
if len(batch) > 1:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))
# Used to slice next batch past
past_indices = []
prefill_logprobs = []
next_token_logprobs = []
# Create batch.cu_seqlens_q for decode
batch.cu_seqlens_q = torch.arange(
0, len(batch) + 1, device=self.device, dtype=torch.int32
)
next_input_ids = batch.input_ids.new_empty(len(batch))
next_position_ids = batch.position_ids.new_empty(len(batch))
else:
prefill_logprobs = None
next_input_ids = batch.input_ids
next_position_ids = batch.position_ids
next_token_logprobs = out.new_empty(len(batch))
# Prepare past for next decode
if len(batch) > 1:
# Used to slice next batch past
past_indices = torch.empty(
present.shape[1], dtype=torch.int64, device=self.device
)
batch.past_key_values = present.new_empty(
(
present.shape[0],
present.shape[1] + len(batch.requests),
*present.shape[2:],
)
)
# It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
# and will run asynchronously while we do the next for loop
cumulative_length = 0
for i, input_length in enumerate(batch.input_lengths):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
# Indices to copy present at the correct place in past_key_values
past_indices[start_index:end_index] = torch.arange(
start_index + i,
end_index + i,
dtype=torch.int64,
device=self.device,
)
cumulative_length += input_length
# Copy from present to past_key_values
batch.past_key_values[:, past_indices] = present
# Initialize past_key_values in prefill for len(batch) == 1
elif prefill:
# present is already pre-padded
batch.past_key_values = present
# Cumulative length
cumulative_length = 0
@ -475,6 +552,10 @@ class FlashCausalLM(Model):
batch.all_input_ids,
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync
# It is faster if we delay this sync for the maximum amount of time
# For each member of the batch
for i, (
input_length,
@ -491,23 +572,32 @@ class FlashCausalLM(Model):
# out is of shape [cumulative_sequence_lengths, vocab_size]
# only take last token logit
logits = out[end_index - 1 : end_index]
all_input_ids_tensor = F.pad(
batch.input_ids[start_index:end_index],
(0, stopping_criteria.max_new_tokens),
# Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
all_input_ids_tensor = batch.input_ids.new_empty(
input_length + stopping_criteria.max_new_tokens
)
# Copy from batch.input_ids to all_input_ids_tensor
all_input_ids_tensor[:input_length] = batch.input_ids[
start_index:end_index
]
batch.all_input_ids_tensor.append(all_input_ids_tensor)
batch.position_ids[i] = input_length
prefill_logprobs.append(
prefill_logprobs_tensor[start_index:end_index]
.gather(
1,
all_input_ids_tensor[1:input_length]
.unsqueeze(1)
.to(torch.int64),
)
.squeeze(1)[:-1]
)
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if len(batch) > 1:
prefill_tokens_indices[
start_index : end_index - 1
] = batch.input_ids[start_index + 1 : end_index]
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : end_index
]
else:
# Decode mode
# out is of shape [batch_size, vocab_size]
@ -519,54 +609,36 @@ class FlashCausalLM(Model):
next_token_id, logprob = next_token_chooser(
all_input_ids_tensor[None, :input_length], logits
)
# Add to all_input_ids_tensor
next_token_id_squeezed = next_token_id.squeeze()
all_input_ids_tensor[input_length] = next_token_id_squeezed
past_indices.extend([j for j in range(start_index + i, end_index + i)])
batch.input_ids[i] = next_token_id_squeezed
next_token_logprobs.append(logprob[-1, next_token_id])
# Set values
next_input_ids[i] = next_token_id_squeezed
next_token_logprobs[i] = logprob[-1, next_token_id].view(1)
cumulative_length += input_length
if prefill:
batch.input_ids = batch.input_ids[: len(batch)]
batch.position_ids = batch.position_ids[: len(batch)]
batch.cu_seqlens_q = torch.arange(
0, len(batch) + 1, device=self.device, dtype=torch.int32
)
else:
batch.position_ids += 1
# Initialize past_key_values in prefill
if prefill and len(batch) == 1:
# present is already pre-padded
batch.past_key_values = present
# Set values in batch
batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
if len(batch) > 1:
prefill_logprobs = torch.cat(prefill_logprobs) if prefill else None
next_token_logprobs = torch.cat(next_token_logprobs)
batch.past_key_values = present.new_empty(
(
present.shape[0],
present.shape[1] + len(batch.requests),
*present.shape[2:],
)
if prefill:
# Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1)
prefill_logprobs = torch.gather(
prefill_logprobs_tensor, 1, prefill_tokens_indices.unsqueeze(1)
)
batch.past_key_values[:, past_indices] = present
# GPU <-> CPU sync
prefill_logprobs = prefill_logprobs.squeeze(1).to("cpu").numpy()
prefill_logprobs = prefill_logprobs.to("cpu") if prefill else None
next_token_logprobs = next_token_logprobs.to("cpu")
else:
prefill_logprobs = prefill_logprobs[0] if prefill else None
next_token_logprobs = next_token_logprobs[0]
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.to("cpu").numpy()
next_token_ids = batch.input_ids.to("cpu").numpy()
next_token_ids = batch.input_ids.to("cpu")
prefill_logprobs_cumulative_length = 0
cumulative_length = 0
# Zipped iterator
iterator = zip(
@ -595,10 +667,11 @@ class FlashCausalLM(Model):
next_token_id,
next_token_logprob,
) in enumerate(iterator):
next_token_id_item = next_token_id.item()
start_index = cumulative_length
end_index = cumulative_length + input_length
# Append next token to all tokens
all_input_ids.append(next_token_id_item)
all_input_ids.append(next_token_id)
# Generated token
next_token_text, offset, token_offset = self.decode_token(
@ -609,7 +682,7 @@ class FlashCausalLM(Model):
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_item,
next_token_id,
next_token_text,
)
@ -633,11 +706,10 @@ class FlashCausalLM(Model):
# Prefill
if prefill:
start_index = prefill_logprobs_cumulative_length
end_index = prefill_logprobs_cumulative_length + input_length - 1
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[start_index:end_index].tolist()
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1
].tolist()
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
@ -647,18 +719,16 @@ class FlashCausalLM(Model):
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts
)
prefill_logprobs_cumulative_length += input_length - 1
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_item,
next_token_logprob.item(),
next_token_id,
next_token_logprob,
next_token_text,
next_token_id_item in self.all_special_ids,
next_token_id in self.all_special_ids,
generated_text,
)
@ -670,7 +740,8 @@ class FlashCausalLM(Model):
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.all_input_ids[i] = all_input_ids
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
batch.max_seqlen = batch.max_seqlen + 1
cumulative_length += input_length
# No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None

View File

@ -32,7 +32,7 @@ class FlashLlama(FlashCausalLM):
self.past_pad = None
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
dtype = torch.float16
else:
raise NotImplementedError("FlashLlama is only available on GPU")
@ -152,7 +152,7 @@ class FlashLlamaSharded(FlashLlama):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
dtype = torch.float16
else:
raise NotImplementedError("FlashLlama is only available on GPU")

View File

@ -38,7 +38,7 @@ class FlashNeoXSharded(FlashNeoX):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
dtype = torch.float16
else:
raise NotImplementedError("FlashNeoX is only available on GPU")

View File

@ -31,7 +31,7 @@ class FlashSantacoder(FlashCausalLM):
self.past_pad = None
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
dtype = torch.float16
else:
raise NotImplementedError("FlashSantacoder is only available on GPU")
@ -178,7 +178,7 @@ class FlashSantacoderSharded(FlashSantacoder):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
dtype = torch.float16
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")