fix: Some small fixes

- Avoid theoretical hang in batcher loop
- Avoid a couple of clones in py server generate method
- Keep attention mask tensors as integers
This commit is contained in:
Nick Hill 2022-11-29 13:22:25 -08:00
parent daa1d81d5e
commit a172430d8b
4 changed files with 5 additions and 9 deletions

View File

@ -105,7 +105,7 @@ async fn batching_task(
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB
let mut waiting_tokens = 0;
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
waiting_tokens += 1;

View File

@ -90,11 +90,7 @@ async fn generate(
// Validate request
let (input_length, validated_request) = state
.validation
// FIXME: can't we get rid of the cloning here??
.validate(GenerateRequest {
inputs: req.inputs.clone(),
parameters: req.parameters.clone(),
})
.validate(req.0)
.await
.map_err(|err| {
tracing::error!("{}", err.to_string());

View File

@ -381,7 +381,7 @@ class CausalLM(Model):
next_batch_attention_mask = torch.cat(
[
next_batch_attention_mask,
torch.ones((next_batch_size, 1)).to(self.device),
next_batch_attention_mask.new_ones(next_batch_size, 1),
],
dim=1,
)

View File

@ -87,7 +87,7 @@ class Seq2SeqLMBatch:
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
return cls(
batch_id=pb.id,
@ -499,7 +499,7 @@ class Seq2SeqLM(Model):
next_batch_decoder_attention_mask = torch.cat(
[
next_batch_decoder_attention_mask,
torch.ones((next_batch_size, 1)).to(self.device),
next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
],
dim=1,
)