mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
fix: pass model_id for all causal and seq2seq lms
This commit is contained in:
parent
88bd5c2c92
commit
dc0f76553c
@ -90,6 +90,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
@ -538,6 +538,7 @@ class CausalLM(Model):
|
|||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
@ -212,6 +212,7 @@ class GalacticaSharded(CausalLM):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
@ -65,6 +65,7 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
@ -90,6 +90,7 @@ class MPTSharded(CausalLM):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
|
@ -63,6 +63,7 @@ class OPTSharded(CausalLM):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
@ -60,6 +60,7 @@ class Phi(CausalLM):
|
|||||||
model = PhiForCausalLM(config, weights)
|
model = PhiForCausalLM(config, weights)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
@ -62,6 +62,7 @@ class RW(CausalLM):
|
|||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
@ -62,6 +62,7 @@ class SantaCoder(CausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
@ -575,6 +575,7 @@ class Seq2SeqLM(Model):
|
|||||||
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
||||||
|
|
||||||
super(Seq2SeqLM, self).__init__(
|
super(Seq2SeqLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
@ -73,6 +73,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(Seq2SeqLM, self).__init__(
|
super(Seq2SeqLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
Loading…
Reference in New Issue
Block a user