mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
black
This commit is contained in:
parent
93f6acc396
commit
9285f67be5
@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM):
|
|||||||
if name == "word_embeddings.weight":
|
if name == "word_embeddings.weight":
|
||||||
model.lm_head._parameters["weight"] = tensor
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
|
def forward(
|
||||||
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
):
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
@ -296,7 +296,10 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
with context_manager():
|
with context_manager():
|
||||||
logits, past = self.forward(
|
logits, past = self.forward(
|
||||||
batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values
|
batch.input_ids,
|
||||||
|
batch.attention_mask,
|
||||||
|
batch.position_ids,
|
||||||
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
# List of indices to cache
|
# List of indices to cache
|
||||||
@ -389,7 +392,7 @@ class CausalLM(Model):
|
|||||||
token_ids=token_ids.squeeze(1).tolist(),
|
token_ids=token_ids.squeeze(1).tolist(),
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
reason=reason,
|
reason=reason,
|
||||||
seed=seed
|
seed=seed,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
|
@ -333,7 +333,9 @@ class GalacticaSharded(Galactica):
|
|||||||
if name == "model.decoder.embed_tokens.weight":
|
if name == "model.decoder.embed_tokens.weight":
|
||||||
model.lm_head._parameters["weight"] = tensor
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
|
def forward(
|
||||||
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
):
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
@ -39,12 +39,16 @@ class SantaCoder(CausalLM):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = (
|
||||||
model_name,
|
AutoModelForCausalLM.from_pretrained(
|
||||||
torch_dtype=dtype,
|
model_name,
|
||||||
load_in_8bit=quantize,
|
torch_dtype=dtype,
|
||||||
trust_remote_code=True, # required
|
load_in_8bit=quantize,
|
||||||
).to(device).eval()
|
trust_remote_code=True, # required
|
||||||
|
)
|
||||||
|
.to(device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -468,7 +468,7 @@ class Seq2SeqLM(Model):
|
|||||||
token_ids=token_ids.tolist(),
|
token_ids=token_ids.tolist(),
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
reason=reason,
|
reason=reason,
|
||||||
seed=seed
|
seed=seed,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
|
@ -50,5 +50,5 @@ class GeneratedText:
|
|||||||
token_ids=self.token_ids,
|
token_ids=self.token_ids,
|
||||||
logprobs=self.logprobs,
|
logprobs=self.logprobs,
|
||||||
finish_reason=self.reason,
|
finish_reason=self.reason,
|
||||||
seed=self.seed
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
@ -33,7 +33,9 @@ class Sampling:
|
|||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator).squeeze(1)
|
next_tokens = torch.multinomial(
|
||||||
|
probs, num_samples=1, generator=self.generator
|
||||||
|
).squeeze(1)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -47,7 +49,9 @@ class Greedy:
|
|||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None):
|
def __init__(
|
||||||
|
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
|
||||||
|
):
|
||||||
warpers = LogitsProcessorList()
|
warpers = LogitsProcessorList()
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
# all samplers can be found in `generation_utils_samplers.py`
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
@ -84,7 +88,7 @@ class NextTokenChooser:
|
|||||||
top_k=pb.top_k,
|
top_k=pb.top_k,
|
||||||
top_p=pb.top_p,
|
top_p=pb.top_p,
|
||||||
do_sample=pb.do_sample,
|
do_sample=pb.do_sample,
|
||||||
seed=seed
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -100,10 +104,10 @@ class StopSequenceCriteria:
|
|||||||
|
|
||||||
class StoppingCriteria:
|
class StoppingCriteria:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eos_token_id: int,
|
eos_token_id: int,
|
||||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
):
|
):
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
@ -128,9 +132,9 @@ class StoppingCriteria:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.StoppingCriteriaParameters,
|
pb: generate_pb2.StoppingCriteriaParameters,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "StoppingCriteria":
|
) -> "StoppingCriteria":
|
||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
|
Loading…
Reference in New Issue
Block a user