diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 51adffc7d..a2cbf30c0 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1512,9 +1512,10 @@ class FlashCausalLM(Model): self.bucketing_ctx = HPUBucketingContext( os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO - os.getenv("PREFILL_MAX_BS", 16), # self.max_num_prefill_seqs, #TODO + os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO BLOCK_SIZE, num_blocks * BLOCK_SIZE, + False, ) self.bucketing_ctx.num_hpu_blocks = num_blocks if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":