mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing some simple stuff, adding speculate
to budget.
This commit is contained in:
parent
5aa3a01971
commit
09839b05f4
@ -32,7 +32,7 @@ message InfoResponse {
|
|||||||
string dtype = 2;
|
string dtype = 2;
|
||||||
string device_type = 3;
|
string device_type = 3;
|
||||||
optional uint32 window_size = 4;
|
optional uint32 window_size = 4;
|
||||||
optional uint32 speculate = 5;
|
uint32 speculate = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
|
@ -50,10 +50,11 @@ impl Infer {
|
|||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding, 16, window_size);
|
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
|
@ -34,7 +34,12 @@ pub(crate) struct Queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Queue {
|
impl Queue {
|
||||||
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
pub(crate) fn new(
|
||||||
|
requires_padding: bool,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
@ -43,6 +48,7 @@ impl Queue {
|
|||||||
requires_padding,
|
requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
|
speculate,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -91,9 +97,10 @@ async fn queue_task(
|
|||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(requires_padding, block_size, window_size);
|
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
@ -136,10 +143,18 @@ struct State {
|
|||||||
|
|
||||||
/// Sliding window
|
/// Sliding window
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
|
|
||||||
|
/// Speculation amount
|
||||||
|
speculate: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
fn new(
|
||||||
|
requires_padding: bool,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
@ -147,6 +162,7 @@ impl State {
|
|||||||
requires_padding,
|
requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
|
speculate,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,7 +237,7 @@ impl State {
|
|||||||
window_size.saturating_sub(entry.request.input_length),
|
window_size.saturating_sub(entry.request.input_length),
|
||||||
entry.request.stopping_parameters.max_new_tokens,
|
entry.request.stopping_parameters.max_new_tokens,
|
||||||
),
|
),
|
||||||
};
|
} + self.speculate;
|
||||||
|
|
||||||
// pad to block size
|
// pad to block size
|
||||||
decode_tokens +=
|
decode_tokens +=
|
||||||
@ -359,7 +375,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_append() {
|
fn test_append() {
|
||||||
let mut state = State::new(false, 1, None);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
@ -375,7 +391,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_empty() {
|
fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false, 1, None);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
|
|
||||||
assert!(state.next_batch(None, 1, 1).is_none());
|
assert!(state.next_batch(None, 1, 1).is_none());
|
||||||
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
||||||
@ -383,7 +399,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_min_size() {
|
fn test_next_batch_min_size() {
|
||||||
let mut state = State::new(false, 1, None);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
@ -415,7 +431,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_token_budget() {
|
fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false, 1, None);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
@ -448,14 +464,14 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, 1, None);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, 1, None);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
||||||
@ -463,7 +479,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, 1, None);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -496,7 +512,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, 1, None);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -519,9 +535,20 @@ mod tests {
|
|||||||
assert_eq!(batch.size, 2);
|
assert_eq!(batch.size, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
|
let queue = Queue::new(false, 1, None, 2);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
|
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, 1, None);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
@ -596,6 +596,7 @@ pub async fn run(
|
|||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
shard_info.window_size,
|
shard_info.window_size,
|
||||||
|
shard_info.speculate,
|
||||||
generation_health,
|
generation_health,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -938,8 +938,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
# next_token_ids,
|
|
||||||
# next_token_logprobs,
|
|
||||||
accepted_ids,
|
accepted_ids,
|
||||||
batch_top_token_ids,
|
batch_top_token_ids,
|
||||||
batch_top_token_logprobs,
|
batch_top_token_logprobs,
|
||||||
@ -957,8 +955,6 @@ class FlashCausalLM(Model):
|
|||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
# next_token_id,
|
|
||||||
# next_token_logprob,
|
|
||||||
n_accepted_ids,
|
n_accepted_ids,
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
@ -968,21 +964,18 @@ class FlashCausalLM(Model):
|
|||||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
|
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
|
||||||
|
|
||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
|
left = 0
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
# Generated token
|
# Generated token
|
||||||
all_input_ids.append(next_token_ids[j])
|
next_token_id = next_token_ids[j]
|
||||||
|
all_input_ids.append(next_token_id)
|
||||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
prefix_offset,
|
prefix_offset,
|
||||||
read_offset,
|
read_offset,
|
||||||
)
|
)
|
||||||
next_token_texts.append(next_token_text)
|
next_token_texts.append(next_token_text)
|
||||||
index += n_accepted_ids
|
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
|
||||||
|
|
||||||
left = 0
|
|
||||||
for j, next_token_id in enumerate(_next_token_ids):
|
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
next_token_id,
|
next_token_id,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
@ -994,6 +987,7 @@ class FlashCausalLM(Model):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
stopped = False
|
stopped = False
|
||||||
|
index += n_accepted_ids
|
||||||
_next_token_ids = _next_token_ids[:len(_next_token_ids) - left]
|
_next_token_ids = _next_token_ids[:len(_next_token_ids) - left]
|
||||||
|
|
||||||
# Shard generations
|
# Shard generations
|
||||||
@ -1003,7 +997,7 @@ class FlashCausalLM(Model):
|
|||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
# Remove potentially accepted ids that do not respect
|
# Remove potentially accepted ids that do not respect
|
||||||
# the stopping_criteria
|
# the stopping_criteria
|
||||||
_ids = all_input_ids[:len(all_input_ids)-left]
|
_ids = all_input_ids
|
||||||
output_text, _, _ = self.decode_token(
|
output_text, _, _ = self.decode_token(
|
||||||
_ids,
|
_ids,
|
||||||
prefix_offset=len(_ids)
|
prefix_offset=len(_ids)
|
||||||
|
@ -6,6 +6,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
|
|||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, Generation
|
from text_generation_server.models.types import Batch, Generation
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
@ -22,6 +23,7 @@ class Model(ABC):
|
|||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
speculate: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@ -33,6 +35,10 @@ class Model(ABC):
|
|||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
if speculate is None:
|
||||||
|
speculate = get_speculate()
|
||||||
|
self.speculate = speculate
|
||||||
|
|
||||||
self.has_position_ids = (
|
self.has_position_ids = (
|
||||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||||
is not None
|
is not None
|
||||||
@ -50,6 +56,7 @@ class Model(ABC):
|
|||||||
dtype=str(self.dtype),
|
dtype=str(self.dtype),
|
||||||
device_type=self.device.type,
|
device_type=self.device.type,
|
||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
|
speculate=self.speculate
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user