From 1747365e25d431c117ccc7268c0beed300503f39 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 4 Dec 2022 11:31:08 -0800 Subject: [PATCH] Remove unneeded Model.num_heads field --- server/text_generation/models/bloom.py | 1 - server/text_generation/models/causal_lm.py | 3 +-- server/text_generation/models/galactica.py | 1 - server/text_generation/models/model.py | 3 +-- server/text_generation/models/seq2seq_lm.py | 1 - 5 files changed, 2 insertions(+), 7 deletions(-) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 008288f8..2a7405d3 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -82,7 +82,6 @@ class BLOOMSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, - num_heads=config.n_head // self.process_group.size(), device=device, ) diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 2c55508b..4e66ae3a 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -251,7 +251,6 @@ class CausalLM(Model): super(CausalLM, self).__init__( tokenizer=tokenizer, - num_heads=self.model.config.num_attention_heads, device=device, ) @@ -358,7 +357,7 @@ class CausalLM(Model): # Force past to be of dim [batch_size, num_heads, ...] for easy indexing next_batch_past_key_values = [ [ - t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices] + t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices] for t in layer ] for layer in past diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index abc3c36c..5de75ab4 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -185,7 +185,6 @@ class GalacticaSharded(Galactica): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, - num_heads=config.num_attention_heads // self.process_group.size(), device=device, ) diff --git a/server/text_generation/models/model.py b/server/text_generation/models/model.py index 7fb8142c..0331e193 100644 --- a/server/text_generation/models/model.py +++ b/server/text_generation/models/model.py @@ -10,9 +10,8 @@ B = TypeVar("B", bound=Batch) class Model(ABC): - def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device): + def __init__(self, tokenizer: Tokenizer, device: torch.device): self.tokenizer = tokenizer - self.num_heads = num_heads self.device = device @property diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index f63a8849..e9c65596 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -319,7 +319,6 @@ class Seq2SeqLM(Model): super(Seq2SeqLM, self).__init__( tokenizer=tokenizer, - num_heads=self.model.config.num_attention_heads, device=device, )