fix: Some small fixes

- Avoid theoretical hang in batcher loop
- Avoid a couple of clones in py server generate method
- Keep attention mask tensors as integers
This commit is contained in:
Nick Hill 2022-11-29 13:22:25 -08:00
parent daa1d81d5e
commit a172430d8b
4 changed files with 5 additions and 9 deletions

View File

@ -105,7 +105,7 @@ async fn batching_task(
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB // waiting in the DB
let mut waiting_tokens = 0; let mut waiting_tokens = 0;
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) { while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
waiting_tokens += 1; waiting_tokens += 1;

View File

@ -90,11 +90,7 @@ async fn generate(
// Validate request // Validate request
let (input_length, validated_request) = state let (input_length, validated_request) = state
.validation .validation
// FIXME: can't we get rid of the cloning here?? .validate(req.0)
.validate(GenerateRequest {
inputs: req.inputs.clone(),
parameters: req.parameters.clone(),
})
.await .await
.map_err(|err| { .map_err(|err| {
tracing::error!("{}", err.to_string()); tracing::error!("{}", err.to_string());

View File

@ -381,7 +381,7 @@ class CausalLM(Model):
next_batch_attention_mask = torch.cat( next_batch_attention_mask = torch.cat(
[ [
next_batch_attention_mask, next_batch_attention_mask,
torch.ones((next_batch_size, 1)).to(self.device), next_batch_attention_mask.new_ones(next_batch_size, 1),
], ],
dim=1, dim=1,
) )

View File

@ -87,7 +87,7 @@ class Seq2SeqLMBatch:
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device) ).to(device)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1] # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1) decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -499,7 +499,7 @@ class Seq2SeqLM(Model):
next_batch_decoder_attention_mask = torch.cat( next_batch_decoder_attention_mask = torch.cat(
[ [
next_batch_decoder_attention_mask, next_batch_decoder_attention_mask,
torch.ones((next_batch_size, 1)).to(self.device), next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
], ],
dim=1, dim=1,
) )