55from __future__ import annotations
66
77from multiprocessing import Pool , cpu_count
8+ from typing import cast
89from warnings import warn
910
1011import numpy as np
12+ from numpy .typing import NDArray
1113
1214from poli .core .black_box_information import BlackBoxInformation
1315from poli .core .exceptions import BudgetExhaustedException
@@ -22,17 +24,17 @@ class AbstractBlackBox:
2224
2325 Parameters
2426 ----------
25- batch_size : int, optional
27+ batch_size : int | None , optional
2628 The batch size for evaluating the black box function. Default is None.
2729 parallelize : bool, optional
2830 Flag indicating whether to evaluate the black box function in parallel.
2931 Default is False.
30- num_workers : int, optional
32+ num_workers : int | None , optional
3133 The number of workers to use for parallel evaluation. Default is None,
3234 which uses half of the available CPU cores.
33- evaluation_budget : int, optional
35+ evaluation_budget : int | None , optional
3436 The maximum number of evaluations allowed for the black box function.
35- Default is None) .
37+ Default is None, which means an infinite budget .
3638
3739 Attributes
3840 ----------
@@ -44,7 +46,7 @@ class AbstractBlackBox:
4446 Flag indicating whether to evaluate the black box function in parallel.
4547 num_workers : int
4648 The number of workers to use for parallel evaluation.
47- batch_size : int or None
49+ batch_size : int | None
4850 The batch size for evaluating the black box function.
4951
5052 Methods
@@ -84,13 +86,13 @@ def __init__(
8486
8587 Parameters
8688 ----------
87- batch_size : int, optional
89+ batch_size : int | None , optional
8890 The batch size for parallel execution, by default None.
8991 parallelize : bool, optional
9092 Flag indicating whether to parallelize the execution, by default False.
91- num_workers : int, optional
93+ num_workers : int | None , optional
9294 The number of workers for parallel execution, by default we use half the available CPUs.
93- evaluation_budget : int, optional
95+ evaluation_budget : int | None , optional
9496 The maximum number of evaluations allowed for the black box function, by default it is None, which means no limit.
9597 """
9698 self .observer = None
@@ -145,13 +147,13 @@ def set_observer(self, observer: AbstractObserver):
145147 )
146148 self .observer = observer
147149
148- def set_observer_info (self , observer_info : object ):
150+ def set_observer_info (self , observer_info : dict [ str , object ] | None ):
149151 """
150152 Set the observer information after initialization.
151153
152154 Parameters
153155 ----------
154- observer_info : object
156+ observer_info : dict[str, object]
155157 The information given by the observer after initialization.
156158 """
157159 self .observer_info = observer_info
@@ -160,7 +162,7 @@ def reset_evaluation_budget(self):
160162 """Resets the evaluation budget by setting the number of evaluations made to 0."""
161163 self .num_evaluations = 0
162164
163- def __call__ (self , x : np .array , context = None ):
165+ def __call__ (self , x : NDArray [ np .str_ ] , context = None ):
164166 """Calls the black box function.
165167
166168 The purpose of this function is to enforce that inputs are equal across
@@ -340,7 +342,7 @@ def terminate(self) -> None:
340342 Terminate the black box optimization problem.
341343 """
342344 if hasattr (self , "inner_function" ):
343- self .inner_function .terminate ()
345+ self .inner_function .terminate () # type: ignore
344346 # if self.observer is not None:
345347 # # NOTE: terminating a problem should gracefully end the observer process -> write the last state.
346348 # self.observer.finish()
@@ -387,13 +389,13 @@ def __init__(self, f: AbstractBlackBox):
387389 batch_size = f .batch_size ,
388390 parallelize = f .parallelize ,
389391 num_workers = f .num_workers ,
390- evaluation_budget = f .evaluation_budget ,
392+ evaluation_budget = cast ( int | None , f .evaluation_budget ) ,
391393 )
392394
393- def __call__ (self , x , context = None ):
395+ def __call__ (self , x : NDArray [ np . str_ ] , context = None ):
394396 return - self .f .__call__ (x , context )
395397
396- def _black_box (self , x , context = None ):
398+ def _black_box (self , x : NDArray [ np . str_ ] , context = None ):
397399 return self .f ._black_box (x , context )
398400
399401 def __str__ (self ) -> str :
0 commit comments