This commit is contained in:
OlivierDehaene 2023-01-30 15:15:34 +01:00
parent 93f6acc396
commit 9285f67be5
7 changed files with 37 additions and 22 deletions

View File

@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM):
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,

View File

@ -296,7 +296,10 @@ class CausalLM(Model):
)
with context_manager():
logits, past = self.forward(
batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values
batch.input_ids,
batch.attention_mask,
batch.position_ids,
batch.past_key_values,
)
# List of indices to cache
@ -389,7 +392,7 @@ class CausalLM(Model):
token_ids=token_ids.squeeze(1).tolist(),
logprobs=logprobs,
reason=reason,
seed=seed
seed=seed,
)
)
# add to the next batch

View File

@ -333,7 +333,9 @@ class GalacticaSharded(Galactica):
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,

View File

@ -39,12 +39,16 @@ class SantaCoder(CausalLM):
}
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
load_in_8bit=quantize,
trust_remote_code=True, # required
).to(device).eval()
self.model = (
AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
load_in_8bit=quantize,
trust_remote_code=True, # required
)
.to(device)
.eval()
)
super(CausalLM, self).__init__(
tokenizer=tokenizer,

View File

@ -468,7 +468,7 @@ class Seq2SeqLM(Model):
token_ids=token_ids.tolist(),
logprobs=logprobs,
reason=reason,
seed=seed
seed=seed,
)
)
# add to the next batch

View File

@ -50,5 +50,5 @@ class GeneratedText:
token_ids=self.token_ids,
logprobs=self.logprobs,
finish_reason=self.reason,
seed=self.seed
seed=self.seed,
)

View File

@ -33,7 +33,9 @@ class Sampling:
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator).squeeze(1)
next_tokens = torch.multinomial(
probs, num_samples=1, generator=self.generator
).squeeze(1)
return next_tokens
@property
@ -47,7 +49,9 @@ class Greedy:
class NextTokenChooser:
def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None):
def __init__(
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
@ -84,7 +88,7 @@ class NextTokenChooser:
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=seed
seed=seed,
)
@ -100,10 +104,10 @@ class StopSequenceCriteria:
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
@ -128,9 +132,9 @@ class StoppingCriteria:
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences