mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
black
This commit is contained in:
parent
93f6acc396
commit
9285f67be5
@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM):
|
||||
if name == "word_embeddings.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
@ -296,7 +296,10 @@ class CausalLM(Model):
|
||||
)
|
||||
with context_manager():
|
||||
logits, past = self.forward(
|
||||
batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values
|
||||
batch.input_ids,
|
||||
batch.attention_mask,
|
||||
batch.position_ids,
|
||||
batch.past_key_values,
|
||||
)
|
||||
|
||||
# List of indices to cache
|
||||
@ -389,7 +392,7 @@ class CausalLM(Model):
|
||||
token_ids=token_ids.squeeze(1).tolist(),
|
||||
logprobs=logprobs,
|
||||
reason=reason,
|
||||
seed=seed
|
||||
seed=seed,
|
||||
)
|
||||
)
|
||||
# add to the next batch
|
||||
|
@ -333,7 +333,9 @@ class GalacticaSharded(Galactica):
|
||||
if name == "model.decoder.embed_tokens.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
@ -39,12 +39,16 @@ class SantaCoder(CausalLM):
|
||||
}
|
||||
)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize,
|
||||
trust_remote_code=True, # required
|
||||
).to(device).eval()
|
||||
)
|
||||
.to(device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
|
@ -468,7 +468,7 @@ class Seq2SeqLM(Model):
|
||||
token_ids=token_ids.tolist(),
|
||||
logprobs=logprobs,
|
||||
reason=reason,
|
||||
seed=seed
|
||||
seed=seed,
|
||||
)
|
||||
)
|
||||
# add to the next batch
|
||||
|
@ -50,5 +50,5 @@ class GeneratedText:
|
||||
token_ids=self.token_ids,
|
||||
logprobs=self.logprobs,
|
||||
finish_reason=self.reason,
|
||||
seed=self.seed
|
||||
seed=self.seed,
|
||||
)
|
||||
|
@ -33,7 +33,9 @@ class Sampling:
|
||||
|
||||
def __call__(self, logits):
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator).squeeze(1)
|
||||
next_tokens = torch.multinomial(
|
||||
probs, num_samples=1, generator=self.generator
|
||||
).squeeze(1)
|
||||
return next_tokens
|
||||
|
||||
@property
|
||||
@ -47,7 +49,9 @@ class Greedy:
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None):
|
||||
def __init__(
|
||||
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
|
||||
):
|
||||
warpers = LogitsProcessorList()
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
@ -84,7 +88,7 @@ class NextTokenChooser:
|
||||
top_k=pb.top_k,
|
||||
top_p=pb.top_p,
|
||||
do_sample=pb.do_sample,
|
||||
seed=seed
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user