fix tests

This commit is contained in:
OlivierDehaene 2023-12-14 14:46:44 +01:00
parent 248eda7b20
commit 6f9366556a
5 changed files with 33 additions and 11 deletions

View File

@ -8,8 +8,6 @@ use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
@ -163,7 +161,11 @@ impl Client {
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((response.generations, response.batch, PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns)))
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
}
/// Generate one token for each request in the given cached batches
@ -177,7 +179,16 @@ impl Client {
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((response.generations, response.batch, DecodeTimings::new(response.concat_ns, response.forward_ns, response.decode_ns, response.total_ns)))
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}

View File

@ -1,10 +1,10 @@
use crate::client::{DecodeTimings, PrefillTimings};
/// Multi shard Client
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
use crate::{ClientError, Result};
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
use crate::client::{DecodeTimings, PrefillTimings};
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
@ -131,7 +131,8 @@ impl ShardedClient {
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) = results.pop().ok_or(ClientError::EmptyResults)?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
@ -162,7 +163,8 @@ impl ShardedClient {
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) = results.pop().ok_or(ClientError::EmptyResults)?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {

View File

@ -229,6 +229,15 @@ impl Validation {
stop_sequences.len(),
));
}
for s in stop_sequences {
if s.chars().len() > 50 {
return Err(ValidationError::StopSequence(
self.max_stop_sequences,
stop_sequences.len(),
));
}
}
// If seed is None, assign a random one
let seed = match seed {

View File

@ -943,6 +943,7 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist()
start_decode = time.time_ns()
# Zipped iterator
@ -980,7 +981,6 @@ class FlashCausalLM(Model):
# Append next token to all tokens
next_token_texts = []
left = 0
before = stopping_criteria.current_tokens
current_stopped = False
for j in range(index, index + n_accepted_ids):
@ -1095,7 +1095,7 @@ class FlashCausalLM(Model):
generations.append(generation)
# Update values
batch.input_lengths[i] = input_length + n_accepted_ids.item()
batch.input_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
batch.prefix_offsets[i] = prefix_offset

View File

@ -92,7 +92,7 @@ class NextTokenChooser:
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
stop_sequence = re.escape(stop_sequence)
self.regex = re.compile(f".*{stop_sequence}$")
self.regex = re.compile(f"{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):