mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Improve logging from python shards
I found these changes necessary to be able to do reasonable problem diagnosis. - Ensure output from the shards makes its way to the launcher's stdout - Log full exceptions raised by server rpc functions - Fix stderr logging in integration_tests.rs
This commit is contained in:
parent
3efa5bbbfd
commit
337e1f8795
@ -299,6 +299,10 @@ fn shard_manager(
|
|||||||
"SAFETENSORS_FAST_GPU".parse().unwrap(),
|
"SAFETENSORS_FAST_GPU".parse().unwrap(),
|
||||||
"1".to_string().parse().unwrap(),
|
"1".to_string().parse().unwrap(),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"PYTHONUNBUFFERED".parse().unwrap(),
|
||||||
|
"1".to_string().parse().unwrap(),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
||||||
@ -347,6 +351,12 @@ fn shard_manager(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Redirect STDOUT to the console
|
||||||
|
let shard_stdout = p.stdout.take().unwrap();
|
||||||
|
thread::spawn(move || BufReader::new(shard_stdout).lines().for_each(|line|
|
||||||
|
println!("Shard {}: {}", rank, line.unwrap())
|
||||||
|
));
|
||||||
|
|
||||||
let mut ready = false;
|
let mut ready = false;
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let mut wait_time = Instant::now();
|
let mut wait_time = Instant::now();
|
||||||
|
@ -41,25 +41,21 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
|
|||||||
&argv,
|
&argv,
|
||||||
PopenConfig {
|
PopenConfig {
|
||||||
stdout: Redirection::Pipe,
|
stdout: Redirection::Pipe,
|
||||||
stderr: Redirection::Pipe,
|
stderr: Redirection::Merge,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.expect("Could not start launcher");
|
.expect("Could not start launcher");
|
||||||
|
|
||||||
// Redirect STDOUT and STDERR to the console
|
// Redirect STDOUT and STDERR to the console
|
||||||
|
// (STDERR is merged into STDOUT)
|
||||||
let launcher_stdout = launcher.stdout.take().unwrap();
|
let launcher_stdout = launcher.stdout.take().unwrap();
|
||||||
let launcher_stderr = launcher.stderr.take().unwrap();
|
|
||||||
|
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
let stdout = BufReader::new(launcher_stdout);
|
let stdout = BufReader::new(launcher_stdout);
|
||||||
let stderr = BufReader::new(launcher_stderr);
|
|
||||||
for line in stdout.lines() {
|
for line in stdout.lines() {
|
||||||
println!("{}", line.unwrap());
|
println!("{}", line.unwrap());
|
||||||
}
|
}
|
||||||
for line in stderr.lines() {
|
|
||||||
println!("{}", line.unwrap());
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
for _ in 0..60 {
|
for _ in 0..60 {
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from grpc import aio
|
from grpc import aio
|
||||||
@ -12,6 +13,16 @@ from text_generation.models import Model, get_model
|
|||||||
from text_generation.pb import generate_pb2_grpc, generate_pb2
|
from text_generation.pb import generate_pb2_grpc, generate_pb2
|
||||||
|
|
||||||
|
|
||||||
|
def log_errs(func):
|
||||||
|
async def func_with_log(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return await func(*args,**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(f"{func.__name__} failed")
|
||||||
|
raise e
|
||||||
|
return func_with_log
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
@ -21,10 +32,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
async def ServiceDiscovery(self, request, context):
|
async def ServiceDiscovery(self, request, context):
|
||||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||||
|
|
||||||
|
@log_errs
|
||||||
async def ClearCache(self, request, context):
|
async def ClearCache(self, request, context):
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
return generate_pb2.ClearCacheResponse()
|
return generate_pb2.ClearCacheResponse()
|
||||||
|
|
||||||
|
@log_errs
|
||||||
async def Generate(self, request, context):
|
async def Generate(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.device
|
request.batch, self.model.tokenizer, self.model.device
|
||||||
@ -40,6 +53,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@log_errs
|
||||||
async def GenerateWithCache(self, request, context):
|
async def GenerateWithCache(self, request, context):
|
||||||
if len(request.batches) == 0:
|
if len(request.batches) == 0:
|
||||||
raise ValueError("Must provide at least one batch")
|
raise ValueError("Must provide at least one batch")
|
||||||
|
Loading…
Reference in New Issue
Block a user