mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Fix TGI issues with ROCm (#1921)
Not all models were tested in https://github.com/huggingface/text-generation-inference/pull/1764. Fixing some more issues (notably starcoder2) here, the full CI will come shortly once we split `build.yml` in two
This commit is contained in:
parent
05600c55a5
commit
14ed7c7b4a
@ -230,11 +230,15 @@ class LlamaMLP(nn.Module):
|
|||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
and hidden_states.shape[0] == 1
|
and hidden_states.shape[0] == 1
|
||||||
|
and not self.quantize
|
||||||
):
|
):
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
hidden_states.shape[0],
|
hidden_states.shape[0],
|
||||||
|
@ -290,11 +290,15 @@ class MistralMLP(nn.Module):
|
|||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
and hidden_states.shape[0] == 1
|
and hidden_states.shape[0] == 1
|
||||||
|
and not self.quantize
|
||||||
):
|
):
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
hidden_states.shape[0],
|
hidden_states.shape[0],
|
||||||
|
@ -890,6 +890,9 @@ class FlashCausalLM(Model):
|
|||||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||||
kv_cache = get_cache_manager().kv_cache
|
kv_cache = get_cache_manager().kv_cache
|
||||||
|
|
||||||
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
|
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -899,7 +902,7 @@ class FlashCausalLM(Model):
|
|||||||
),
|
),
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
kv_cache=get_cache_manager().kv_cache,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
input_lengths=None,
|
input_lengths=input_lengths,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=seqlen,
|
max_s=seqlen,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
@ -397,6 +397,9 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||||
kv_cache = get_cache_manager().kv_cache
|
kv_cache = get_cache_manager().kv_cache
|
||||||
|
|
||||||
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
|
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -406,7 +409,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
),
|
),
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
kv_cache=get_cache_manager().kv_cache,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
input_lengths=None,
|
input_lengths=input_lengths,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=seqlen,
|
max_s=seqlen,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
Loading…
Reference in New Issue
Block a user