diff --git a/pyrit/cli/_banner.py b/pyrit/cli/_banner.py index bd5a3d40fe..a76d286b27 100644 --- a/pyrit/cli/_banner.py +++ b/pyrit/cli/_banner.py @@ -255,7 +255,7 @@ def add(line: str, role: ColorRole, segments: Optional[list[tuple[int, int, Colo "Commands:", " • list-scenarios - See all available scenarios", " • list-initializers - See all available initializers", - " • list-targets - See all available targets in the registry", + " • list-targets [opts] - See all available targets in the registry", " • run [opts] - Execute a security scenario", " • scenario-history - View your session history", " • print-scenario [N] - Display detailed results", diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index 1264956ccb..80e383dfe5 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -457,6 +457,48 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: return result +def parse_list_targets_arguments(*, args_string: str) -> dict[str, Any]: + """ + Parse list-targets command arguments from a string (for shell mode). + + Args: + args_string: Space-separated argument string (e.g., "--initializers target"). + + Returns: + Dictionary with parsed arguments: + - initializers: Optional[list[str | dict[str, Any]]] + - initialization_scripts: Optional[list[str]] + + Raises: + ValueError: If parsing or validation fails. + """ + parts = args_string.split() + + result: dict[str, Any] = { + "initializers": None, + "initialization_scripts": None, + } + + i = 0 + while i < len(parts): + if parts[i] == "--initializers": + result["initializers"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["initializers"].append(_parse_initializer_arg(parts[i])) + i += 1 + elif parts[i] == "--initialization-scripts": + result["initialization_scripts"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["initialization_scripts"].append(parts[i]) + i += 1 + else: + raise ValueError(f"Unknown argument: {parts[i]}") + + return result + + # --------------------------------------------------------------------------- # Shared argparse builder # --------------------------------------------------------------------------- diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 3db5552011..8bcf5f4655 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -27,6 +27,7 @@ from pyrit.cli._cli_args import _parse_initializer_arg as _parse_initializer_arg from pyrit.cli._cli_args import add_common_arguments as add_common_arguments from pyrit.cli._cli_args import non_negative_int as non_negative_int +from pyrit.cli._cli_args import parse_list_targets_arguments as parse_list_targets_arguments from pyrit.cli._cli_args import parse_memory_labels as parse_memory_labels from pyrit.cli._cli_args import parse_run_arguments as parse_run_arguments from pyrit.cli._cli_args import positive_int as positive_int @@ -254,18 +255,16 @@ async def list_initializers_async( async def list_targets_async( *, context: FrontendCore, - initializer_names: Optional[list[Any]] = None, ) -> list[str]: """ List available target names from the TargetRegistry. Since targets are registered by initializers, this function requires initializers - to have been run first. If initializer_names are provided, they will be resolved - and run before querying the registry. + to have been run first. Configure initializers on the FrontendCore context + (via initializer_names or initialization_scripts) before calling this function. Args: context: PyRIT context with loaded registries. - initializer_names: Optional list of initializer entries to run before listing. Returns: Sorted list of registered target names. @@ -273,25 +272,24 @@ async def list_targets_async( if not context._initialized: await context.initialize_async() - # If initializer names are provided, run them to populate the target registry - if initializer_names or context._initializer_configs: - configs = context._initializer_configs - if configs: - initializer_instances = [] - for config in configs: + # Run initializers and/or initialization scripts to populate the target registry + if context._initializer_configs or context._initialization_scripts: + initializer_instances = [] + if context._initializer_configs: + for config in context._initializer_configs: initializer_class = context.initializer_registry.get_class(config.name) instance = initializer_class() if config.args: instance.set_params_from_args(args=config.args) initializer_instances.append(instance) - await initialize_pyrit_async( - memory_db_type=context._database, - initialization_scripts=context._initialization_scripts, - initializers=initializer_instances, - env_files=context._env_files, - silent=getattr(context, "_silent_reinit", False), - ) + await initialize_pyrit_async( + memory_db_type=context._database, + initialization_scripts=context._initialization_scripts, + initializers=initializer_instances or None, + env_files=context._env_files, + silent=getattr(context, "_silent_reinit", False), + ) target_registry = TargetRegistry.get_registry_singleton() return target_registry.get_names() diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index aefdfa5f22..f6f77bea2b 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -197,17 +197,19 @@ def main(args: Optional[list[str]] = None) -> int: if parsed_args.list_targets: # Need initializers to populate target registry - context = frontend_core.FrontendCore( - config_file=parsed_args.config_file, - initializer_names=parsed_args.initializers, - log_level=parsed_args.log_level, - ) - return asyncio.run(frontend_core.print_targets_list_async(context=context)) + initialization_scripts = None + if parsed_args.initialization_scripts: + try: + initialization_scripts = frontend_core.resolve_initialization_scripts( + script_paths=parsed_args.initialization_scripts + ) + except FileNotFoundError as e: + print(f"Error: {e}") + return 1 - if parsed_args.list_targets: - # Need initializers to populate target registry context = frontend_core.FrontendCore( config_file=parsed_args.config_file, + initialization_scripts=initialization_scripts, initializer_names=parsed_args.initializers, log_level=parsed_args.log_level, ) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index f19602bee0..6707f14e95 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -33,7 +33,7 @@ class PyRITShell(cmd.Cmd): Commands: list-scenarios - List all available scenarios list-initializers - List all available initializers - list-targets - List all available targets from the registry + list-targets [opts] - List all available targets from the registry run [opts] - Run a scenario with optional parameters scenario-history - List all previous scenario runs print-scenario [N] - Print detailed results for scenario run(s) @@ -189,6 +189,9 @@ def cmdloop(self, intro: Optional[str] = None) -> None: def do_list_scenarios(self, arg: str) -> None: """List all available scenarios.""" + if arg.strip(): + print(f"Error: list-scenarios does not accept arguments, got: {arg.strip()}") + return self._ensure_initialized() try: asyncio.run(self._fc.print_scenarios_list_async(context=self.context)) @@ -197,6 +200,9 @@ def do_list_scenarios(self, arg: str) -> None: def do_list_initializers(self, arg: str) -> None: """List all available initializers.""" + if arg.strip(): + print(f"Error: list-initializers does not accept arguments, got: {arg.strip()}") + return self._ensure_initialized() try: asyncio.run(self._fc.print_initializers_list_async(context=self.context)) @@ -204,10 +210,49 @@ def do_list_initializers(self, arg: str) -> None: print(f"Error listing initializers: {e}") def do_list_targets(self, arg: str) -> None: - """List all available targets from the TargetRegistry.""" + """ + List all available targets from the TargetRegistry. + + Usage: + list-targets + list-targets --initializers [ ...] + list-targets --initialization-scripts [ ...] + + Options: + --initializers ... Built-in initializers to run first + --initialization-scripts <...> Custom Python scripts to run first + + Examples: + list-targets --initializers target + list-targets --initializers target:tags=default,scorer + """ self._ensure_initialized() try: - asyncio.run(self._fc.print_targets_list_async(context=self.context)) + context_to_use = self.context + + if arg.strip(): + args = self._fc.parse_list_targets_arguments(args_string=arg) + + resolved_scripts = None + if args["initialization_scripts"]: + resolved_scripts = self._fc.resolve_initialization_scripts( + script_paths=args["initialization_scripts"] + ) + + context_to_use = self._fc.FrontendCore( + initialization_scripts=resolved_scripts, + initializer_names=args["initializers"], + log_level=self.default_log_level, + ) + context_to_use._scenario_registry = self.context._scenario_registry + context_to_use._initializer_registry = self.context._initializer_registry + context_to_use._initialized = True + + asyncio.run(self._fc.print_targets_list_async(context=context_to_use)) + except ValueError as e: + print(f"Error: {e}") + except FileNotFoundError as e: + print(f"Error: {e}") except Exception as e: print(f"Error listing targets: {e}") @@ -338,6 +383,9 @@ def do_scenario_history(self, arg: str) -> None: Shows a numbered list of all scenario runs with the commands used. """ + if arg.strip(): + print(f"Error: scenario-history does not accept arguments, got: {arg.strip()}") + return if not self._scenario_history: print("No scenario runs in history.") return @@ -467,8 +515,9 @@ def do_help(self, arg: str) -> None: print(" pyrit_shell") print(" pyrit_shell --config-file ./my_config.yaml --log-level DEBUG") else: - # Show help for specific command - super().do_help(arg) + # Convert hyphens to underscores (e.g. help list-targets -> help list_targets) for command lookup + normalized_arg = arg.replace("-", "_") + super().do_help(normalized_arg) def do_exit(self, arg: str) -> bool: """ diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 61b3c7bb50..e3cbbca8d0 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -672,6 +672,46 @@ def test_parse_run_arguments_missing_value(self): frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency") +class TestParseListTargetsArguments: + """Tests for parse_list_targets_arguments function.""" + + def test_parse_list_targets_arguments_empty(self): + """Test parsing empty string returns defaults.""" + result = frontend_core.parse_list_targets_arguments(args_string="") + assert result["initializers"] is None + assert result["initialization_scripts"] is None + + def test_parse_list_targets_arguments_with_initializers(self): + """Test parsing with initializers.""" + result = frontend_core.parse_list_targets_arguments(args_string="--initializers target init2") + assert result["initializers"] == ["target", "init2"] + + def test_parse_list_targets_arguments_with_initializer_params(self): + """Test parsing initializers with key=value params.""" + result = frontend_core.parse_list_targets_arguments(args_string="--initializers target:tags=default,scorer") + assert result["initializers"] == [{"name": "target", "args": {"tags": ["default", "scorer"]}}] + + def test_parse_list_targets_arguments_with_initialization_scripts(self): + """Test parsing with initialization-scripts.""" + result = frontend_core.parse_list_targets_arguments( + args_string="--initialization-scripts script1.py script2.py" + ) + assert result["initialization_scripts"] == ["script1.py", "script2.py"] + + def test_parse_list_targets_arguments_with_both(self): + """Test parsing with both initializers and scripts.""" + result = frontend_core.parse_list_targets_arguments( + args_string="--initializers target --initialization-scripts script1.py" + ) + assert result["initializers"] == ["target"] + assert result["initialization_scripts"] == ["script1.py"] + + def test_parse_list_targets_arguments_unknown_arg_raises(self): + """Test parsing with unknown argument raises ValueError.""" + with pytest.raises(ValueError, match="Unknown argument"): + frontend_core.parse_list_targets_arguments(args_string="--unknown-flag") + + @pytest.mark.asyncio @pytest.mark.usefixtures("patch_central_database") class TestRunScenarioAsync: @@ -1141,3 +1181,31 @@ async def test_print_targets_list_empty( captured = capsys.readouterr() assert "No targets found" in captured.out assert "--initializers target" in captured.out + + @patch("pyrit.cli.frontend_core.TargetRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) + async def test_list_targets_with_initialization_scripts_calls_initialize( + self, + mock_init: AsyncMock, + mock_target_registry_class: MagicMock, + ): + """Test list_targets_async calls initialize_pyrit_async when only scripts are configured.""" + mock_registry = MagicMock() + mock_registry.get_names.return_value = ["script_target"] + mock_target_registry_class.get_registry_singleton.return_value = mock_registry + + context = frontend_core.FrontendCore() + context._scenario_registry = MagicMock() + context._initializer_registry = MagicMock() + context._initialized = True + context._initialization_scripts = ["/path/to/script.py"] + context._initializer_configs = None + + result = await frontend_core.list_targets_async(context=context) + + assert result == ["script_target"] + # Verify initialize_pyrit_async was called with the scripts + mock_init.assert_called_once() + call_kwargs = mock_init.call_args[1] + assert call_kwargs["initialization_scripts"] == ["/path/to/script.py"] + assert call_kwargs["initializers"] is None diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index 34a8b8ad52..a4c3620ca7 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -214,6 +214,55 @@ def test_main_list_scenarios_with_missing_script(self, mock_resolve_scripts: Mag assert result == 1 + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_list_targets_with_initializers( + self, + mock_frontend_core: MagicMock, + mock_print_targets: AsyncMock, + ): + """Test main with --list-targets and --initializers passes initializers to FrontendCore.""" + mock_print_targets.return_value = 0 + + result = pyrit_scan.main(["--list-targets", "--initializers", "target"]) + + assert result == 0 + mock_frontend_core.assert_called_once() + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["initializer_names"] == ["target"] + mock_print_targets.assert_called_once() + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_list_targets_with_scripts( + self, + mock_frontend_core: MagicMock, + mock_resolve_scripts: MagicMock, + mock_print_targets: AsyncMock, + ): + """Test main with --list-targets and --initialization-scripts passes scripts to FrontendCore.""" + mock_resolve_scripts.return_value = [Path("/test/script.py")] + mock_print_targets.return_value = 0 + + result = pyrit_scan.main(["--list-targets", "--initialization-scripts", "script.py"]) + + assert result == 0 + mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) + mock_frontend_core.assert_called_once() + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["initialization_scripts"] == [Path("/test/script.py")] + mock_print_targets.assert_called_once() + + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + def test_main_list_targets_with_missing_script(self, mock_resolve_scripts: MagicMock): + """Test main with --list-targets and missing script file.""" + mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") + + result = pyrit_scan.main(["--list-targets", "--initialization-scripts", "missing.py"]) + + assert result == 1 + def test_main_no_scenario_specified(self, capsys): """Test main without scenario name.""" result = pyrit_scan.main([]) diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 89a218644d..eca05ec159 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -151,6 +151,15 @@ def test_do_list_scenarios_with_exception(self, mock_print_scenarios: AsyncMock, captured = capsys.readouterr() assert "Error listing scenarios" in captured.out + def test_do_list_scenarios_rejects_args(self, shell, capsys): + """Test do_list_scenarios rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_list_scenarios("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) def test_do_list_initializers(self, mock_print_initializers: AsyncMock, shell): """Test do_list_initializers command.""" @@ -171,6 +180,73 @@ def test_do_list_initializers_with_exception(self, mock_print_initializers: Asyn captured = capsys.readouterr() assert "Error listing initializers" in captured.out + def test_do_list_initializers_rejects_args(self, shell, capsys): + """Test do_list_initializers rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_list_initializers("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + def test_do_list_targets_no_args(self, mock_print_targets: AsyncMock, shell): + """Test do_list_targets with no arguments uses the default context.""" + s, ctx, _ = shell + + s.do_list_targets("") + + mock_print_targets.assert_called_once_with(context=ctx) + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.FrontendCore") + @patch("pyrit.cli.frontend_core.parse_list_targets_arguments") + def test_do_list_targets_with_initializers( + self, + mock_parse: MagicMock, + mock_fc_class: MagicMock, + mock_print_targets: AsyncMock, + shell, + ): + """Test do_list_targets with --initializers creates a new context.""" + s, ctx, _ = shell + mock_parse.return_value = {"initializers": ["target"], "initialization_scripts": None} + mock_run_context = MagicMock() + mock_fc_class.return_value = mock_run_context + + s.do_list_targets("--initializers target") + + mock_parse.assert_called_once_with(args_string="--initializers target") + mock_fc_class.assert_called_once_with( + initialization_scripts=None, + initializer_names=["target"], + log_level=s.default_log_level, + ) + assert mock_run_context._scenario_registry == ctx._scenario_registry + assert mock_run_context._initializer_registry == ctx._initializer_registry + assert mock_run_context._initialized is True + mock_print_targets.assert_called_once_with(context=mock_run_context) + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + def test_do_list_targets_with_exception(self, mock_print_targets: AsyncMock, shell, capsys): + """Test do_list_targets handles exceptions.""" + s, ctx, _ = shell + mock_print_targets.side_effect = RuntimeError("Test error") + + s.do_list_targets("") + + captured = capsys.readouterr() + assert "Error listing targets" in captured.out + + def test_do_list_targets_parse_error(self, shell, capsys): + """Test do_list_targets shows error for invalid args.""" + s, ctx, _ = shell + + s.do_list_targets("--unknown-flag") + + captured = capsys.readouterr() + assert "Error" in captured.out + def test_do_run_empty_line(self, shell, capsys): """Test do_run with empty line.""" s, ctx, _ = shell @@ -380,6 +456,15 @@ def test_do_scenario_history_empty(self, shell, capsys): captured = capsys.readouterr() assert "No scenario runs in history" in captured.out + def test_do_scenario_history_rejects_args(self, shell, capsys): + """Test do_scenario_history rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_scenario_history("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + def test_do_scenario_history_with_runs(self, shell, capsys): """Test do_scenario_history with scenario runs.""" s, ctx, _ = shell @@ -502,6 +587,14 @@ def test_do_help_with_arg(self, shell): s.do_help("run") mock_parent_help.assert_called_with("run") + def test_do_help_with_hyphenated_arg(self, shell): + """Test do_help converts hyphens to underscores for command lookup.""" + s, ctx, _ = shell + + with patch("cmd.Cmd.do_help") as mock_parent_help: + s.do_help("list-targets") + mock_parent_help.assert_called_with("list_targets") + @patch.object(cmd.Cmd, "cmdloop") @patch.object(banner, "play_animation") def test_cmdloop_sets_intro_via_play_animation(self, mock_play: MagicMock, mock_cmdloop: MagicMock, shell):