Remove unneeded Model.num_heads field

This commit is contained in:
Nick Hill 2022-12-04 11:31:08 -08:00
parent a172430d8b
commit 1747365e25
5 changed files with 2 additions and 7 deletions

View File

@ -82,7 +82,6 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
num_heads=config.n_head // self.process_group.size(),
device=device, device=device,
) )

View File

@ -251,7 +251,6 @@ class CausalLM(Model):
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
num_heads=self.model.config.num_attention_heads,
device=device, device=device,
) )
@ -358,7 +357,7 @@ class CausalLM(Model):
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values = [ next_batch_past_key_values = [
[ [
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices] t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
for t in layer for t in layer
] ]
for layer in past for layer in past

View File

@ -185,7 +185,6 @@ class GalacticaSharded(Galactica):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
num_heads=config.num_attention_heads // self.process_group.size(),
device=device, device=device,
) )

View File

@ -10,9 +10,8 @@ B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device): def __init__(self, tokenizer: Tokenizer, device: torch.device):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_heads = num_heads
self.device = device self.device = device
@property @property

View File

@ -319,7 +319,6 @@ class Seq2SeqLM(Model):
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
num_heads=self.model.config.num_attention_heads,
device=device, device=device,
) )