mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Put back the attention impl.
This commit is contained in:
parent
6fe37d61d0
commit
8d05d6a62c
@ -37,10 +37,6 @@ def tgi_flash_attention_forward(
|
||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||
):
|
||||
kv_cache = kv_cache[module.layer_idx]
|
||||
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||
@ -127,9 +123,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
# load_in_8bit=quantize == "bitsandbytes",
|
||||
# trust_remote_code=trust_remote_code,
|
||||
# attn_implementation="tgi",
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
attn_implementation="tgi",
|
||||
device_map=device if world_size == 1 else None,
|
||||
tp_plan="auto" if world_size > 1 else None,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user