mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Fix TGI issues with ROCm (#1921)
Not all models were tested in https://github.com/huggingface/text-generation-inference/pull/1764. Fixing some more issues (notably starcoder2) here, the full CI will come shortly once we split `build.yml` in two
This commit is contained in:
parent
05600c55a5
commit
14ed7c7b4a
@ -230,11 +230,15 @@ class LlamaMLP(nn.Module):
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
):
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
|
@ -290,11 +290,15 @@ class MistralMLP(nn.Module):
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
):
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
|
@ -890,6 +890,9 @@ class FlashCausalLM(Model):
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -899,7 +902,7 @@ class FlashCausalLM(Model):
|
||||
),
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=None,
|
||||
input_lengths=None,
|
||||
input_lengths=input_lengths,
|
||||
slots=slots,
|
||||
max_s=seqlen,
|
||||
lm_head_indices=None,
|
||||
|
@ -397,6 +397,9 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -406,7 +409,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
),
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=None,
|
||||
input_lengths=None,
|
||||
input_lengths=input_lengths,
|
||||
slots=slots,
|
||||
max_s=seqlen,
|
||||
lm_head_indices=None,
|
||||
|
Loading…
Reference in New Issue
Block a user