diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7e93a90e..356ca668 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -953,12 +953,13 @@ class FlashCausalLM(Model): total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size free_memory = get_free_memory(self.device, MEMORY_FRACTION) + batch_num_blocks = batch.num_blocks if batch is not None else 0 num_blocks = ( # Leave 5% for some wiggle room int((free_memory * 0.95) // total_cache_size) # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. - + batch.num_blocks + + batch_num_blocks ) del batch diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index b907ee08..6738a681 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -62,6 +62,7 @@ class FlashCohere(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashCohere, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index d5eb1a6e..0e9c913c 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -87,6 +87,7 @@ class FlashDbrx(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashDbrx, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 358883e6..c635ccfa 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -62,6 +62,7 @@ class FlashGemma(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashGemma, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index ef129e92..e9fc471e 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -65,6 +65,7 @@ class FlashGPT2(FlashCausalLM): model = FlashGPT2ForCausalLM(prefix, config, weights) torch.distributed.barrier(group=self.process_group) super(FlashGPT2, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 081c2e2c..e66a1c3d 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -79,6 +79,7 @@ class BaseFlashMistral(FlashCausalLM): torch.distributed.barrier(group=self.process_group) num_layers, num_kv_heads, head_size = self.get_layer_config(model) super().__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=num_layers, @@ -110,6 +111,7 @@ class FlashMistral(BaseFlashMistral): trust_remote_code: bool = False, ): super(FlashMistral, self).__init__( + model_id=model_id, config_cls=MistralConfig, model_cls=FlashMistralForCausalLM, model_id=model_id, diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py index 587d423f..216279a8 100644 --- a/server/text_generation_server/models/flash_mixtral.py +++ b/server/text_generation_server/models/flash_mixtral.py @@ -20,6 +20,7 @@ class FlashMixtral(BaseFlashMistral): trust_remote_code: bool = False, ): super(FlashMixtral, self).__init__( + model_id=model_id, config_cls=MixtralConfig, model_cls=FlashMixtralForCausalLM, model_id=model_id, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index adefaeb2..f14f20f5 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -65,6 +65,7 @@ class FlashNeoXSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashNeoXSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.gpt_neox.layers), diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 32b573a9..709c15c5 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -91,6 +91,7 @@ class FlashPhi(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashPhi, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 75285863..b984c7ca 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -71,6 +71,7 @@ class FlashQwen2(BaseFlashMistral): torch.distributed.barrier(group=self.process_group) super(BaseFlashMistral, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index e6350611..b1867988 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -74,6 +74,7 @@ class FlashRWSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashRWSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.transformer.h), diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 2ad36b93..8ee66292 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -76,6 +76,7 @@ class FlashSantacoderSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashSantacoderSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.transformer.h), diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 5533c9d9..9eb88816 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -70,6 +70,7 @@ class FlashStarcoder2(BaseFlashMistral): torch.distributed.barrier(group=self.process_group) super(BaseFlashMistral, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers),