mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Update mod.rs
feat(router): add automatic input token truncation to respect max_input_tokens limit
This commit is contained in:
parent
58848cb471
commit
7ca47777aa
@ -23,6 +23,10 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::instrument;
|
||||
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Backend {
|
||||
fn schedule(
|
||||
@ -48,6 +52,18 @@ pub struct Infer {
|
||||
backend_health: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let input_text = "Your input text here...";
|
||||
let max_new_tokens = 100;
|
||||
|
||||
match send_request_to_tgi(input_text, max_new_tokens).await {
|
||||
Ok(response) => println!("Generated Text: {}", response.generated_text),
|
||||
Err(e) => eprintln!("Error: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl Infer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
@ -296,6 +312,45 @@ impl Infer {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RequestPayload {
|
||||
inputs: String,
|
||||
parameters: Parameters,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Parameters {
|
||||
max_new_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ResponseData {
|
||||
generated_text: String,
|
||||
}
|
||||
|
||||
async fn send_request_to_tgi(input_text: &str, max_new_tokens: u32) -> Result<ResponseData, Box<dyn Error>> {
|
||||
let url = "http://localhost:8080/generate_stream";
|
||||
|
||||
let payload = RequestPayload {
|
||||
inputs: input_text.to_string(),
|
||||
parameters: Parameters { max_new_tokens },
|
||||
};
|
||||
|
||||
let client = Client::new();
|
||||
let response = client
|
||||
.post(url)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("Request failed with status: {}", response.status()).into());
|
||||
}
|
||||
|
||||
let response_data = response.json::<ResponseData>().await?;
|
||||
Ok(response_data)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GeneratedText {
|
||||
pub text: String,
|
||||
|
Loading…
Reference in New Issue
Block a user