-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrouter.py
More file actions
54 lines (47 loc) · 1.28 KB
/
router.py
File metadata and controls
54 lines (47 loc) · 1.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import utils
from omegaconf import OmegaConf
from router import coreset_similarity_test, semantic_similarity_test
def run(cfg):
# target classes
class_list = [
"bottle",
"cable",
"capsule",
"carpet",
"grid",
"hazelnut",
"leather",
"metal_nut",
"pill",
"screw",
"tile",
"toothbrush",
"transistor",
"wood",
"zipper",
]
# set device
device = utils.set_torch_device(gpu_ids=cfg.DEFAULT.device_ids)
# routing test
if cfg.ROUTER.name == 'coreset_similarity':
assert cfg.DEFAULT.get('coreset_dir'), 'The saved coreset filepath is needed.'
coreset_similarity_test(
cfg = cfg,
classes = class_list,
device = device
)
elif cfg.ROUTER.name == 'semantic_similarity':
semantic_similarity_test(
cfg = cfg,
classes = class_list,
device = device
)
if __name__ == '__main__':
args = OmegaConf.from_cli()
# load default config
cfg = OmegaConf.load(args.config)
del args['config']
# merge config with new keys
cfg = OmegaConf.merge(cfg, args)
print(OmegaConf.to_yaml(cfg))
run(cfg)