mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix tests
This commit is contained in:
parent
248eda7b20
commit
6f9366556a
@ -8,8 +8,6 @@ use std::time::Duration;
|
|||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/// Text Generation Inference gRPC client
|
/// Text Generation Inference gRPC client
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
@ -163,7 +161,11 @@ impl Client {
|
|||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
Ok((response.generations, response.batch, PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns)))
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given cached batches
|
/// Generate one token for each request in the given cached batches
|
||||||
@ -177,7 +179,16 @@ impl Client {
|
|||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||||
let response = self.stub.decode(request).await?.into_inner();
|
let response = self.stub.decode(request).await?.into_inner();
|
||||||
Ok((response.generations, response.batch, DecodeTimings::new(response.concat_ns, response.forward_ns, response.decode_ns, response.total_ns)))
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
DecodeTimings::new(
|
||||||
|
response.concat_ns,
|
||||||
|
response.forward_ns,
|
||||||
|
response.decode_ns,
|
||||||
|
response.total_ns,
|
||||||
|
),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
|
use crate::client::{DecodeTimings, PrefillTimings};
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
|
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
|
||||||
use crate::{ClientError, Result};
|
use crate::{ClientError, Result};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
use crate::client::{DecodeTimings, PrefillTimings};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
/// Text Generation Inference gRPC multi client
|
/// Text Generation Inference gRPC multi client
|
||||||
@ -131,7 +131,8 @@ impl ShardedClient {
|
|||||||
join_all(futures).await.into_iter().collect();
|
join_all(futures).await.into_iter().collect();
|
||||||
let mut results = results?;
|
let mut results = results?;
|
||||||
|
|
||||||
let (mut generations, next_batch, mut timings) = results.pop().ok_or(ClientError::EmptyResults)?;
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
// Merge generations from different model shards
|
// Merge generations from different model shards
|
||||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
@ -162,7 +163,8 @@ impl ShardedClient {
|
|||||||
join_all(futures).await.into_iter().collect();
|
join_all(futures).await.into_iter().collect();
|
||||||
let mut results = results?;
|
let mut results = results?;
|
||||||
|
|
||||||
let (mut generations, next_batch, mut timings) = results.pop().ok_or(ClientError::EmptyResults)?;
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
// Merge generations from different model shards
|
// Merge generations from different model shards
|
||||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
@ -229,6 +229,15 @@ impl Validation {
|
|||||||
stop_sequences.len(),
|
stop_sequences.len(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
for s in stop_sequences {
|
||||||
|
if s.chars().len() > 50 {
|
||||||
|
return Err(ValidationError::StopSequence(
|
||||||
|
self.max_stop_sequences,
|
||||||
|
stop_sequences.len(),
|
||||||
|
));
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If seed is None, assign a random one
|
// If seed is None, assign a random one
|
||||||
let seed = match seed {
|
let seed = match seed {
|
||||||
|
@ -943,6 +943,7 @@ class FlashCausalLM(Model):
|
|||||||
# GPU <-> CPU sync
|
# GPU <-> CPU sync
|
||||||
next_token_logprobs = next_token_logprobs.tolist()
|
next_token_logprobs = next_token_logprobs.tolist()
|
||||||
next_token_ids = next_input_ids.tolist()
|
next_token_ids = next_input_ids.tolist()
|
||||||
|
accepted_ids = accepted_ids.tolist()
|
||||||
start_decode = time.time_ns()
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
@ -980,7 +981,6 @@ class FlashCausalLM(Model):
|
|||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
left = 0
|
left = 0
|
||||||
before = stopping_criteria.current_tokens
|
|
||||||
|
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
@ -1095,7 +1095,7 @@ class FlashCausalLM(Model):
|
|||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = input_length + n_accepted_ids.item()
|
batch.input_lengths[i] = input_length + n_accepted_ids
|
||||||
if batch.input_lengths[i] > batch.max_seqlen:
|
if batch.input_lengths[i] > batch.max_seqlen:
|
||||||
batch.max_seqlen = batch.input_lengths[i]
|
batch.max_seqlen = batch.input_lengths[i]
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
|
@ -92,7 +92,7 @@ class NextTokenChooser:
|
|||||||
class StopSequenceCriteria:
|
class StopSequenceCriteria:
|
||||||
def __init__(self, stop_sequence: str):
|
def __init__(self, stop_sequence: str):
|
||||||
stop_sequence = re.escape(stop_sequence)
|
stop_sequence = re.escape(stop_sequence)
|
||||||
self.regex = re.compile(f".*{stop_sequence}$")
|
self.regex = re.compile(f"{stop_sequence}$")
|
||||||
|
|
||||||
def __call__(self, output: str) -> bool:
|
def __call__(self, output: str) -> bool:
|
||||||
if self.regex.findall(output):
|
if self.regex.findall(output):
|
||||||
|
Loading…
Reference in New Issue
Block a user