diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 55069efc..e0efbcf5 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -69,9 +69,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ) generations, next_batch = self.model.generate_token(batch) - - if next_batch is not None: - self.cache.set(next_batch) + self.cache.set(next_batch) return generate_pb2.PrefillResponse( generations=[generation.to_pb() for generation in generations], @@ -98,9 +96,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = batches[0] generations, next_batch = self.model.generate_token(batch) - - if next_batch is not None: - self.cache.set(next_batch) + self.cache.set(next_batch) return generate_pb2.DecodeResponse( generations=[generation.to_pb() for generation in generations],