diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 3100c56b7..cda7ee335 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -355,7 +355,7 @@ def _greedy_until( # The choice we go for here is to avoid truncating the prompt if we can, since it # should have been managed by the prompt creator/few shot manager if requested by the user. inputs = tokenized["input_ids"] - context_size = len(inputs[0]) + context_size = max(len(inp) for inp in inputs) # left truncate the inputs to the maximum length if self.max_length is None: @@ -365,7 +365,7 @@ def _greedy_until( elif max_new_tokens is not None: if context_size + max_new_tokens > self.max_length: logger.warning( - f"{context_size + max_new_tokens=} which is greater than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens." + f"Batch max length {context_size} + {max_new_tokens=} which is greater than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens." ) context_size = self.max_length - max_new_tokens if context_size < 0: @@ -377,7 +377,7 @@ def _greedy_until( else: if context_size > self.max_length: logger.warning( - f"{context_size=} which is greater than {self.max_length=}. Truncating context to {self.max_length} tokens." + f"Batch max length {context_size=} which is greater than {self.max_length=}. Truncating context to {self.max_length} tokens." ) context_size = self.max_length inputs = [input[-context_size:] for input in inputs]