mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing some simple stuff, adding speculate
to budget.
This commit is contained in:
parent
5aa3a01971
commit
09839b05f4
@ -32,7 +32,7 @@ message InfoResponse {
|
||||
string dtype = 2;
|
||||
string device_type = 3;
|
||||
optional uint32 window_size = 4;
|
||||
optional uint32 speculate = 5;
|
||||
uint32 speculate = 5;
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
|
@ -50,10 +50,11 @@ impl Infer {
|
||||
max_concurrent_requests: usize,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(requires_padding, 16, window_size);
|
||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
});
|
||||
|
@ -34,7 +34,12 @@ pub(crate) struct Queue {
|
||||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
pub(crate) fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
@ -43,6 +48,7 @@ impl Queue {
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
queue_receiver,
|
||||
));
|
||||
|
||||
@ -91,9 +97,10 @@ async fn queue_task(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(requires_padding, block_size, window_size);
|
||||
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
@ -136,10 +143,18 @@ struct State {
|
||||
|
||||
/// Sliding window
|
||||
window_size: Option<u32>,
|
||||
|
||||
/// Speculation amount
|
||||
speculate: u32,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
@ -147,6 +162,7 @@ impl State {
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
}
|
||||
}
|
||||
|
||||
@ -221,7 +237,7 @@ impl State {
|
||||
window_size.saturating_sub(entry.request.input_length),
|
||||
entry.request.stopping_parameters.max_new_tokens,
|
||||
),
|
||||
};
|
||||
} + self.speculate;
|
||||
|
||||
// pad to block size
|
||||
decode_tokens +=
|
||||
@ -359,7 +375,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_append() {
|
||||
let mut state = State::new(false, 1, None);
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
@ -375,7 +391,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, None);
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
|
||||
assert!(state.next_batch(None, 1, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
||||
@ -383,7 +399,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, None);
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
@ -415,7 +431,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, None);
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
@ -448,14 +464,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, None);
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, None);
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
|
||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
||||
@ -463,7 +479,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, None);
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -496,7 +512,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, None);
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -519,9 +535,20 @@ mod tests {
|
||||
assert_eq!(batch.size, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(false, 1, None, 2);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, None);
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
@ -596,6 +596,7 @@ pub async fn run(
|
||||
max_concurrent_requests,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
generation_health,
|
||||
);
|
||||
|
||||
|
@ -938,8 +938,6 @@ class FlashCausalLM(Model):
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
batch.top_n_tokens,
|
||||
# next_token_ids,
|
||||
# next_token_logprobs,
|
||||
accepted_ids,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
@ -957,8 +955,6 @@ class FlashCausalLM(Model):
|
||||
do_sample,
|
||||
seed,
|
||||
top_n_tokens,
|
||||
# next_token_id,
|
||||
# next_token_logprob,
|
||||
n_accepted_ids,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
@ -968,21 +964,18 @@ class FlashCausalLM(Model):
|
||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
|
||||
|
||||
next_token_texts = []
|
||||
left = 0
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
# Generated token
|
||||
all_input_ids.append(next_token_ids[j])
|
||||
next_token_id = next_token_ids[j]
|
||||
all_input_ids.append(next_token_id)
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
all_input_ids,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
)
|
||||
next_token_texts.append(next_token_text)
|
||||
index += n_accepted_ids
|
||||
|
||||
# Evaluate stopping criteria
|
||||
|
||||
left = 0
|
||||
for j, next_token_id in enumerate(_next_token_ids):
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id,
|
||||
next_token_text,
|
||||
@ -994,6 +987,7 @@ class FlashCausalLM(Model):
|
||||
break
|
||||
else:
|
||||
stopped = False
|
||||
index += n_accepted_ids
|
||||
_next_token_ids = _next_token_ids[:len(_next_token_ids) - left]
|
||||
|
||||
# Shard generations
|
||||
@ -1003,7 +997,7 @@ class FlashCausalLM(Model):
|
||||
# Decode generated tokens
|
||||
# Remove potentially accepted ids that do not respect
|
||||
# the stopping_criteria
|
||||
_ids = all_input_ids[:len(all_input_ids)-left]
|
||||
_ids = all_input_ids
|
||||
output_text, _, _ = self.decode_token(
|
||||
_ids,
|
||||
prefix_offset=len(_ids)
|
||||
|
@ -6,6 +6,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
@ -22,6 +23,7 @@ class Model(ABC):
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
sliding_window: Optional[int] = None,
|
||||
speculate: Optional[int] = None,
|
||||
):
|
||||
self.model = model.eval()
|
||||
self.tokenizer = tokenizer
|
||||
@ -33,6 +35,10 @@ class Model(ABC):
|
||||
self.world_size = world_size
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
if speculate is None:
|
||||
speculate = get_speculate()
|
||||
self.speculate = speculate
|
||||
|
||||
self.has_position_ids = (
|
||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||
is not None
|
||||
@ -50,6 +56,7 @@ class Model(ABC):
|
||||
dtype=str(self.dtype),
|
||||
device_type=self.device.type,
|
||||
window_size=self.sliding_window,
|
||||
speculate=self.speculate
|
||||
)
|
||||
|
||||
@property
|
||||
|
Loading…
Reference in New Issue
Block a user