mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Put back the attention impl.
This commit is contained in:
parent
6fe37d61d0
commit
8d05d6a62c
@ -37,10 +37,6 @@ def tgi_flash_attention_forward(
|
|||||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
):
|
):
|
||||||
kv_cache = kv_cache[module.layer_idx]
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||||
@ -127,9 +123,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
# load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
# trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
# attn_implementation="tgi",
|
attn_implementation="tgi",
|
||||||
device_map=device if world_size == 1 else None,
|
device_map=device if world_size == 1 else None,
|
||||||
tp_plan="auto" if world_size > 1 else None,
|
tp_plan="auto" if world_size > 1 else None,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user