mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix: Some small fixes
- Avoid theoretical hang in batcher loop - Avoid a couple of clones in py server generate method - Keep attention mask tensors as integers
This commit is contained in:
parent
daa1d81d5e
commit
a172430d8b
@ -105,7 +105,7 @@ async fn batching_task(
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the DB
|
||||
let mut waiting_tokens = 0;
|
||||
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
|
||||
while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
|
||||
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
||||
waiting_tokens += 1;
|
||||
|
||||
|
@ -90,11 +90,7 @@ async fn generate(
|
||||
// Validate request
|
||||
let (input_length, validated_request) = state
|
||||
.validation
|
||||
// FIXME: can't we get rid of the cloning here??
|
||||
.validate(GenerateRequest {
|
||||
inputs: req.inputs.clone(),
|
||||
parameters: req.parameters.clone(),
|
||||
})
|
||||
.validate(req.0)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
tracing::error!("{}", err.to_string());
|
||||
|
@ -381,7 +381,7 @@ class CausalLM(Model):
|
||||
next_batch_attention_mask = torch.cat(
|
||||
[
|
||||
next_batch_attention_mask,
|
||||
torch.ones((next_batch_size, 1)).to(self.device),
|
||||
next_batch_attention_mask.new_ones(next_batch_size, 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
@ -87,7 +87,7 @@ class Seq2SeqLMBatch:
|
||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
||||
).to(device)
|
||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
@ -499,7 +499,7 @@ class Seq2SeqLM(Model):
|
||||
next_batch_decoder_attention_mask = torch.cat(
|
||||
[
|
||||
next_batch_decoder_attention_mask,
|
||||
torch.ones((next_batch_size, 1)).to(self.device),
|
||||
next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user