diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 9c1d2225..581cbde8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1346,6 +1346,9 @@ class FlashCausalLM(Model): def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch + def get_input_embeddings(self, batch): + batch.inputs_embeds = None + def init_kv_cache( self, num_blocks: int,