Memory fragmentation added for Causal LM

This commit is contained in:
ankit201 2023-07-09 03:35:47 +00:00
parent 15de7c7ac3
commit 20ca9cf0c3

View File

@ -89,6 +89,7 @@ class CausalLMBatch(Batch):
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
stopping_criteria.max_new_tokens = stopping_criteria.max_new_tokens if stopping_criteria.max_new_tokens < 512 else 512
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
@ -423,6 +424,7 @@ class CausalLMBatch(Batch):
start_index = end_index
past_key_values.append([padded_past_keys, padded_past_values])
torch.cuda.empty_cache()
return cls(
batch_id=batches[0].batch_id,
@ -537,14 +539,17 @@ class CausalLM(Model):
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
logits, past = self.forward(
batch.input_ids,
attention_mask,
batch.position_ids,
batch.past_key_values,
)
try:
logits, past = self.forward(
batch.input_ids,
attention_mask,
batch.position_ids,
batch.past_key_values,
)
except Exception as e:
del batch
torch.cuda.empty_cache()
raise e
# Results
generations: List[Generation] = []
stopped = True
@ -659,6 +664,8 @@ class CausalLM(Model):
# We finished all generations in the batch; there is no next batch
if stopped:
del batch
torch.cuda.empty_cache()
return generations, None
# Slice unused values from prefill