This commit is contained in:
OlivierDehaene 2023-01-30 15:15:34 +01:00
parent 93f6acc396
commit 9285f67be5
7 changed files with 37 additions and 22 deletions

View File

@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM):
if name == "word_embeddings.weight": if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,

View File

@ -296,7 +296,10 @@ class CausalLM(Model):
) )
with context_manager(): with context_manager():
logits, past = self.forward( logits, past = self.forward(
batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values batch.input_ids,
batch.attention_mask,
batch.position_ids,
batch.past_key_values,
) )
# List of indices to cache # List of indices to cache
@ -389,7 +392,7 @@ class CausalLM(Model):
token_ids=token_ids.squeeze(1).tolist(), token_ids=token_ids.squeeze(1).tolist(),
logprobs=logprobs, logprobs=logprobs,
reason=reason, reason=reason,
seed=seed seed=seed,
) )
) )
# add to the next batch # add to the next batch

View File

@ -333,7 +333,9 @@ class GalacticaSharded(Galactica):
if name == "model.decoder.embed_tokens.weight": if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,

View File

@ -39,12 +39,16 @@ class SantaCoder(CausalLM):
} }
) )
self.model = AutoModelForCausalLM.from_pretrained( self.model = (
model_name, AutoModelForCausalLM.from_pretrained(
torch_dtype=dtype, model_name,
load_in_8bit=quantize, torch_dtype=dtype,
trust_remote_code=True, # required load_in_8bit=quantize,
).to(device).eval() trust_remote_code=True, # required
)
.to(device)
.eval()
)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,

View File

@ -468,7 +468,7 @@ class Seq2SeqLM(Model):
token_ids=token_ids.tolist(), token_ids=token_ids.tolist(),
logprobs=logprobs, logprobs=logprobs,
reason=reason, reason=reason,
seed=seed seed=seed,
) )
) )
# add to the next batch # add to the next batch

View File

@ -50,5 +50,5 @@ class GeneratedText:
token_ids=self.token_ids, token_ids=self.token_ids,
logprobs=self.logprobs, logprobs=self.logprobs,
finish_reason=self.reason, finish_reason=self.reason,
seed=self.seed seed=self.seed,
) )

View File

@ -33,7 +33,9 @@ class Sampling:
def __call__(self, logits): def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator).squeeze(1) next_tokens = torch.multinomial(
probs, num_samples=1, generator=self.generator
).squeeze(1)
return next_tokens return next_tokens
@property @property
@ -47,7 +49,9 @@ class Greedy:
class NextTokenChooser: class NextTokenChooser:
def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None): def __init__(
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
):
warpers = LogitsProcessorList() warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py` # all samplers can be found in `generation_utils_samplers.py`
@ -84,7 +88,7 @@ class NextTokenChooser:
top_k=pb.top_k, top_k=pb.top_k,
top_p=pb.top_p, top_p=pb.top_p,
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=seed seed=seed,
) )
@ -100,10 +104,10 @@ class StopSequenceCriteria:
class StoppingCriteria: class StoppingCriteria:
def __init__( def __init__(
self, self,
eos_token_id: int, eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria], stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20, max_new_tokens=20,
): ):
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias self.stop_sequence_criterias = stop_sequence_criterias
@ -128,9 +132,9 @@ class StoppingCriteria:
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.StoppingCriteriaParameters, pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria": ) -> "StoppingCriteria":
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences