comment out prints

This commit is contained in:
Bill Nell 2023-10-12 14:53:28 -04:00
parent 5e02f40a83
commit 79b5a4068b

View File

@ -261,7 +261,7 @@ class DeepSparseDecoderModel:
if len(input_ids) == self.batch_size and self.batch_size != 1: if len(input_ids) == self.batch_size and self.batch_size != 1:
engine_inputs = self.engine_inputs_for_decode(input_ids) engine_inputs = self.engine_inputs_for_decode(input_ids)
print(f"GOT HERE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! {past_key_values}") #print(f"GOT HERE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! {past_key_values}")
logits, new_key_values = self.batched_singletoken_engine( logits, new_key_values = self.batched_singletoken_engine(
engine_inputs, engine_inputs,
past_key_values past_key_values
@ -329,10 +329,10 @@ class DeepSparseDecoderModel:
#print(f"forward pkv {past_key_values} {past_key_values[0] is None}") #print(f"forward pkv {past_key_values} {past_key_values[0] is None}")
if past_key_values[0] is None: if past_key_values[0] is None:
assert len(input_ids) == 1 assert len(input_ids) == 1
print("PREFILL!!!!!!!!!!!!!!!!!!!!!") #print("PREFILL!!!!!!!!!!!!!!!!!!!!!")
return self.prefill(input_ids[0]) return self.prefill(input_ids[0])
else: else:
print("DECODE!!!!!!!!!!!!!!!!!!!!!") #print("DECODE!!!!!!!!!!!!!!!!!!!!!")
return self.decode(input_ids, past_key_values) return self.decode(input_ids, past_key_values)
def __call__( def __call__(