"Fix" for rw-1b.

- New "falcon" layout on this repo
- No alibi
- `transformers` already modifying cache layout in our stead (same
  modifications).
- Output is garbage. Not sure why.
This commit is contained in:
Nicolas Patry 2023-08-16 19:58:30 +00:00
parent d9bceb8e6b
commit 7e20b8cb50

View File

@ -67,18 +67,6 @@ class RW(CausalLM):
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
if past_key_values is not None:
reshaped_past_key_values = []
for layer in past_key_values:
past_keys, past_values = layer
reshaped_past_key_values.append(
(
past_keys.view(-1, *past_keys.shape[-2:]),
past_values.view(-1, *past_values.shape[-2:]),
)
)
past_key_values = reshaped_past_key_values
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,