diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 986ec377..60b81fe6 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,5 +1,5 @@ -use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; /// Multi shard Client +use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; use crate::{ClientError, Result}; use futures::future::join_all; use tonic::transport::Uri; @@ -123,6 +123,7 @@ impl ShardedClient { } } +/// Merge generations from the different model shards fn merge_generations( mut results: Vec<(Vec, Option)>, ) -> Result<(Vec, Option)> { diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index c029e3ab..610dc4e2 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -572,6 +572,8 @@ class CausalLM(Model): if not stop: stopped = False + # Shard generations + # All generations will be appended in the rust sharded client if i % self.world_size == self.rank: if stop: # Decode generated tokens diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index eefa5be9..e862cfeb 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -690,6 +690,8 @@ class FlashCausalLM(Model): if not stop: stopped = False + # Shard generations + # All generations will be appended in the rust sharded client if i % self.world_size == self.rank: if stop: # Decode generated tokens diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index ed658126..d4a0ddcc 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -653,6 +653,8 @@ class Seq2SeqLM(Model): if not stop: stopped = False + # Shard generations + # All generations will be appended in the rust sharded client if i % self.world_size == self.rank: if stop: # Slice with decoder_input_length to remove padding