-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
106 lines (97 loc) · 4.03 KB
/
main.py
File metadata and controls
106 lines (97 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import wandb
import torch
import torchvision
import gc
import platform
import socket
import os
from config import args
if not args.inference:
from ava import get_all, data_samplers, data_transforms, ava_data_reflect
else:
from inference import get_all, data_samplers, data_transforms, ava_data_reflect
from download import get_dataset
if args.inference:
print('evaluating model')
from inference import deep_eval
else:
import ava
if __name__ == '__main__':
os.environ['WANDB_HOST'] = 'https://api.wandb.ai/'
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
if args.download:
assert args.dataset in ['ava']
get_data = get_dataset(
fid='crushed.zip', url='http://desigley.space/ava/crushed.zip')
get_data.get_zip()
get_data.unzip(out_dir='images')
elif args.unzip_only:
get_data = get_dataset(
fid='../crushed.zip', url=None)
get_data.unzip(out_dir='../images')
df, y_g_dict, data, neg, pos = get_all(subset=args.subset)
reflect_transforms = data_transforms(size=224)
# data_loader = data_samplers(
# data=data,
# reflect_transforms=reflect_transforms,
# ava_data_reflect=ava_data_reflect,
# batch_size=args.batch_size)
if args.inference:
wandb.login(host='https://api.wandb.ai/')
print(f'running in inference mode{"8"*30}')
wb_tags = ['inference', platform.system(), platform.system(),
platform.release(), socket.gethostname(), platform.node()]
if args.entity and args.project and args.tags:
run = wandb.init(entity=args.entity, project=args.project, tags=args.tags)
elif args.entity and args.project:
run = wandb.init(entity=args.entity, project=args.project,
group = platform.node())
elif args.d.exists():
with args.d.open('r') as hndl:
for default in hndl.readlines():
default = default.split('=')
if len(default)==2:
arg_,param = default
arg_= arg_.strip('\n').strip(' ')
param = param.strip('\n').strip(' ')
args.arg_=param
run = wandb.init(entity=args.entity, project=args.project, tags=wb_tags)
else:
run = wandb.init(group = platform.node(), tags=wb_tags)
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512,2)
loaded = torch.load('models/resnet_18',
map_location=torch.device(args.device))
model.load_state_dict(loaded['model_state_dict'])
model.to(device)
data_load_dict = data_samplers(data,ava_data_reflect,reflect_transforms,batch_size=args.batch_size)
evaluation = deep_eval(model,run, data_load_dict=data_load_dict)
print('logging wandb table')
run.finish()
else:
device = 'cuda'
data_load_dict = data_samplers(data, ava_data_reflect, batch_size=128)
torch.clear_autocast_cache()
model = torch.nn.Conv2d(2, 64, kernel_size=(
3, 3), stride=(1, 1), padding=(1, 1))
model.to(device)
del model
gc.collect()
torch.cuda.empty_cache()
did = Path('drive/MyDrive/0.AVA/results/training_eq_conditions/')
if not did.exists():
did.mkdir()
res_did = Path('drive/MyDrive/0.AVA/results/')
trained = [sub_dir.name for dir in did.iterdir()
for sub_dir in dir.iterdir()]
avail_pretrained_models = timm.list_models(pretrained=True)
nets = [net for net in avail_pretrained_models if 'mobilevit' in net and 'tf' not in net and 'rw' not in net]
# set to multiple to train whole stack of models under equeal conditions
nets = [nets[-1]]
nets
still_to_train = {net: {'location': did/net, 'epochs': 10}
for net in nets}
mods = ava.loader(still_to_train)