mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 10:52:07 +00:00
feat(neuron): support on-device sampling
This commit is contained in:
parent
c8c5dcf352
commit
1bccd1a3e4
@ -356,6 +356,10 @@ class NeuronGenerator(Generator):
|
|||||||
]
|
]
|
||||||
self.batch_id = 0
|
self.batch_id = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def on_device_sampling(self) -> bool:
|
||||||
|
return getattr(self.model.neuron_config, "on_device_sampling", False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def info(self) -> InfoResponse:
|
def info(self) -> InfoResponse:
|
||||||
"""Returns the expected InfoResponse."""
|
"""Returns the expected InfoResponse."""
|
||||||
@ -454,6 +458,9 @@ class NeuronGenerator(Generator):
|
|||||||
)
|
)
|
||||||
input_ids = padded_inputs.input_ids
|
input_ids = padded_inputs.input_ids
|
||||||
attention_mask = padded_inputs.attention_mask
|
attention_mask = padded_inputs.attention_mask
|
||||||
|
sampling_params = (
|
||||||
|
torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
|
||||||
|
)
|
||||||
# Pause previously active slots during generation
|
# Pause previously active slots during generation
|
||||||
for slot in active_slots:
|
for slot in active_slots:
|
||||||
slot.pause()
|
slot.pause()
|
||||||
@ -477,14 +484,21 @@ class NeuronGenerator(Generator):
|
|||||||
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
|
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
|
||||||
slot_attention_mask = attention_mask[i]
|
slot_attention_mask = attention_mask[i]
|
||||||
slot.reset(slot_input_ids, slot_attention_mask, selector)
|
slot.reset(slot_input_ids, slot_attention_mask, selector)
|
||||||
|
if sampling_params is not None:
|
||||||
|
sampling_params[i, 0] = slot.generation_config.top_k
|
||||||
|
sampling_params[i, 1] = slot.generation_config.top_p
|
||||||
|
sampling_params[i, 2] = slot.generation_config.temperature
|
||||||
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
|
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
|
||||||
# as they have already been generated and sent back in the last decode.
|
# as they have already been generated and sent back in the last decode.
|
||||||
model_inputs = self.model.prepare_inputs_for_prefill(
|
model_inputs = self.model.prepare_inputs_for_prefill(
|
||||||
input_ids, attention_mask=attention_mask, seq_ids=seq_ids
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
seq_ids=seq_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
)
|
)
|
||||||
logits = self.model(**model_inputs)[0]
|
tokens_or_logits = self.model(**model_inputs)[0]
|
||||||
generation, next_batch = self._generate_token(
|
generation, next_batch = self._generate_token(
|
||||||
prefill_slots, self.batch_id, logits, input_ids
|
prefill_slots, self.batch_id, tokens_or_logits, input_ids
|
||||||
)
|
)
|
||||||
self.batch_id += 1
|
self.batch_id += 1
|
||||||
# Reactivate previously active slots for the next decode
|
# Reactivate previously active slots for the next decode
|
||||||
@ -544,22 +558,32 @@ class NeuronGenerator(Generator):
|
|||||||
for slot in decode_slots:
|
for slot in decode_slots:
|
||||||
max_length = max(max_length, slot.attention_mask.size(-1))
|
max_length = max(max_length, slot.attention_mask.size(-1))
|
||||||
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
|
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
|
||||||
|
sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None
|
||||||
for i, slot in enumerate(decode_slots):
|
for i, slot in enumerate(decode_slots):
|
||||||
if slot.state != Slot.State.EMPTY:
|
if slot.state != Slot.State.EMPTY:
|
||||||
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
|
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
|
||||||
input_ids[i, 0] = slot.next_token
|
input_ids[i, 0] = slot.next_token
|
||||||
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
|
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
|
||||||
|
if sampling_params is not None:
|
||||||
|
sampling_params[i, 0] = slot.generation_config.top_k
|
||||||
|
sampling_params[i, 1] = slot.generation_config.top_p
|
||||||
|
sampling_params[i, 2] = slot.generation_config.temperature
|
||||||
model_inputs = self.model.prepare_inputs_for_decode(
|
model_inputs = self.model.prepare_inputs_for_decode(
|
||||||
input_ids, attention_mask=attention_mask, seq_ids=seq_ids
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
seq_ids=seq_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
tokens_or_logits = self.model(**model_inputs)[0]
|
||||||
|
return self._generate_token(
|
||||||
|
decode_slots, next_batch_id, tokens_or_logits, input_ids
|
||||||
)
|
)
|
||||||
logits = self.model(**model_inputs)[0]
|
|
||||||
return self._generate_token(decode_slots, next_batch_id, logits, input_ids)
|
|
||||||
|
|
||||||
def _generate_token(
|
def _generate_token(
|
||||||
self,
|
self,
|
||||||
slots: List[Slot],
|
slots: List[Slot],
|
||||||
next_batch_id: int,
|
next_batch_id: int,
|
||||||
logits: torch.Tensor,
|
tokens_or_logits: torch.Tensor,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
) -> Tuple[List[Generation], CachedBatch]:
|
) -> Tuple[List[Generation], CachedBatch]:
|
||||||
generations = []
|
generations = []
|
||||||
@ -568,8 +592,11 @@ class NeuronGenerator(Generator):
|
|||||||
if slot.state != Slot.State.READY:
|
if slot.state != Slot.State.READY:
|
||||||
continue
|
continue
|
||||||
request_id = slot.request_id
|
request_id = slot.request_id
|
||||||
next_token_logits = logits[i : i + 1, -1, :]
|
|
||||||
slot_input_ids = input_ids[i : i + 1, :]
|
slot_input_ids = input_ids[i : i + 1, :]
|
||||||
|
if self.on_device_sampling:
|
||||||
|
next_token = tokens_or_logits[i]
|
||||||
|
else:
|
||||||
|
next_token_logits = tokens_or_logits[i : i + 1, -1, :]
|
||||||
next_token = slot.select(slot_input_ids, next_token_logits)
|
next_token = slot.select(slot_input_ids, next_token_logits)
|
||||||
next_token_text = slot.append(next_token)
|
next_token_text = slot.append(next_token)
|
||||||
generated_text = None
|
generated_text = None
|
||||||
|
Loading…
Reference in New Issue
Block a user