-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
46 lines (36 loc) · 1.33 KB
/
predict.py
File metadata and controls
46 lines (36 loc) · 1.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import argparse
import joblib
import numpy as np
import pandas as pd
from glob import glob
def load_model(model_path):
return joblib.load(model_path)
def predict_and_save(model, path, log_target=False):
df = joblib.load(path)
if 'flood_depth' not in df.columns:
return
# Drop unused columns
features = df.drop(columns=['flood_depth', 'huc8', 'return_period'], errors='ignore')
features = features.select_dtypes(include=[np.number])
preds = model.predict(features)
if log_target:
preds = np.expm1(preds)
df['predicted'] = preds
joblib.dump(df, path)
print(f"✅ Updated: {path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, help='Path to trained model.pkl')
parser.add_argument('--root_dir', required=True, help='Directory with *_test.pkl files')
parser.add_argument('--log_target', action='store_true', help='If model was trained on log1p')
args = parser.parse_args()
model = load_model(args.model)
test_files = glob(os.path.join(args.root_dir, "**", "*_test.pkl"), recursive=True)
for path in test_files:
try:
predict_and_save(model, path, log_target=args.log_target)
except Exception as e:
print(f"❌ Failed on {path}: {e}")
if __name__ == "__main__":
main()