mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
add docs
This commit is contained in:
parent
89565b4eaf
commit
f0609e73d8
@ -1,5 +1,5 @@
|
|||||||
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
|
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
||||||
use crate::{ClientError, Result};
|
use crate::{ClientError, Result};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
@ -123,6 +123,7 @@ impl ShardedClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Merge generations from the different model shards
|
||||||
fn merge_generations(
|
fn merge_generations(
|
||||||
mut results: Vec<(Vec<Generation>, Option<Batch>)>,
|
mut results: Vec<(Vec<Generation>, Option<Batch>)>,
|
||||||
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
|
@ -572,6 +572,8 @@ class CausalLM(Model):
|
|||||||
if not stop:
|
if not stop:
|
||||||
stopped = False
|
stopped = False
|
||||||
|
|
||||||
|
# Shard generations
|
||||||
|
# All generations will be appended in the rust sharded client
|
||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
|
@ -690,6 +690,8 @@ class FlashCausalLM(Model):
|
|||||||
if not stop:
|
if not stop:
|
||||||
stopped = False
|
stopped = False
|
||||||
|
|
||||||
|
# Shard generations
|
||||||
|
# All generations will be appended in the rust sharded client
|
||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
|
@ -653,6 +653,8 @@ class Seq2SeqLM(Model):
|
|||||||
if not stop:
|
if not stop:
|
||||||
stopped = False
|
stopped = False
|
||||||
|
|
||||||
|
# Shard generations
|
||||||
|
# All generations will be appended in the rust sharded client
|
||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Slice with decoder_input_length to remove padding
|
# Slice with decoder_input_length to remove padding
|
||||||
|
Loading…
Reference in New Issue
Block a user