mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve logs by passing span to internal functions
This commit is contained in:
parent
bd28c36815
commit
a7bf3196d4
@ -164,6 +164,15 @@ async fn generate(
|
|||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_internal(
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
ComputeType(compute_type): ComputeType,
|
||||||
|
Json(req): Json<GenerateRequest>,
|
||||||
|
span: tracing::Span,
|
||||||
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
@ -362,12 +371,13 @@ async fn generate_stream(
|
|||||||
HeaderMap,
|
HeaderMap,
|
||||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||||
) {
|
) {
|
||||||
|
let span = tracing::Span::current();
|
||||||
let on_message_callback = |stream_token: StreamResponse| {
|
let on_message_callback = |stream_token: StreamResponse| {
|
||||||
let event = Event::default();
|
let event = Event::default();
|
||||||
event.json_data(stream_token).unwrap()
|
event.json_data(stream_token).unwrap()
|
||||||
};
|
};
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
|
generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await;
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
(headers, sse)
|
(headers, sse)
|
||||||
}
|
}
|
||||||
@ -377,8 +387,8 @@ async fn generate_stream_internal(
|
|||||||
ComputeType(compute_type): ComputeType,
|
ComputeType(compute_type): ComputeType,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||||
|
span: tracing::Span,
|
||||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||||
let span = tracing::Span::current();
|
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
@ -584,6 +594,7 @@ async fn completions(
|
|||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
Json(req): Json<CompletionRequest>,
|
Json(req): Json<CompletionRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
let stream = req.stream;
|
let stream = req.stream;
|
||||||
@ -657,7 +668,7 @@ async fn completions(
|
|||||||
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||||
let infer_clone = infer.clone();
|
let infer_clone = infer.clone();
|
||||||
let compute_type_clone = compute_type.clone();
|
let compute_type_clone = compute_type.clone();
|
||||||
let params_clone = generate_request.parameters.clone();
|
let span_clone = span.clone();
|
||||||
|
|
||||||
// Create a future for each generate_stream_internal call.
|
// Create a future for each generate_stream_internal call.
|
||||||
let generate_future = async move {
|
let generate_future = async move {
|
||||||
@ -691,32 +702,28 @@ async fn completions(
|
|||||||
let (header_tx, header_rx) = oneshot::channel();
|
let (header_tx, header_rx) = oneshot::channel();
|
||||||
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
|
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
tokio::spawn(
|
tokio::spawn(async move {
|
||||||
async move {
|
let (header_map, sse) = generate_stream_internal(
|
||||||
let (header_map, sse) = generate_stream_internal(
|
infer_clone.clone(),
|
||||||
infer_clone.clone(),
|
compute_type_clone.clone(),
|
||||||
compute_type_clone.clone(),
|
Json(generate_request),
|
||||||
Json(generate_request),
|
on_message_callback,
|
||||||
on_message_callback,
|
span_clone.clone(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// send and dont wait for response
|
// send and dont wait for response
|
||||||
let _ = header_tx.send(header_map);
|
let _ = header_tx.send(header_map);
|
||||||
|
|
||||||
// pin an emit messages to the sse_tx
|
// pin an emit messages to the sse_tx
|
||||||
let mut sse = Box::pin(sse);
|
let mut sse = Box::pin(sse);
|
||||||
while let Some(event) = sse.next().await {
|
while let Some(event) = sse.next().await {
|
||||||
if sse_tx.send(event).is_err() {
|
if sse_tx.send(event).is_err() {
|
||||||
tracing::error!("Failed to send event. Receiver dropped.");
|
tracing::error!("Failed to send event. Receiver dropped.");
|
||||||
break;
|
break;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.instrument(
|
});
|
||||||
tracing::info_span!("request", index = %index, parameters = ?params_clone),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
|
|
||||||
(header_rx, sse_rx)
|
(header_rx, sse_rx)
|
||||||
};
|
};
|
||||||
@ -802,17 +809,17 @@ async fn completions(
|
|||||||
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||||
let infer_clone = infer.clone();
|
let infer_clone = infer.clone();
|
||||||
let compute_type_clone = compute_type.clone();
|
let compute_type_clone = compute_type.clone();
|
||||||
let params_clone = generate_request.parameters.clone();
|
let span_clone = span.clone();
|
||||||
let response_future = async move {
|
let response_future = async move {
|
||||||
let result = generate(
|
let result = generate_internal(
|
||||||
Extension(infer_clone),
|
Extension(infer_clone),
|
||||||
Extension(compute_type_clone),
|
compute_type_clone,
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
|
span_clone,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
result.map(|(headers, generation)| (index, headers, generation))
|
result.map(|(headers, generation)| (index, headers, generation))
|
||||||
}
|
};
|
||||||
.instrument(tracing::info_span!("request", index = %index, parameters = ?params_clone));
|
|
||||||
responses.push(response_future);
|
responses.push(response_future);
|
||||||
}
|
}
|
||||||
let generate_responses = responses.try_collect::<Vec<_>>().await?;
|
let generate_responses = responses.try_collect::<Vec<_>>().await?;
|
||||||
@ -979,6 +986,7 @@ async fn chat_completions(
|
|||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
Json(req): Json<ChatRequest>,
|
Json(req): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
let ChatRequest {
|
let ChatRequest {
|
||||||
@ -1116,17 +1124,14 @@ async fn chat_completions(
|
|||||||
compute_type,
|
compute_type,
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
on_message_callback,
|
on_message_callback,
|
||||||
|
span,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) = generate(
|
let (headers, Json(generation)) =
|
||||||
Extension(infer),
|
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
||||||
Extension(compute_type),
|
|
||||||
Json(generate_request),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
@ -1223,6 +1228,7 @@ async fn vertex_compatibility(
|
|||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Json(req): Json<VertexRequest>,
|
Json(req): Json<VertexRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
// check that theres at least one instance
|
// check that theres at least one instance
|
||||||
@ -1254,10 +1260,11 @@ async fn vertex_compatibility(
|
|||||||
};
|
};
|
||||||
|
|
||||||
async {
|
async {
|
||||||
generate(
|
generate_internal(
|
||||||
Extension(infer.clone()),
|
Extension(infer.clone()),
|
||||||
Extension(compute_type.clone()),
|
compute_type.clone(),
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
|
span.clone(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map(|(_, Json(generation))| generation.generated_text)
|
.map(|(_, Json(generation))| generation.generated_text)
|
||||||
|
Loading…
Reference in New Issue
Block a user