diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 6f242ca4..c2669dba 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -37,10 +37,6 @@ def tgi_flash_attention_forward( **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): kv_cache = kv_cache[module.layer_idx] - - import ipdb - - ipdb.set_trace() query_states = query_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0) @@ -127,9 +123,9 @@ class TransformersFlashCausalLM(FlashCausalLM): model_id, revision=revision, torch_dtype=dtype, - # load_in_8bit=quantize == "bitsandbytes", - # trust_remote_code=trust_remote_code, - # attn_implementation="tgi", + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + attn_implementation="tgi", device_map=device if world_size == 1 else None, tp_plan="auto" if world_size > 1 else None, )