-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_heatmap.py
More file actions
498 lines (430 loc) · 18.6 KB
/
run_heatmap.py
File metadata and controls
498 lines (430 loc) · 18.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
import argparse
import os
import random
import json
from pathlib import Path
from types import SimpleNamespace
import numpy as np
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
from PIL import Image
from tqdm import tqdm
from transformers.utils import is_flash_attn_2_available
from models.Qwen2_5_VL.Qwen2_5_VL_hf import Qwen2_5_VL
QUESTION = "Is there any anomaly in the image?\nAnswer the question using a single word or phrase."
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")
MMXU_DEFAULT_IMAGE_ROOT = "path/to/physionet.org/files/mimic-cxr-jpg/2.1.0/"
# ------------------------------
# Utils: robust probability map
# ------------------------------
def to_prob_map_from_tensor(t4: torch.Tensor, target_hw):
"""
Convert preview tensor into a [H, W] probability/activation map in [0,1].
Returns: (map, mode) where mode in {"prob", "softmax", "feature"}
- C==1: treat as logits -> sigmoid (mode="prob")
- C==2: treat as (bg, fg) logits -> softmax[:,1] (mode="softmax")
- C>2: looks like features -> ReLU, channel-max, min-max normalize (viz only) (mode="feature")
"""
if t4.dim() == 3:
t4 = t4.unsqueeze(0) # [1,C,h,w]
assert t4.dim() == 4, f"expect [N,C,h,w] or [C,h,w], got {t4.shape}"
t4 = t4[:1] # [1,C,h,w]
_, C, h, w = t4.shape
if (h, w) != target_hw:
t4 = F.interpolate(t4, size=target_hw, mode="bilinear", align_corners=False)
if C == 1:
vals = t4[0,0]
if vals.max() <= 1.0 and vals.min() >= 0.0:
m = vals # Already a probability map.
mode = "prob"
else:
m = torch.sigmoid(vals) # Apply sigmoid for logits.
mode = "prob"
elif C == 2:
m = torch.softmax(t4[0, :2], dim=0)[1] # foreground prob
mode = "softmax"
else:
feat = F.relu(t4[0]).amax(dim=0) # [H,W]
mn, mx = feat.min(), feat.max()
m = (feat - mn) / (mx - mn + 1e-6)
mode = "feature"
return m.clamp(0., 1.), mode
def shard_samples(samples, rank: int, world_size: int):
if world_size <= 1:
return samples
return samples[rank::world_size]
def aggregate_metrics(metrics_list: list[dict]) -> dict:
"""Aggregate per-rank metric dicts into a single dataset summary."""
if not metrics_list:
raise ValueError("metrics_list must not be empty")
dataset_name = metrics_list[0]["dataset"]
samples = sum(int(m["samples"]) for m in metrics_list)
no_mask_samples = sum(int(m.get("samples_no_mask", 0)) for m in metrics_list)
tp = sum(int(m["tp"]) for m in metrics_list)
fp = sum(int(m["fp"]) for m in metrics_list)
fn = sum(int(m["fn"]) for m in metrics_list)
tn = sum(int(m["tn"]) for m in metrics_list)
iou_sum = sum(float(m.get("iou_sum", 0.0)) for m in metrics_list)
iou_count = sum(int(m.get("iou_samples", 0)) for m in metrics_list)
l2_sum = sum(float(m.get("l2_sum", 0.0)) for m in metrics_list)
l2_count = sum(int(m.get("l2_samples", 0)) for m in metrics_list)
auc_sum = sum(float(m.get("auc_sum", 0.0)) for m in metrics_list)
auc_count = sum(int(m.get("auc_samples", 0)) for m in metrics_list)
precision = tp / (tp + fp) if (tp + fp) else 0.0
recall = tp / (tp + fn) if (tp + fn) else 0.0
accuracy = (tp + tn) / samples if samples else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
mean_iou = iou_sum / iou_count if iou_count else 0.0
mean_l2 = l2_sum / l2_count if l2_count else 0.0
mean_auc = auc_sum / auc_count if auc_count else 0.0
return {
"dataset": dataset_name,
"samples": samples,
"samples_no_mask": no_mask_samples,
"tp": tp,
"fp": fp,
"fn": fn,
"tn": tn,
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
"mean_iou": mean_iou,
"iou_samples": iou_count,
"iou_sum": iou_sum,
"mean_l2": mean_l2,
"l2_samples": l2_count,
"l2_sum": l2_sum,
"mean_auc": mean_auc,
"auc_samples": auc_count,
"auc_sum": auc_sum,
}
def format_dataset_metrics(metrics: dict) -> str:
return (
f"[{metrics['dataset']}] "
f"IoU: {metrics['mean_iou']:.4f}, "
f"L2: {metrics['mean_l2']:.4f}, "
f"AUC: {metrics['mean_auc']:.4f} "
f"(n={metrics['samples']} / {metrics['iou_samples']} IoU / "
f"{metrics['l2_samples']} L2 / {metrics['auc_samples']} AUC)"
)
# ------------------------------
# Dataset helpers
# ------------------------------
def discover_class_dirs(root: str):
"""Return a list of (class_dir, class_name) tuples for the provided root."""
root = os.path.abspath(root)
class_dirs = []
for class_name in ("good", "ungood"):
candidate = os.path.join(root, class_name)
if os.path.isdir(candidate) and os.path.isdir(os.path.join(candidate, "img")):
class_dirs.append((candidate, class_name))
if not class_dirs and os.path.isdir(os.path.join(root, "img")):
class_name = os.path.basename(root.rstrip(os.sep)).lower()
if class_name in ("good", "ungood"):
class_dirs.append((root, class_name))
return class_dirs
def _is_mmxu_jsonl(root: str) -> bool:
return os.path.isfile(root) and root.lower().endswith(".jsonl")
def _gather_mmxu_samples(jsonl_path: str, image_root: str):
samples = []
with open(jsonl_path, "r") as f:
for idx, line in enumerate(f):
line = line.strip()
if not line:
continue
item = json.loads(line)
sample_id = item.get("id", idx)
cur_rel = item.get("cur_image_path")
prior_rel = item.get("prior_image_path")
if cur_rel:
samples.append((os.path.join(image_root, cur_rel), None, "mmxu", sample_id, "cur"))
if prior_rel:
samples.append((os.path.join(image_root, prior_rel), None, "mmxu", sample_id, "prior"))
return samples
def gather_samples(root: str, mmxu_image_root=None):
"""Collect (image_path, mask_path, class_name[, sample_id, tag]) tuples under the given dataset root."""
if _is_mmxu_jsonl(root):
image_root = mmxu_image_root or MMXU_DEFAULT_IMAGE_ROOT
return _gather_mmxu_samples(root, image_root)
class_dirs = discover_class_dirs(root)
samples = []
for class_dir, class_name in class_dirs:
img_dir = os.path.join(class_dir, "img")
mask_dir = os.path.join(class_dir, "label")
if not os.path.isdir(img_dir) or not os.path.isdir(mask_dir):
continue
for filename in sorted(os.listdir(img_dir)):
if not filename.lower().endswith(IMAGE_EXTENSIONS):
continue
img_path = os.path.join(img_dir, filename)
mask_path = os.path.join(mask_dir, filename)
if not os.path.exists(mask_path):
continue
samples.append((img_path, mask_path, class_name))
return samples
def resolve_dataset_name(root: str) -> str:
if _is_mmxu_jsonl(root):
return Path(root).stem
norm_root = os.path.normpath(root)
base = os.path.basename(norm_root)
parent = os.path.basename(os.path.dirname(norm_root))
if base.lower() in {"train", "test", "val", "validation"} and parent:
return parent
return base
# Kept for potential future use; now we prefer to_prob_map_from_tensor
def prepare_prediction_mask(preview_tensor: torch.Tensor, target_hw):
"""Resize preview logits/probs to target spatial resolution (legacy path)."""
if preview_tensor.dim() == 3:
preview_tensor = preview_tensor.unsqueeze(0)
if preview_tensor.dim() != 4:
raise ValueError(f"Unexpected preview tensor shape: {preview_tensor.shape}")
preview_tensor = preview_tensor[:1]
if preview_tensor.shape[1] > 1:
preview_tensor = preview_tensor[:, :1]
if preview_tensor.shape[-2:] != target_hw:
preview_tensor = F.interpolate(
preview_tensor, size=target_hw, mode="bilinear", align_corners=False
)
return preview_tensor[0, 0]
def compute_binary_auc(prob_map: torch.Tensor, target_mask: torch.Tensor):
"""Compute ROC AUC for a binary segmentation map; return None if undefined."""
prob_flat = prob_map.reshape(-1).float()
target_flat = target_mask.reshape(-1).float()
pos_count = int(target_flat.sum().item())
total = target_flat.numel()
neg_count = total - pos_count
if pos_count == 0 or neg_count == 0:
return None
sorted_probs, indices = torch.sort(prob_flat, descending=True)
sorted_labels = target_flat[indices]
tps = torch.cumsum(sorted_labels, dim=0)
fps = torch.cumsum(1.0 - sorted_labels, dim=0)
tpr = tps / pos_count
fpr = fps / neg_count
auc = torch.trapz(tpr, fpr).item()
return auc
# ------------------------------
# Core evaluation
# ------------------------------
def evaluate_dataset(model_wrapper, dataset_name, samples, seg_threshold, shuffle=True):
if shuffle:
random.shuffle(samples)
tp = fp = fn = tn = 0
iou_sum = iou_count = 0.0
l2_sum = l2_count = 0.0
auc_sum = auc_count = 0.0
no_mask_samples = 0
for sample in tqdm(samples, desc=f"{dataset_name}"):
if len(sample) >= 3:
img_path, mask_path, class_name = sample[:3]
else:
continue
try:
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L") if mask_path else None
except Exception as exc:
continue
# Run model; make sure previous preview is cleared
setattr(model_wrapper.llm, "_last_segmentation_preview", None)
message = {"prompt": QUESTION, "image": image}
response, _ = model_wrapper.generate_with_segment_preview(message, tune_mode="visualize")
# Normalize yes/no textual answer
response_clean = response.strip().lower().strip(".")
if response_clean not in {"yes", "no"}:
if "yes" in response_clean:
response_clean = "yes"
elif "no" in response_clean:
response_clean = "no"
has_mask = mask is not None and class_name in {"good", "ungood"}
if has_mask:
gt = "yes" if class_name == "ungood" else "no"
if response_clean == "yes":
tp += (gt == "yes")
fp += (gt == "no")
elif response_clean == "no":
fn += (gt == "yes")
tn += (gt == "no")
# Get preview tensor from model (could be logits or features)
preview = getattr(model_wrapper.llm, "_last_segmentation_preview", None)
if preview is None or "probs" not in preview:
continue
prob_tensor = preview["probs"].detach().float()
img_w, img_h = image.size
if not has_mask:
pred_prob, prob_mode = to_prob_map_from_tensor(prob_tensor, (img_h, img_w))
if prob_mode == "feature":
continue
no_mask_samples += 1
continue
mask_array = np.array(mask, dtype=np.uint8)
gt_tensor = torch.from_numpy((mask_array > 0).astype(np.float32)) # [H,W]
# Build probability map robustly
pred_prob, prob_mode = to_prob_map_from_tensor(prob_tensor, gt_tensor.shape) # [H,W] in [0,1]
if prob_mode == "feature":
continue
pred_binary = (pred_prob >= seg_threshold)
gt_binary = (gt_tensor > 0)
gt_float = gt_binary.float()
# IoU
intersection = torch.logical_and(pred_binary, gt_binary).sum().item()
union = torch.logical_or(pred_binary, gt_binary).sum().item()
iou = 1.0 if union == 0 else intersection / max(union, 1)
iou_sum += iou; iou_count += 1
# Mean L2 / AUC (use probabilities consistently)
l2_value = F.mse_loss(pred_prob, gt_float)
l2_sum += l2_value.item(); l2_count += 1
auc_value = compute_binary_auc(pred_prob, gt_float)
if auc_value is not None:
auc_sum += auc_value; auc_count += 1
# Aggregate metrics for this dataset
total = tp + fp + fn + tn
precision = tp / (tp + fp) if (tp + fp) else 0.0
recall = tp / (tp + fn) if (tp + fn) else 0.0
accuracy = (tp + tn) / total if total else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
mean_iou = iou_sum / iou_count if iou_count else 0.0
mean_l2 = l2_sum / l2_count if l2_count else 0.0
mean_auc = auc_sum / auc_count if auc_count else 0.0
metrics = {
"dataset": dataset_name,
"samples": total,
"samples_no_mask": no_mask_samples,
"tp": tp, "fp": fp, "fn": fn, "tn": tn,
"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1,
"mean_iou": mean_iou, "iou_samples": iou_count, "iou_sum": iou_sum,
"mean_l2": mean_l2, "l2_samples": l2_count, "l2_sum": l2_sum,
"mean_auc": mean_auc, "auc_samples": auc_count, "auc_sum": auc_sum,
}
return metrics
def evaluate_worker(rank, world_size, args, model_path, model_args, dataset_jobs, results):
if torch.cuda.is_available():
torch.cuda.set_device(rank)
worker_args = SimpleNamespace(**vars(model_args))
worker_args.device_map = f"cuda:{rank}" if torch.cuda.is_available() else "auto"
if getattr(worker_args, "attn_implementation", "flash_attention_2") == "flash_attention_2":
worker_args.torch_dtype = torch.bfloat16
model_wrapper = Qwen2_5_VL(model_path, worker_args)
worker_metrics = []
for job in dataset_jobs:
local_samples = shard_samples(job["samples"], rank, world_size)
metrics = evaluate_dataset(
model_wrapper,
job["dataset_name"],
local_samples,
args.seg_threshold,
shuffle=False,
)
worker_metrics.append(metrics)
results[rank] = worker_metrics
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str,
default="wooohyeooon/MEDIC-AD",
help="Path to the locally fine-tuned Qwen2.5-VL model.")
parser.add_argument(
"--dataset-roots",
nargs="+",
default=[
"/path/to/med_anomaly_seg/chestx_det/test",
"/path/to/med_anomaly_seg/BraTS2021_slice/test",
"/path/to/med_anomaly_seg/RESC/test",
"/path/to/med_anomaly_seg/hist_DIY/test",
],
help="One or more dataset roots. Each root should contain 'good'/'ungood' folders "
"with 'img' and 'label' subfolders, or point directly to a class folder. "
"You may also pass a .jsonl file (MMXU) to evaluate images without masks.",
)
parser.add_argument("--mmxu-image-root", type=str, default=MMXU_DEFAULT_IMAGE_ROOT,
help="Base path for MMXU jsonl image paths when dataset-roots include a .jsonl file.")
parser.add_argument("--seg-threshold", type=float, default=0.5,
help="Threshold applied to probabilities to produce binary masks.")
parser.add_argument("--num_gpus", type=int, default=None,
help="Number of GPUs to use (default: all visible)")
args = parser.parse_args()
visible_gpus = torch.cuda.device_count()
requested_gpus = args.num_gpus if args.num_gpus is not None else visible_gpus
world_size = min(requested_gpus, visible_gpus) if visible_gpus > 0 else 0
use_multi_gpu = world_size > 1
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
model_args = SimpleNamespace(
temperature=0.0,
top_p=0.0001,
repetition_penalty=1.0,
max_new_tokens=16,
output_attentions=False,
attn_implementation=attn_implementation,
device_map=None,
)
dataset_jobs = []
for root in args.dataset_roots:
samples = gather_samples(root, mmxu_image_root=args.mmxu_image_root)
if not samples:
continue
dataset_jobs.append(
{
"root": root,
"dataset_name": resolve_dataset_name(root),
"samples": samples,
}
)
if not dataset_jobs:
return
if use_multi_gpu:
manager = mp.Manager()
results = manager.list([None] * world_size)
mp.spawn(
evaluate_worker,
args=(world_size, args, args.model, model_args, dataset_jobs, results),
nprocs=world_size,
)
all_metrics = []
for dataset_idx, job in enumerate(dataset_jobs):
per_rank_metrics = [results[rank][dataset_idx] for rank in range(world_size)]
metrics = aggregate_metrics(per_rank_metrics)
all_metrics.append(metrics)
else:
model_wrapper = Qwen2_5_VL(args.model, model_args)
all_metrics = []
for job in dataset_jobs:
metrics = evaluate_dataset(
model_wrapper,
job["dataset_name"],
job["samples"],
args.seg_threshold,
shuffle=True,
)
all_metrics.append(metrics)
for metrics in all_metrics:
dataset_name = metrics["dataset"]
print(f"\n{format_dataset_metrics(metrics)}")
# Global aggregation (main process only)
if all_metrics:
total_samples = sum(m["samples"] for m in all_metrics)
total_tp = sum(m["tp"] for m in all_metrics)
total_fp = sum(m["fp"] for m in all_metrics)
total_fn = sum(m["fn"] for m in all_metrics)
total_tn = sum(m["tn"] for m in all_metrics)
total_iou = sum(m.get("iou_sum", 0.0) for m in all_metrics)
total_iou_count = sum(m["iou_samples"] for m in all_metrics)
total_l2 = sum(m.get("l2_sum", 0.0) for m in all_metrics)
total_l2_count = sum(m["l2_samples"] for m in all_metrics)
total_auc = sum(m.get("auc_sum", 0.0) for m in all_metrics)
total_auc_count = sum(m["auc_samples"] for m in all_metrics)
mean_iou = total_iou / total_iou_count if total_iou_count else 0.0
mean_l2 = total_l2 / total_l2_count if total_l2_count else 0.0
mean_auc = total_auc / total_auc_count if total_auc_count else 0.0
print("\n=== Aggregate Results ===")
print(
f"Mean IoU: {mean_iou:.4f}, Mean L2: {mean_l2:.4f}, Mean AUC: {mean_auc:.4f}"
)
print(
f"Samples: {total_samples}, IoU Samples: {total_iou_count}, "
f"L2 Samples: {total_l2_count}, AUC Samples: {total_auc_count}"
)
print("\n=== Final Dataset Results ===")
for metrics in all_metrics:
print(format_dataset_metrics(metrics))
if __name__ == "__main__":
main()