mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
"Fix" for rw-1b.
- New "falcon" layout on this repo - No alibi - `transformers` already modifying cache layout in our stead (same modifications). - Output is garbage. Not sure why.
This commit is contained in:
parent
d9bceb8e6b
commit
7e20b8cb50
@ -67,18 +67,6 @@ class RW(CausalLM):
|
|||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if past_key_values is not None:
|
|
||||||
reshaped_past_key_values = []
|
|
||||||
for layer in past_key_values:
|
|
||||||
past_keys, past_values = layer
|
|
||||||
reshaped_past_key_values.append(
|
|
||||||
(
|
|
||||||
past_keys.view(-1, *past_keys.shape[-2:]),
|
|
||||||
past_values.view(-1, *past_values.shape[-2:]),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
past_key_values = reshaped_past_key_values
|
|
||||||
|
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
Loading…
Reference in New Issue
Block a user