mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Remove unneeded Model.num_heads field
This commit is contained in:
parent
a172430d8b
commit
1747365e25
@ -82,7 +82,6 @@ class BLOOMSharded(CausalLM):
|
|||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_heads=config.n_head // self.process_group.size(),
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -251,7 +251,6 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_heads=self.model.config.num_attention_heads,
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -358,7 +357,7 @@ class CausalLM(Model):
|
|||||||
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
||||||
next_batch_past_key_values = [
|
next_batch_past_key_values = [
|
||||||
[
|
[
|
||||||
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices]
|
t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
|
||||||
for t in layer
|
for t in layer
|
||||||
]
|
]
|
||||||
for layer in past
|
for layer in past
|
||||||
|
@ -185,7 +185,6 @@ class GalacticaSharded(Galactica):
|
|||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_heads=config.num_attention_heads // self.process_group.size(),
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -10,9 +10,8 @@ B = TypeVar("B", bound=Batch)
|
|||||||
|
|
||||||
|
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device):
|
def __init__(self, tokenizer: Tokenizer, device: torch.device):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.num_heads = num_heads
|
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -319,7 +319,6 @@ class Seq2SeqLM(Model):
|
|||||||
|
|
||||||
super(Seq2SeqLM, self).__init__(
|
super(Seq2SeqLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_heads=self.model.config.num_attention_heads,
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user