Skip to content

Commit e2d0876

Browse files
authored
Merge pull request #315 from MachineLearningLifeScience/feature/strict-typing
Adds strict typing using pyright, downgrades pytdc
2 parents 6897f89 + 99ca721 commit e2d0876

113 files changed

Lines changed: 1150 additions & 946 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.pre-commit-config.yaml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v3.2.0
3+
rev: v5.0.0
44
hooks:
55
- id: trailing-whitespace
66
exclude: '.*\.pdb$'
@@ -20,4 +20,12 @@ repos:
2020
hooks:
2121
# Run the linter.
2222
- id: ruff
23-
args: [ --fix ]
23+
args: [ --fix ]
24+
- repo: local
25+
hooks:
26+
- id: pyright
27+
name: pyright
28+
entry: pyright
29+
language: system
30+
require_serial: true
31+
types: [python]

pyproject.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ ehrlichholo = [
5353
"pytorch-holo",
5454
]
5555
tdc = [
56-
"pytdc",
56+
"pytdc==1.1.14",
5757
]
5858
dockstring = [
5959
"dockstring"
@@ -65,7 +65,7 @@ rosetta_energy = [
6565
"biopython",
6666
"pyrosetta-installer",
6767
]
68-
dev = ["black", "tox", "pytest", "bump-my-version"]
68+
dev = ["black", "tox", "pytest", "bump-my-version", "pre-commit", "pyright"]
6969
docs = ["sphinx", "furo"]
7070

7171
[project.urls]
@@ -83,9 +83,15 @@ markers = [
8383
"poli__rosetta_energy: marks tests that run in poli__rosetta_energy",
8484
"poli__ehrlich_holo: marks tests that run in poli__ehrlich_holo environment",
8585
"poli__dms: marks tests that run in poli__dms environment",
86+
"isolation: marks tests that require isolation of the black box function",
8687
"unmarked: All other tests, which usually run in the base environment",
8788
]
8889

90+
[tool.pyright]
91+
include = ["src/poli"]
92+
exclude = ["src/poli/core/util/proteins/rasp/inner_rasp", "src/poli/objective_repository/gfp_cbas", "examples", "src/poli/tests"]
93+
reportIncompatibleMethodOverride = "none"
94+
8995
[tool.isort]
9096
profile = "black"
9197

@@ -148,4 +154,7 @@ replace = 'version: {new_version}'
148154
[dependency-groups]
149155
dev = [
150156
"pre-commit>=4.2.0",
157+
"pyright>=1.1.403",
158+
"pytest>=8.4.0",
159+
"ruff>=0.12.3",
151160
]

src/poli/benchmarks/guacamol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _initialize_problem(self, index: int) -> Problem:
125125
problem_factory = self.problem_factories[index]
126126

127127
problem = problem_factory.create(
128-
string_representation=self.string_representation,
128+
string_representation=self.string_representation, # type: ignore
129129
seed=self.seed,
130130
batch_size=self.batch_size,
131131
parallelize=self.parallelize,

src/poli/benchmarks/pmo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(
7575
batch_size: Union[int, None] = None,
7676
parallelize: bool = False,
7777
num_workers: Union[int, None] = None,
78-
evaluation_budget: int = None,
78+
evaluation_budget: int | None = None,
7979
) -> None:
8080
super().__init__(
8181
string_representation=string_representation,

src/poli/benchmarks/toy_continuous_functions_benchmark.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
https://www.sfu.ca/~ssurjano/optimization.html.
1616
"""
1717

18-
from typing import List, Union
18+
from typing import Sequence, Union, cast
1919

2020
from poli.core.abstract_benchmark import AbstractBenchmark
2121
from poli.core.problem import Problem
2222
from poli.objective_repository import ToyContinuousProblemFactory
2323
from poli.objective_repository.toy_continuous_problem.toy_continuous_problem import (
2424
POSSIBLE_FUNCTIONS,
25+
POSSIBLE_FUNCTIONS_TYPE,
2526
SIX_DIMENSIONAL_PROBLEMS,
2627
TWO_DIMENSIONAL_PROBLEMS,
2728
)
@@ -49,12 +50,12 @@ def __init__(
4950
self,
5051
n_dimensions: int = 2,
5152
embed_in: Union[int, None] = None,
52-
dimensions_to_embed_in: Union[List[int], None] = None,
53+
dimensions_to_embed_in: Union[list[int], None] = None,
5354
seed: Union[int, None] = None,
5455
batch_size: Union[int, None] = None,
5556
parallelize: bool = False,
5657
num_workers: Union[int, None] = None,
57-
evaluation_budget: Union[int, List[int]] = None,
58+
evaluation_budget: int | None = None,
5859
) -> None:
5960
super().__init__(
6061
seed=seed,
@@ -66,7 +67,7 @@ def __init__(
6667
self.n_dimensions = n_dimensions
6768
self.embed_in = embed_in
6869
self.dimensions_to_embed_in = dimensions_to_embed_in
69-
self.function_names = list(
70+
self.function_names: Sequence[POSSIBLE_FUNCTIONS_TYPE] = list( # type: ignore
7071
(
7172
set(POSSIBLE_FUNCTIONS)
7273
- set(TWO_DIMENSIONAL_PROBLEMS)
@@ -78,7 +79,9 @@ def __init__(
7879
)
7980

8081
def _initialize_problem(self, index: int) -> Problem:
81-
problem_factory: ToyContinuousProblemFactory = self.problem_factories[index]
82+
problem_factory: ToyContinuousProblemFactory = cast(
83+
ToyContinuousProblemFactory, self.problem_factories[index]
84+
)
8285

8386
problem = problem_factory.create(
8487
function_name=self.function_names[index],
@@ -121,7 +124,7 @@ def __init__(
121124
batch_size: Union[int, None] = None,
122125
parallelize: bool = False,
123126
num_workers: Union[int, None] = None,
124-
evaluation_budget: Union[int, List[int]] = None,
127+
evaluation_budget: int | None = None,
125128
) -> None:
126129
super().__init__(
127130
seed,
@@ -134,7 +137,9 @@ def __init__(
134137
self.problem_factories = [ToyContinuousProblemFactory()] * len(self.embed_in)
135138

136139
def _initialize_problem(self, index: int) -> Problem:
137-
problem_factory: ToyContinuousProblemFactory = self.problem_factories[index]
140+
problem_factory: ToyContinuousProblemFactory = cast(
141+
ToyContinuousProblemFactory, self.problem_factories[index]
142+
)
138143

139144
problem = problem_factory.create(
140145
function_name="branin_2d",
@@ -174,7 +179,7 @@ def __init__(
174179
batch_size: Union[int, None] = None,
175180
parallelize: bool = False,
176181
num_workers: Union[int, None] = None,
177-
evaluation_budget: Union[int, List[int]] = None,
182+
evaluation_budget: int | None = None,
178183
) -> None:
179184
super().__init__(
180185
seed,
@@ -187,7 +192,9 @@ def __init__(
187192
self.problem_factories = [ToyContinuousProblemFactory()] * len(self.embed_in)
188193

189194
def _initialize_problem(self, index: int) -> Problem:
190-
problem_factory: ToyContinuousProblemFactory = self.problem_factories[index]
195+
problem_factory: ToyContinuousProblemFactory = cast(
196+
ToyContinuousProblemFactory, self.problem_factories[index]
197+
)
191198

192199
if index == 0:
193200
problem = problem_factory.create(

src/poli/core/abstract_benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

3-
from typing import List, Union
3+
from typing import Union
44

55
from poli.core.abstract_problem_factory import AbstractProblemFactory
66
from poli.core.problem import Problem
77

88

99
class AbstractBenchmark:
10-
problem_factories: List[AbstractProblemFactory]
10+
problem_factories: list[AbstractProblemFactory]
1111
index: int = 0
1212

1313
def __init__(
@@ -16,7 +16,7 @@ def __init__(
1616
batch_size: Union[int, None] = None,
1717
parallelize: bool = False,
1818
num_workers: Union[int, None] = None,
19-
evaluation_budget: int = None,
19+
evaluation_budget: int | None = None,
2020
) -> None:
2121
self.seed = seed
2222
self.batch_size = batch_size
@@ -46,7 +46,7 @@ def info(self) -> str:
4646
raise NotImplementedError
4747

4848
@property
49-
def problem_names(self) -> List[str]:
49+
def problem_names(self) -> list[str]:
5050
return [
5151
problem_factory.__module__.replace(
5252
"poli.objective_repository.", ""

src/poli/core/abstract_black_box.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from __future__ import annotations
66

77
from multiprocessing import Pool, cpu_count
8+
from typing import cast
89
from warnings import warn
910

1011
import numpy as np
12+
from numpy.typing import NDArray
1113

1214
from poli.core.black_box_information import BlackBoxInformation
1315
from poli.core.exceptions import BudgetExhaustedException
@@ -22,17 +24,17 @@ class AbstractBlackBox:
2224
2325
Parameters
2426
----------
25-
batch_size : int, optional
27+
batch_size : int | None, optional
2628
The batch size for evaluating the black box function. Default is None.
2729
parallelize : bool, optional
2830
Flag indicating whether to evaluate the black box function in parallel.
2931
Default is False.
30-
num_workers : int, optional
32+
num_workers : int | None, optional
3133
The number of workers to use for parallel evaluation. Default is None,
3234
which uses half of the available CPU cores.
33-
evaluation_budget : int, optional
35+
evaluation_budget : int | None, optional
3436
The maximum number of evaluations allowed for the black box function.
35-
Default is None).
37+
Default is None, which means an infinite budget.
3638
3739
Attributes
3840
----------
@@ -44,7 +46,7 @@ class AbstractBlackBox:
4446
Flag indicating whether to evaluate the black box function in parallel.
4547
num_workers : int
4648
The number of workers to use for parallel evaluation.
47-
batch_size : int or None
49+
batch_size : int | None
4850
The batch size for evaluating the black box function.
4951
5052
Methods
@@ -84,13 +86,13 @@ def __init__(
8486
8587
Parameters
8688
----------
87-
batch_size : int, optional
89+
batch_size : int | None, optional
8890
The batch size for parallel execution, by default None.
8991
parallelize : bool, optional
9092
Flag indicating whether to parallelize the execution, by default False.
91-
num_workers : int, optional
93+
num_workers : int | None, optional
9294
The number of workers for parallel execution, by default we use half the available CPUs.
93-
evaluation_budget : int, optional
95+
evaluation_budget : int | None, optional
9496
The maximum number of evaluations allowed for the black box function, by default it is None, which means no limit.
9597
"""
9698
self.observer = None
@@ -145,13 +147,13 @@ def set_observer(self, observer: AbstractObserver):
145147
)
146148
self.observer = observer
147149

148-
def set_observer_info(self, observer_info: object):
150+
def set_observer_info(self, observer_info: dict[str, object] | None):
149151
"""
150152
Set the observer information after initialization.
151153
152154
Parameters
153155
----------
154-
observer_info : object
156+
observer_info : dict[str, object]
155157
The information given by the observer after initialization.
156158
"""
157159
self.observer_info = observer_info
@@ -160,7 +162,7 @@ def reset_evaluation_budget(self):
160162
"""Resets the evaluation budget by setting the number of evaluations made to 0."""
161163
self.num_evaluations = 0
162164

163-
def __call__(self, x: np.array, context=None):
165+
def __call__(self, x: NDArray[np.str_], context=None):
164166
"""Calls the black box function.
165167
166168
The purpose of this function is to enforce that inputs are equal across
@@ -340,7 +342,7 @@ def terminate(self) -> None:
340342
Terminate the black box optimization problem.
341343
"""
342344
if hasattr(self, "inner_function"):
343-
self.inner_function.terminate()
345+
self.inner_function.terminate() # type: ignore
344346
# if self.observer is not None:
345347
# # NOTE: terminating a problem should gracefully end the observer process -> write the last state.
346348
# self.observer.finish()
@@ -387,13 +389,13 @@ def __init__(self, f: AbstractBlackBox):
387389
batch_size=f.batch_size,
388390
parallelize=f.parallelize,
389391
num_workers=f.num_workers,
390-
evaluation_budget=f.evaluation_budget,
392+
evaluation_budget=cast(int | None, f.evaluation_budget),
391393
)
392394

393-
def __call__(self, x, context=None):
395+
def __call__(self, x: NDArray[np.str_], context=None):
394396
return -self.f.__call__(x, context)
395397

396-
def _black_box(self, x, context=None):
398+
def _black_box(self, x: NDArray[np.str_], context=None):
397399
return self.f._black_box(x, context)
398400

399401
def __str__(self) -> str:

src/poli/core/abstract_problem_factory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ class AbstractProblemFactory(metaclass=MetaProblemFactory):
3030

3131
def create(
3232
self,
33-
seed: int = None,
34-
batch_size: int = None,
33+
seed: int | None = None,
34+
batch_size: int | None = None,
3535
parallelize: bool = False,
36-
num_workers: int = None,
37-
evaluation_budget: int = None,
36+
num_workers: int | None = None,
37+
evaluation_budget: int | None = None,
3838
force_isolation: bool = False,
3939
) -> Problem:
4040
"""

src/poli/core/benchmark_information.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(
1010
fixed_length: bool,
1111
deterministic: bool,
1212
alphabet: list,
13-
log_transform_recommended: bool = None,
13+
log_transform_recommended: bool | None = None,
1414
discrete: bool = True,
1515
fidelity: Union[Literal["high", "low"], None] = None,
1616
padding_token: str = "",
@@ -111,7 +111,7 @@ def get_alphabet(self) -> list:
111111
"""
112112
return self.alphabet
113113

114-
def log_transform_recommended(self) -> bool:
114+
def is_log_transform_recommended(self) -> bool | None:
115115
"""
116116
Returns whether the black-box recommends log-transforming the targets.
117117

0 commit comments

Comments
 (0)