From aac547dd8283ecebd9945813b5d5f79c6257f2c5 Mon Sep 17 00:00:00 2001 From: BaihuiJin Date: Thu, 11 Jul 2024 21:19:17 +0800 Subject: [PATCH] Clear previous hpu_graphs when graph shape changed to save memory (#176) --- server/text_generation_server/models/causal_lm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 37d7479b..ad2270ab 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -602,6 +602,7 @@ class CausalLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.prev_bs = 0 if use_medusa: raise RuntimeError("Medusa decoding is not enabled for AutoModel") @@ -965,6 +966,9 @@ class CausalLM(Model): batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) scenario = 'PREFILL' if prefill else 'GENERATE' + if self.enable_hpu_graph and batch.batch_size != self.prev_bs: + self.model.clear_cache() + self.prev_bs = batch.batch_size dbg_trace( scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') assert batch.right_padding > 0, 'No more room for next token!'