Fix TGI issues with ROCm (#1921)

Not all models were tested in
https://github.com/huggingface/text-generation-inference/pull/1764.

Fixing some more issues (notably starcoder2) here, the full CI will come
shortly once we split `build.yml` in two
This commit is contained in:
fxmarty 2024-05-17 19:50:52 +02:00 committed by yuanwu
parent 05600c55a5
commit 14ed7c7b4a
4 changed files with 16 additions and 2 deletions

View File

@ -230,11 +230,15 @@ class LlamaMLP(nn.Module):
config.intermediate_size // weights.process_group.size()
)
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],

View File

@ -290,11 +290,15 @@ class MistralMLP(nn.Module):
config.intermediate_size // weights.process_group.size()
)
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],

View File

@ -890,6 +890,9 @@ class FlashCausalLM(Model):
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
@ -899,7 +902,7 @@ class FlashCausalLM(Model):
),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=None,
input_lengths=input_lengths,
slots=slots,
max_s=seqlen,
lm_head_indices=None,

View File

@ -397,6 +397,9 @@ class BaseFlashMistral(FlashCausalLM):
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
@ -406,7 +409,7 @@ class BaseFlashMistral(FlashCausalLM):
),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=None,
input_lengths=input_lengths,
slots=slots,
max_s=seqlen,
lm_head_indices=None,