mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Memory fragmentation added for Causal LM
This commit is contained in:
parent
15de7c7ac3
commit
20ca9cf0c3
@ -89,6 +89,7 @@ class CausalLMBatch(Batch):
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
stopping_criteria.max_new_tokens = stopping_criteria.max_new_tokens if stopping_criteria.max_new_tokens < 512 else 512
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
@ -423,6 +424,7 @@ class CausalLMBatch(Batch):
|
||||
start_index = end_index
|
||||
|
||||
past_key_values.append([padded_past_keys, padded_past_values])
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
@ -537,14 +539,17 @@ class CausalLM(Model):
|
||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
|
||||
logits, past = self.forward(
|
||||
batch.input_ids,
|
||||
attention_mask,
|
||||
batch.position_ids,
|
||||
batch.past_key_values,
|
||||
)
|
||||
|
||||
try:
|
||||
logits, past = self.forward(
|
||||
batch.input_ids,
|
||||
attention_mask,
|
||||
batch.position_ids,
|
||||
batch.past_key_values,
|
||||
)
|
||||
except Exception as e:
|
||||
del batch
|
||||
torch.cuda.empty_cache()
|
||||
raise e
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
@ -659,6 +664,8 @@ class CausalLM(Model):
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
del batch
|
||||
torch.cuda.empty_cache()
|
||||
return generations, None
|
||||
|
||||
# Slice unused values from prefill
|
||||
|
Loading…
Reference in New Issue
Block a user