mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
added streaming endpoint
This commit is contained in:
parent
1539d3cbbe
commit
d37b2d3fb9
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -1801,6 +1801,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"prost",
|
"prost",
|
||||||
|
"serde",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tonic",
|
"tonic",
|
||||||
@ -1829,6 +1830,7 @@ dependencies = [
|
|||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
"clap 4.0.22",
|
"clap 4.0.22",
|
||||||
"futures",
|
"futures",
|
||||||
|
@ -84,6 +84,13 @@ message GeneratedText {
|
|||||||
string finish_reason = 7;
|
string finish_reason = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message Intermediate {
|
||||||
|
uint64 request_id = 1;
|
||||||
|
uint32 token_id = 2;
|
||||||
|
float logprob = 3;
|
||||||
|
string token = 4;
|
||||||
|
}
|
||||||
|
|
||||||
message GenerateRequest {
|
message GenerateRequest {
|
||||||
/// Batch
|
/// Batch
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
@ -94,6 +101,8 @@ message GenerateResponse {
|
|||||||
repeated GeneratedText generated_texts = 1;
|
repeated GeneratedText generated_texts = 1;
|
||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional Batch batch = 2;
|
optional Batch batch = 2;
|
||||||
|
|
||||||
|
repeated Intermediate intermediates = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateWithCacheRequest {
|
message GenerateWithCacheRequest {
|
||||||
@ -106,4 +115,6 @@ message GenerateWithCacheResponse {
|
|||||||
repeated GeneratedText generated_texts = 1;
|
repeated GeneratedText generated_texts = 1;
|
||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional Batch batch = 2;
|
optional Batch batch = 2;
|
||||||
|
|
||||||
|
repeated Intermediate intermediates = 3;
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,8 @@ serde = "1.0.145"
|
|||||||
serde_json = "1.0.85"
|
serde_json = "1.0.85"
|
||||||
thiserror = "1.0.37"
|
thiserror = "1.0.37"
|
||||||
tokenizers = "0.13.0"
|
tokenizers = "0.13.0"
|
||||||
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "net"] }
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.36"
|
||||||
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
||||||
|
async-stream = "0.3.3"
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ tonic = "^0.6"
|
|||||||
tower = "^0.4"
|
tower = "^0.4"
|
||||||
tracing = "^0.1"
|
tracing = "^0.1"
|
||||||
tracing-error = "^0.2"
|
tracing-error = "^0.2"
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
tonic-build = "0.6.2"
|
tonic-build = "0.6.2"
|
||||||
|
@ -73,7 +73,7 @@ impl Client {
|
|||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
/// Returns a list of generated texts of request that met their stopping criteria
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
|
||||||
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
|
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
|
||||||
let response = self
|
let response = self
|
||||||
.stub
|
.stub
|
||||||
@ -81,7 +81,7 @@ impl Client {
|
|||||||
.instrument(info_span!("generate"))
|
.instrument(info_span!("generate"))
|
||||||
.await?
|
.await?
|
||||||
.into_inner();
|
.into_inner();
|
||||||
Ok((response.generated_texts, response.batch))
|
Ok((response.generated_texts, response.batch, response.intermediates))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given cached batch
|
/// Generate one token for each request in the given cached batch
|
||||||
@ -92,7 +92,7 @@ impl Client {
|
|||||||
pub async fn generate_with_cache(
|
pub async fn generate_with_cache(
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
|
||||||
let request = tonic::Request::new(GenerateWithCacheRequest { batches });
|
let request = tonic::Request::new(GenerateWithCacheRequest { batches });
|
||||||
let response = self
|
let response = self
|
||||||
.stub
|
.stub
|
||||||
@ -100,6 +100,6 @@ impl Client {
|
|||||||
.instrument(info_span!("generate_with_cache"))
|
.instrument(info_span!("generate_with_cache"))
|
||||||
.await?
|
.await?
|
||||||
.into_inner();
|
.into_inner();
|
||||||
Ok((response.generated_texts, response.batch))
|
Ok((response.generated_texts, response.batch, response.intermediates))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ mod sharded_client;
|
|||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v1::{
|
||||||
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
@ -35,3 +35,10 @@ impl From<transport::Error> for ClientError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
pub type Result<T> = std::result::Result<T, ClientError>;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct IntermediateEvent {
|
||||||
|
pub token: String,
|
||||||
|
pub token_id: u32,
|
||||||
|
pub logprob: f32,
|
||||||
|
}
|
@ -1,6 +1,6 @@
|
|||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use crate::{Batch, Client, GeneratedText};
|
use crate::{Batch, Client, GeneratedText, Intermediate};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use futures::future::select_all;
|
use futures::future::select_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
@ -41,7 +41,7 @@ impl ShardedClient {
|
|||||||
///
|
///
|
||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
/// Returns a list of generated texts of request that met their stopping criteria
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
@ -59,7 +59,7 @@ impl ShardedClient {
|
|||||||
pub async fn generate_with_cache(
|
pub async fn generate_with_cache(
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
|
@ -4,11 +4,13 @@ use crate::{ErrorResponse, GenerateRequest};
|
|||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use axum::Json;
|
use axum::Json;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
|
|
||||||
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient, Intermediate};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{oneshot, Notify};
|
use tokio::sync::{oneshot, Notify, mpsc};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
@ -51,6 +53,36 @@ impl Batcher {
|
|||||||
Self { db, shared }
|
Self { db, shared }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add a new request to the database and return a future that will generate the text
|
||||||
|
pub(crate) fn infer_stream(
|
||||||
|
&self,
|
||||||
|
input_length: usize,
|
||||||
|
request: GenerateRequest,
|
||||||
|
intermediate_tx: mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>,
|
||||||
|
response_tx: oneshot::Sender<Result<InferResponse, ClientError>>,
|
||||||
|
) {
|
||||||
|
// Try to append the request to the database
|
||||||
|
self.db.append(Entry {
|
||||||
|
request,
|
||||||
|
response_tx,
|
||||||
|
intermediate_tx: Some(intermediate_tx),
|
||||||
|
input_length,
|
||||||
|
time: Instant::now(),
|
||||||
|
batch_time: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Notify the background task that we have a new entry in the database that needs
|
||||||
|
// to be batched
|
||||||
|
self.shared.batching_task.notify_one();
|
||||||
|
|
||||||
|
// // Await on the response from the background task
|
||||||
|
// // We can safely unwrap as the background task will never drop the sender
|
||||||
|
// response_rx
|
||||||
|
// .await
|
||||||
|
// .unwrap()
|
||||||
|
// .map_err(|err| InferError::GenerationError(err.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
/// Add a new request to the database and return a future that will generate the text
|
/// Add a new request to the database and return a future that will generate the text
|
||||||
pub(crate) async fn infer(
|
pub(crate) async fn infer(
|
||||||
&self,
|
&self,
|
||||||
@ -64,6 +96,7 @@ impl Batcher {
|
|||||||
self.db.append(Entry {
|
self.db.append(Entry {
|
||||||
request,
|
request,
|
||||||
response_tx,
|
response_tx,
|
||||||
|
intermediate_tx: None,
|
||||||
input_length,
|
input_length,
|
||||||
time: Instant::now(),
|
time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
@ -152,12 +185,12 @@ async fn batching_task(
|
|||||||
|
|
||||||
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
||||||
async fn wrap_future(
|
async fn wrap_future(
|
||||||
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>), ClientError>>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
match future.await {
|
match future.await {
|
||||||
Ok((generated_texts, next_batch)) => {
|
Ok((generated_texts, next_batch, intermediates)) => {
|
||||||
send_generated(generated_texts, entries);
|
send_generated(generated_texts, intermediates, entries);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
@ -177,12 +210,26 @@ fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Send `generated_text` to the Batcher for all `finished`
|
/// Send `generated_text` to the Batcher for all `finished`
|
||||||
fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>) {
|
fn send_generated(finished: Vec<GeneratedText>, intermediates: Vec<Intermediate>, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
intermediates.into_iter().for_each(|intermediate| {
|
||||||
|
// We can `expect` here as the request id should always be in the DB
|
||||||
|
let entry = entries.get(&intermediate.request_id).expect("ID not found in db. This is a bug.");
|
||||||
|
|
||||||
|
|
||||||
|
if let Some(tx) = &entry.intermediate_tx {
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
tx.send(Ok(Some(intermediate))).unwrap_or(());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
finished.into_iter().for_each(|output| {
|
finished.into_iter().for_each(|output| {
|
||||||
// We can `expect` here as the request id should always be in the entries
|
// We can `expect` here as the request id should always be in the entries
|
||||||
let entry = entries
|
let entry = entries
|
||||||
.remove(&output.request.unwrap().id)
|
.remove(&output.request.unwrap().id)
|
||||||
.expect("ID not found in entries. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
if let Some(tx) = &entry.intermediate_tx {
|
||||||
|
tx.send(Ok(None)).unwrap_or(());
|
||||||
|
}
|
||||||
|
|
||||||
let response = InferResponse {
|
let response = InferResponse {
|
||||||
output_text: output.output_text,
|
output_text: output.output_text,
|
||||||
|
@ -6,7 +6,7 @@ use parking_lot::Mutex;
|
|||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate,
|
||||||
};
|
};
|
||||||
use tokio::sync::oneshot::Sender;
|
use tokio::sync::oneshot::Sender;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
@ -18,6 +18,8 @@ pub(crate) struct Entry {
|
|||||||
pub request: GenerateRequest,
|
pub request: GenerateRequest,
|
||||||
/// Response sender to communicate between the Batcher and the batching_task
|
/// Response sender to communicate between the Batcher and the batching_task
|
||||||
pub response_tx: Sender<Result<InferResponse, ClientError>>,
|
pub response_tx: Sender<Result<InferResponse, ClientError>>,
|
||||||
|
/// Intermediate sender to communicate between the Batcher and the batching_task
|
||||||
|
pub intermediate_tx: Option<tokio::sync::mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>>,
|
||||||
/// Number of tokens in the input
|
/// Number of tokens in the input
|
||||||
pub input_length: usize,
|
pub input_length: usize,
|
||||||
/// Instant when this entry was created
|
/// Instant when this entry was created
|
||||||
|
@ -8,12 +8,17 @@ use axum::routing::{get, post};
|
|||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::{ShardedClient, IntermediateEvent};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::sync::Semaphore;
|
use tokio::sync::Semaphore;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
use tokio::sync::{oneshot, mpsc};
|
||||||
|
|
||||||
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
|
use std::convert::Infallible;
|
||||||
|
use futures::stream::Stream;
|
||||||
|
|
||||||
// Server shared state
|
// Server shared state
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -62,6 +67,80 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
struct StreamEvent {
|
||||||
|
is_end: bool,
|
||||||
|
event: Option<IntermediateEvent>,
|
||||||
|
generated_text: Option<GeneratedText>,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
state: Extension<ServerState>,
|
||||||
|
req: Json<GenerateRequest>,
|
||||||
|
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
||||||
|
let (intermediate_tx, mut intermediate_rx) = mpsc::unbounded_channel();
|
||||||
|
let (response_tx, response_rx) = oneshot::channel();
|
||||||
|
|
||||||
|
let (input_length, validated_request) =
|
||||||
|
state.validation.validate(req.0).await.map_err(|err| {
|
||||||
|
tracing::error!("{}", err.to_string());
|
||||||
|
err
|
||||||
|
}).unwrap();
|
||||||
|
|
||||||
|
// Inference
|
||||||
|
state.batcher.infer_stream(input_length, validated_request, intermediate_tx, response_tx);
|
||||||
|
|
||||||
|
let stream = async_stream::stream! {
|
||||||
|
while let Some(item) = intermediate_rx.recv().await {
|
||||||
|
match item {
|
||||||
|
Ok(item) => {
|
||||||
|
match item {
|
||||||
|
Some(item) => {
|
||||||
|
let event_data = IntermediateEvent {
|
||||||
|
token: item.token,
|
||||||
|
token_id: item.token_id,
|
||||||
|
logprob: item.logprob,
|
||||||
|
};
|
||||||
|
let stream_event = StreamEvent {
|
||||||
|
is_end: false,
|
||||||
|
event: Some(event_data),
|
||||||
|
generated_text: None,
|
||||||
|
};
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
yield Ok(Event::default().data(err.to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let response = response_rx.await.unwrap();
|
||||||
|
match response {
|
||||||
|
Ok(response) => {
|
||||||
|
let response = GeneratedText {
|
||||||
|
generated_text: response.output_text,
|
||||||
|
details: None,
|
||||||
|
};
|
||||||
|
let stream_event = StreamEvent {
|
||||||
|
is_end: true,
|
||||||
|
event: None,
|
||||||
|
generated_text: Some(response),
|
||||||
|
};
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
yield Ok(Event::default().data(err.to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate method
|
/// Generate method
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(state),
|
skip(state),
|
||||||
@ -197,6 +276,7 @@ pub async fn run(
|
|||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/", post(generate))
|
.route("/", post(generate))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.layer(Extension(shared_state.clone()));
|
.layer(Extension(shared_state.clone()));
|
||||||
|
@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
|
|||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation.models import Model
|
||||||
from text_generation.models.types import GeneratedText, Batch
|
from text_generation.models.types import GeneratedText, Batch, Intermediate
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
||||||
|
|
||||||
@ -314,6 +314,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# Finished requests
|
# Finished requests
|
||||||
generated_texts: List[GeneratedText] = []
|
generated_texts: List[GeneratedText] = []
|
||||||
|
intermediates: List[Intermediate] = []
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
@ -352,12 +353,23 @@ class CausalLM(Model):
|
|||||||
next_token_logprob = logprobs[-1, next_token]
|
next_token_logprob = logprobs[-1, next_token]
|
||||||
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
|
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
|
||||||
|
|
||||||
|
next_token_sq = next_token.squeeze()
|
||||||
|
decoded_next_token = self.tokenizer.decode(
|
||||||
|
next_token_sq, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
next_token_logprob = all_logprobs[-1]
|
||||||
|
intermediate = Intermediate(
|
||||||
|
request_id=request.id,
|
||||||
|
token=decoded_next_token,
|
||||||
|
logprob=next_token_logprob.item(),
|
||||||
|
token_id=next_token_sq.item(),
|
||||||
|
)
|
||||||
|
intermediates.append(intermediate.to_pb())
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
next_token.squeeze(),
|
next_token_sq,
|
||||||
self.tokenizer.decode(
|
decoded_next_token,
|
||||||
next_token.squeeze(), clean_up_tokenization_spaces=False
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
@ -399,7 +411,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if not next_batch_keep_indices:
|
if not next_batch_keep_indices:
|
||||||
return generated_texts, None
|
return generated_texts, None, intermediates
|
||||||
|
|
||||||
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
||||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||||
@ -459,4 +471,4 @@ class CausalLM(Model):
|
|||||||
max_sequence_length=next_batch_max_sequence_length,
|
max_sequence_length=next_batch_max_sequence_length,
|
||||||
keys_head_dim_last=batch.keys_head_dim_last,
|
keys_head_dim_last=batch.keys_head_dim_last,
|
||||||
)
|
)
|
||||||
return generated_texts, next_batch
|
return generated_texts, next_batch, intermediates
|
||||||
|
@ -50,3 +50,19 @@ class GeneratedText:
|
|||||||
logprobs=self.logprobs,
|
logprobs=self.logprobs,
|
||||||
finish_reason=self.reason,
|
finish_reason=self.reason,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Intermediate:
|
||||||
|
request_id: int
|
||||||
|
token_id: int
|
||||||
|
logprob: float
|
||||||
|
token: str
|
||||||
|
|
||||||
|
def to_pb(self) -> generate_pb2.Intermediate:
|
||||||
|
return generate_pb2.Intermediate(
|
||||||
|
request_id = self.request_id,
|
||||||
|
token_id = self.token_id,
|
||||||
|
logprob = self.logprob,
|
||||||
|
token = self.token,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
request.batch, self.model.tokenizer, self.model.device
|
request.batch, self.model.tokenizer, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_texts, next_batch = self.model.generate_token(batch)
|
generated_texts, next_batch, intermediates = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.GenerateResponse(
|
return generate_pb2.GenerateResponse(
|
||||||
@ -40,6 +40,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
generated_text.to_pb() for generated_text in generated_texts
|
generated_text.to_pb() for generated_text in generated_texts
|
||||||
],
|
],
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
|
intermediates=intermediates,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def GenerateWithCache(self, request, context):
|
async def GenerateWithCache(self, request, context):
|
||||||
@ -58,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
generated_texts, next_batch = self.model.generate_token(batch)
|
generated_texts, next_batch, intermediates = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.GenerateWithCacheResponse(
|
return generate_pb2.GenerateWithCacheResponse(
|
||||||
@ -66,6 +67,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
generated_text.to_pb() for generated_text in generated_texts
|
generated_text.to_pb() for generated_text in generated_texts
|
||||||
],
|
],
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
|
intermediates=intermediates,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user