Put back the attention impl.

This commit is contained in:
Nicolas Patry 2025-01-22 18:13:59 +01:00
parent 6fe37d61d0
commit 8d05d6a62c
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -37,10 +37,6 @@ def tgi_flash_attention_forward(
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling **kwargs, # This is needed to "absorb" other args passed by Transformers modeling
): ):
kv_cache = kv_cache[module.layer_idx] kv_cache = kv_cache[module.layer_idx]
import ipdb
ipdb.set_trace()
query_states = query_states.transpose(1, 2).squeeze(dim=0) query_states = query_states.transpose(1, 2).squeeze(dim=0)
key_states = key_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0)
value_states = value_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0)
@ -127,9 +123,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
# load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
# trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
# attn_implementation="tgi", attn_implementation="tgi",
device_map=device if world_size == 1 else None, device_map=device if world_size == 1 else None,
tp_plan="auto" if world_size > 1 else None, tp_plan="auto" if world_size > 1 else None,
) )