mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix tests
This commit is contained in:
parent
248eda7b20
commit
6f9366556a
@ -8,8 +8,6 @@ use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
@ -163,7 +161,11 @@ impl Client {
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((response.generations, response.batch, PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns)))
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||
))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
@ -177,7 +179,16 @@ impl Client {
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((response.generations, response.batch, DecodeTimings::new(response.concat_ns, response.forward_ns, response.decode_ns, response.total_ns)))
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
DecodeTimings::new(
|
||||
response.concat_ns,
|
||||
response.forward_ns,
|
||||
response.decode_ns,
|
||||
response.total_ns,
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@ -213,4 +224,4 @@ impl DecodeTimings {
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
use crate::client::{DecodeTimings, PrefillTimings};
|
||||
/// Multi shard Client
|
||||
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
|
||||
use crate::{ClientError, Result};
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
use crate::client::{DecodeTimings, PrefillTimings};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
@ -131,7 +131,8 @@ impl ShardedClient {
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) = results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
@ -162,7 +163,8 @@ impl ShardedClient {
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) = results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
|
@ -229,6 +229,15 @@ impl Validation {
|
||||
stop_sequences.len(),
|
||||
));
|
||||
}
|
||||
for s in stop_sequences {
|
||||
if s.chars().len() > 50 {
|
||||
return Err(ValidationError::StopSequence(
|
||||
self.max_stop_sequences,
|
||||
stop_sequences.len(),
|
||||
));
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// If seed is None, assign a random one
|
||||
let seed = match seed {
|
||||
|
@ -943,6 +943,7 @@ class FlashCausalLM(Model):
|
||||
# GPU <-> CPU sync
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
next_token_ids = next_input_ids.tolist()
|
||||
accepted_ids = accepted_ids.tolist()
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Zipped iterator
|
||||
@ -980,7 +981,6 @@ class FlashCausalLM(Model):
|
||||
# Append next token to all tokens
|
||||
next_token_texts = []
|
||||
left = 0
|
||||
before = stopping_criteria.current_tokens
|
||||
|
||||
current_stopped = False
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
@ -1095,7 +1095,7 @@ class FlashCausalLM(Model):
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.input_lengths[i] = input_length + n_accepted_ids.item()
|
||||
batch.input_lengths[i] = input_length + n_accepted_ids
|
||||
if batch.input_lengths[i] > batch.max_seqlen:
|
||||
batch.max_seqlen = batch.input_lengths[i]
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
|
@ -92,7 +92,7 @@ class NextTokenChooser:
|
||||
class StopSequenceCriteria:
|
||||
def __init__(self, stop_sequence: str):
|
||||
stop_sequence = re.escape(stop_sequence)
|
||||
self.regex = re.compile(f".*{stop_sequence}$")
|
||||
self.regex = re.compile(f"{stop_sequence}$")
|
||||
|
||||
def __call__(self, output: str) -> bool:
|
||||
if self.regex.findall(output):
|
||||
|
Loading…
Reference in New Issue
Block a user