-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathuncertainty_aware_trajectory_model
More file actions
2248 lines (1850 loc) · 87.1 KB
/
uncertainty_aware_trajectory_model
File metadata and controls
2248 lines (1850 loc) · 87.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
SafePathAI: Advanced Uncertainty-Aware Trajectory Prediction for Safe Autonomous Driving
Complete implementation with advanced features for research and real-world applications
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import math
import pandas as pd
import random
from collections import defaultdict
from typing import Dict, List, Tuple, Optional, Union, Callable, Any, Set
import copy
import tqdm
import time
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans
from sklearn.metrics import r2_score
from scipy.stats import multivariate_normal
import networkx as nx
# For normalizing flows
try:
import pyro
import pyro.distributions as dist
import pyro.distributions.transforms as T
from pyro.nn import DenseNN
PYRO_AVAILABLE = True
except ImportError:
PYRO_AVAILABLE = False
print("Warning: Pyro not available. Normalizing flows will not be available.")
# For multi-GPU training
try:
import horovod.torch as hvd
HOROVOD_AVAILABLE = True
except ImportError:
HOROVOD_AVAILABLE = False
print("Warning: Horovod not available. Multi-GPU training will not be available.")
# For tensor visualizations
try:
from tensorboardX import SummaryWriter
TB_AVAILABLE = True
except ImportError:
TB_AVAILABLE = False
print("Warning: TensorboardX not available. Visualizations will be limited.")
#######################
# CONFIG
#######################
class Config:
"""Advanced configuration for the SafePathAI model."""
def __init__(self):
# Data parameters
self.input_seq_len = 20 # Length of input trajectory sequence
self.pred_seq_len = 30 # Length of predicted trajectory sequence
self.input_dim = 4 # Base features (x, y, vx, vy)
self.output_dim = 2 # Position prediction (x, y)
self.map_features = 64 # Number of map features
self.agent_features = 32 # Features for other agents
self.weather_features = 8 # Weather condition features
self.time_features = 4 # Time of day features
self.intention_classes = 6 # Number of intention classes
# Dataset parameters
self.dataset_type = "nuScenes" # Options: "nuScenes", "Argoverse", "Waymo", "Lyft"
self.data_path = "./data"
self.use_map_data = True
self.use_traffic_rules = True
self.use_weather_data = True
self.use_time_data = True
self.random_rotation = True # Apply random rotations for data augmentation
self.random_noise = True # Add random noise for data augmentation
self.max_agents = 32 # Maximum number of agents to consider
self.max_neighbor_dist = 50.0 # Maximum distance to consider agents as neighbors (meters)
# Model parameters
self.model_type = "transformer" # Options: "lstm", "transformer", "gnn", "social_gan", "scene_transformer"
self.hidden_size = 256
self.embedding_size = 128
self.num_layers = 4
self.num_heads = 8
self.dropout = 0.1
self.mc_dropout_samples = 50
self.ensemble_size = 5
self.num_modes = 6 # Number of prediction modes (for multimodal prediction)
self.use_attention = True
self.use_scene_graph = True
self.use_map_encoder = True
self.use_social_encoder = True
self.use_multimodal = True
self.use_evidential = True
self.use_normalizing_flows = PYRO_AVAILABLE
# Training parameters
self.batch_size = 64
self.learning_rate = 0.001
self.weight_decay = 1e-5
self.num_epochs = 100
self.lr_scheduler = "cosine" # Options: "step", "plateau", "cosine"
self.lr_step_size = 20
self.lr_gamma = 0.5
self.gradient_clip = 1.0
self.early_stopping = True
self.patience = 10
self.validation_freq = 1
self.save_freq = 5
self.num_workers = 8
self.seed = 42
# Kalman Filter parameters
self.kf_type = "IMM" # Options: "KF", "EKF", "UKF", "IMM"
self.num_motion_models = 3 # Number of motion models for IMM
self.q_var = 0.01 # Process noise variance
self.r_var = 0.1 # Measurement noise variance
# Uncertainty thresholds
self.uncertainty_threshold = 0.5 # Threshold for high uncertainty
self.safety_factor = 2.0 # Safety factor for uncertainty inflation
# Evaluation parameters
self.eval_metrics = ["ade", "fde", "nll", "mr", "calibration", "collisions"]
self.adversarial_test = True
self.eval_k_predictions = [1, 5, 10] # K values for minADE/minFDE
# Safety parameters
self.use_safety_envelope = True
self.safety_envelope_size = 2.0 # Size of safety envelope (meters)
self.collision_threshold = 1.5 # Collision threshold (meters)
self.emergency_stopping_threshold = 0.8 # Emergency stopping threshold
self.use_contingency_planning = True
self.max_contingency_plans = 3
# Visualization parameters
self.visualize_predictions = True
self.visualize_uncertainty = True
self.visualize_attention = True
self.visualize_scene_graph = True
self.save_visualizations = True
self.vis_path = "./visualizations"
# Logging parameters
self.log_dir = "./logs"
self.checkpoint_dir = "./checkpoints"
self.result_dir = "./results"
self.use_tensorboard = TB_AVAILABLE
self.log_freq = 10
# Environment parameters
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.use_multi_gpu = HOROVOD_AVAILABLE and torch.cuda.device_count() > 1
self.precision = "mixed" # Options: "fp32", "fp16", "mixed"
# Create directories
os.makedirs(self.log_dir, exist_ok=True)
os.makedirs(self.checkpoint_dir, exist_ok=True)
os.makedirs(self.result_dir, exist_ok=True)
os.makedirs(self.vis_path, exist_ok=True)
def update(self, **kwargs):
"""Update config parameters based on a dictionary."""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
else:
raise ValueError(f"Config has no attribute '{key}'")
return self
def to_dict(self):
"""Convert config to dictionary."""
return {key: value for key, value in self.__dict__.items()
if not key.startswith('__') and not callable(value)}
def __str__(self):
"""String representation of the config."""
return "\n".join(f"{key}: {value}" for key, value in self.to_dict().items())
#######################
# DATA PROCESSING
#######################
class MapData:
"""
Handle map data processing for trajectory prediction.
Includes methods for loading, preprocessing, and extracting map features.
"""
def __init__(self, config: Config):
"""
Initialize the map data processor.
Args:
config: Configuration object
"""
self.config = config
self.maps = {}
self.map_features = {}
self.traffic_lights = {}
self.traffic_signs = {}
self.lane_connections = {}
self.crosswalks = {}
self.speed_limits = {}
# Load map data if specified
if self.config.use_map_data:
self._load_map_data()
def _load_map_data(self):
"""
Load map data from the specified dataset.
Supports multiple dataset formats (nuScenes, Argoverse, etc.).
"""
print(f"Loading map data from {self.config.dataset_type}...")
# Different loading procedures based on dataset type
if self.config.dataset_type == "nuScenes":
self._load_nuscenes_maps()
elif self.config.dataset_type == "Argoverse":
self._load_argoverse_maps()
elif self.config.dataset_type == "Waymo":
self._load_waymo_maps()
else:
# Example implementation for a custom dataset
map_files = os.listdir(os.path.join(self.config.data_path, "maps"))
for map_file in map_files:
map_id = map_file.split(".")[0]
self.maps[map_id] = self._load_map(os.path.join(self.config.data_path, "maps", map_file))
print(f"Loaded {len(self.maps)} maps.")
def _load_nuscenes_maps(self):
"""Load maps from nuScenes dataset."""
# This is a placeholder - in a real implementation, use nuScenes API
try:
from nuscenes.map_expansion.map_api import NuScenesMap
from nuscenes.nuscenes import NuScenes
nusc = NuScenes(version='v1.0-mini', dataroot=self.config.data_path, verbose=True)
# Load maps for each location
for location in ['boston-seaport', 'singapore-hollandvillage',
'singapore-onenorth', 'singapore-queenstown']:
nusc_map = NuScenesMap(dataroot=self.config.data_path, map_name=location)
self.maps[location] = nusc_map
# Extract lanes, crosswalks, etc.
lanes = nusc_map.lane_polygons
ped_crossings = nusc_map.ped_crossing_polygons
walkways = nusc_map.walkway_polygons
stop_lines = nusc_map.stop_line_polygons
# Store map features
self.map_features[location] = {
'lanes': lanes,
'ped_crossings': ped_crossings,
'walkways': walkways,
'stop_lines': stop_lines
}
# Extract traffic lights and signs
traffic_lights = [] # In a real implementation, extract from nuScenes
traffic_signs = [] # In a real implementation, extract from nuScenes
self.traffic_lights[location] = traffic_lights
self.traffic_signs[location] = traffic_signs
# Extract lane connections
lane_connections = {} # In a real implementation, extract lane connectivity
self.lane_connections[location] = lane_connections
# Extract speed limits
speed_limits = {} # In a real implementation, extract speed limits
self.speed_limits[location] = speed_limits
except ImportError:
print("Warning: nuScenes API not available. Using placeholder map data.")
# Create placeholder map data
self.maps["placeholder"] = {"lanes": [], "crosswalks": []}
self.map_features["placeholder"] = {"lanes": [], "crosswalks": []}
def _load_argoverse_maps(self):
"""Load maps from Argoverse dataset."""
# This is a placeholder - in a real implementation, use Argoverse API
try:
from argoverse.map_representation.map_api import ArgoverseMap
argoverse_map = ArgoverseMap()
self.maps["argoverse"] = argoverse_map
# Extract map features
for city in argoverse_map.city_name_to_city_id.keys():
# Extract lanes, crosswalks, etc.
lanes = [] # In a real implementation, extract lanes from Argoverse
crosswalks = [] # In a real implementation, extract crosswalks
self.map_features[city] = {
'lanes': lanes,
'crosswalks': crosswalks
}
# Extract traffic lights and signs
self.traffic_lights[city] = []
self.traffic_signs[city] = []
# Extract lane connections
self.lane_connections[city] = {}
# Extract speed limits
self.speed_limits[city] = {}
except ImportError:
print("Warning: Argoverse API not available. Using placeholder map data.")
# Create placeholder map data
self.maps["placeholder"] = {"lanes": [], "crosswalks": []}
self.map_features["placeholder"] = {"lanes": [], "crosswalks": []}
def _load_waymo_maps(self):
"""Load maps from Waymo dataset."""
# This is a placeholder - in a real implementation, use Waymo API
print("Warning: Waymo map loading not implemented. Using placeholder map data.")
# Create placeholder map data
self.maps["placeholder"] = {"lanes": [], "crosswalks": []}
self.map_features["placeholder"] = {"lanes": [], "crosswalks": []}
def _load_map(self, map_path: str):
"""Load a single map file."""
# This is a placeholder - in a real implementation, load actual map data
return {"lanes": [], "crosswalks": []}
def extract_map_features(self, position: np.ndarray,
map_id: str,
radius: float = 50.0) -> np.ndarray:
"""
Extract map features around a given position.
Args:
position: (x, y) position
map_id: ID of the map to use
radius: Radius around the position to extract features
Returns:
Map features as a numpy array
"""
# Handle cases where map data isn't available
if not self.config.use_map_data or map_id not in self.map_features:
return np.zeros(self.config.map_features)
# In a real implementation, extract actual map features
# For this example, we'll return a random feature vector
map_features = np.random.randn(self.config.map_features)
# Normalize the features
map_features = map_features / (np.linalg.norm(map_features) + 1e-8)
return map_features
def get_lane_direction(self, position: np.ndarray, map_id: str) -> np.ndarray:
"""
Get the lane direction at a given position.
Args:
position: (x, y) position
map_id: ID of the map to use
Returns:
Lane direction as a unit vector [dx, dy]
"""
# In a real implementation, find the closest lane and get its direction
# For this example, we'll return a random direction
direction = np.random.randn(2)
return direction / (np.linalg.norm(direction) + 1e-8)
def get_nearest_lane_distance(self, position: np.ndarray, map_id: str) -> float:
"""
Get the distance to the nearest lane.
Args:
position: (x, y) position
map_id: ID of the map to use
Returns:
Distance to the nearest lane
"""
# In a real implementation, find the closest lane and calculate distance
# For this example, we'll return a random distance
return np.random.uniform(0, 5)
def get_speed_limit(self, position: np.ndarray, map_id: str) -> float:
"""
Get the speed limit at a given position.
Args:
position: (x, y) position
map_id: ID of the map to use
Returns:
Speed limit in m/s
"""
# In a real implementation, find the speed limit for the current lane
# For this example, we'll return a random speed limit
return np.random.uniform(5, 25) # 5-25 m/s (18-90 km/h)
def get_traffic_light_state(self, position: np.ndarray,
map_id: str,
direction: np.ndarray) -> int:
"""
Get the traffic light state in the direction of travel.
Args:
position: (x, y) position
map_id: ID of the map to use
direction: Direction of travel [dx, dy]
Returns:
Traffic light state (0=red, 1=yellow, 2=green, -1=none)
"""
# In a real implementation, find the relevant traffic light
# For this example, we'll return a random state
return np.random.choice([-1, 0, 1, 2], p=[0.7, 0.1, 0.05, 0.15])
def is_on_crosswalk(self, position: np.ndarray, map_id: str) -> bool:
"""
Check if the position is on a crosswalk.
Args:
position: (x, y) position
map_id: ID of the map to use
Returns:
True if on crosswalk, False otherwise
"""
# In a real implementation, check if the position is inside any crosswalk polygon
# For this example, we'll return a random boolean
return np.random.random() < 0.05 # 5% chance of being on a crosswalk
def render_map(self, ax, position: np.ndarray, map_id: str, radius: float = 50.0):
"""
Render the map on a matplotlib axis.
Args:
ax: Matplotlib axis
position: Center position [x, y]
map_id: ID of the map to use
radius: Radius around the position to render
"""
if not self.config.use_map_data or map_id not in self.map_features:
return
# This is a placeholder - in a real implementation, render actual map features
# For this example, we'll just add a grid
ax.grid(True, linestyle='--', alpha=0.7)
# Add some fake lanes
for i in range(-3, 4):
# Horizontal lanes
ax.plot([position[0] - radius, position[0] + radius],
[position[1] + i * 4, position[1] + i * 4],
'g-', alpha=0.5)
# Vertical lanes
ax.plot([position[0] + i * 4, position[0] + i * 4],
[position[1] - radius, position[1] + radius],
'g-', alpha=0.5)
# Add some fake crosswalks
for i in range(-2, 3, 2):
# Horizontal crosswalks
ax.plot([position[0] - radius/4, position[0] + radius/4],
[position[1] + i * 10, position[1] + i * 10],
'y-', linewidth=3, alpha=0.7)
# Vertical crosswalks
ax.plot([position[0] + i * 10, position[0] + i * 10],
[position[1] - radius/4, position[1] + radius/4],
'y-', linewidth=3, alpha=0.7)
class WeatherData:
"""
Handle weather data processing for trajectory prediction.
Includes methods for loading, preprocessing, and extracting weather features.
"""
def __init__(self, config: Config):
"""
Initialize the weather data processor.
Args:
config: Configuration object
"""
self.config = config
self.weather_data = {}
# Load weather data if specified
if self.config.use_weather_data:
self._load_weather_data()
def _load_weather_data(self):
"""Load weather data from files or API."""
print("Loading weather data...")
# This is a placeholder - in a real implementation, load actual weather data
# For this example, we'll create synthetic weather data
weather_types = ["clear", "rain", "snow", "fog", "cloudy"]
temperature_range = (-10, 40) # Celsius
precipitation_range = (0, 50) # mm/h
visibility_range = (50, 10000) # meters
wind_speed_range = (0, 30) # m/s
# Create synthetic weather data for a year
dates = pd.date_range(start="2020-01-01", end="2020-12-31", freq="H")
for date in dates:
day_of_year = date.dayofyear / 365.0
hour_of_day = date.hour / 24.0
# Generate weather parameters with seasonal and daily variations
season_factor = math.sin(day_of_year * 2 * math.pi)
day_factor = math.sin(hour_of_day * 2 * math.pi)
# Temperature follows seasonal pattern
temperature = temperature_range[0] + (temperature_range[1] - temperature_range[0]) * (0.5 + 0.4 * season_factor + 0.1 * day_factor)
# Precipitation more likely in certain seasons
precip_probability = 0.2 + 0.2 * (1 + season_factor)
precipitation = 0
if random.random() < precip_probability:
precipitation = random.uniform(0, precipitation_range[1])
# Visibility affected by precipitation and time of day
visibility_factor = 1.0
if precipitation > 10:
visibility_factor = 0.5
elif precipitation > 0:
visibility_factor = 0.8
if hour_of_day < 0.25 or hour_of_day > 0.75: # Night time
visibility_factor *= 0.7
visibility = visibility_range[0] + (visibility_range[1] - visibility_range[0]) * visibility_factor
# Wind speed
wind_speed = random.uniform(wind_speed_range[0], wind_speed_range[1])
# Weather type
if precipitation > 20 and temperature < 2:
weather_type = "snow"
elif precipitation > 5:
weather_type = "rain"
elif visibility < 1000:
weather_type = "fog"
elif random.random() < 0.3:
weather_type = "cloudy"
else:
weather_type = "clear"
# Store the weather data
date_str = date.strftime("%Y-%m-%d %H:%M:%S")
self.weather_data[date_str] = {
"temperature": temperature,
"precipitation": precipitation,
"visibility": visibility,
"wind_speed": wind_speed,
"weather_type": weather_type
}
print(f"Generated weather data for {len(self.weather_data)} time points.")
def get_weather_features(self, timestamp: str) -> np.ndarray:
"""
Get weather features for a given timestamp.
Args:
timestamp: Timestamp in the format "YYYY-MM-DD HH:MM:SS"
Returns:
Weather features as a numpy array
"""
if not self.config.use_weather_data or not self.weather_data:
return np.zeros(self.config.weather_features)
# Find the closest timestamp in our weather data
# In a real implementation, use more sophisticated time matching
weather_timestamps = list(self.weather_data.keys())
closest_ts = min(weather_timestamps, key=lambda x: abs(pd.Timestamp(x) - pd.Timestamp(timestamp)))
weather = self.weather_data[closest_ts]
# Create a feature vector
features = np.zeros(self.config.weather_features)
# Normalize temperature to [-1, 1]
features[0] = (weather["temperature"] - 15) / 25
# Normalize precipitation to [0, 1]
features[1] = min(1.0, weather["precipitation"] / 50)
# Normalize visibility to [0, 1]
features[2] = weather["visibility"] / 10000
# Normalize wind speed to [0, 1]
features[3] = weather["wind_speed"] / 30
# One-hot encode weather type
weather_types = ["clear", "rain", "snow", "fog", "cloudy"]
weather_idx = weather_types.index(weather["weather_type"]) if weather["weather_type"] in weather_types else 0
features[4 + weather_idx] = 1.0
return features
class TrafficRules:
"""
Handle traffic rules for trajectory prediction.
Includes methods for loading and checking traffic rules.
"""
def __init__(self, config: Config, map_data: MapData):
"""
Initialize the traffic rules processor.
Args:
config: Configuration object
map_data: Map data object
"""
self.config = config
self.map_data = map_data
self.traffic_rules = {}
# Load traffic rules if specified
if self.config.use_traffic_rules:
self._load_traffic_rules()
def _load_traffic_rules(self):
"""Load traffic rules from files or database."""
print("Loading traffic rules...")
# This is a placeholder - in a real implementation, load actual traffic rules
# For this example, we'll create synthetic traffic rules
rule_types = ["speed_limit", "stop_sign", "yield", "no_turn", "one_way"]
for rule_type in rule_types:
self.traffic_rules[rule_type] = {}
def check_traffic_rules(self,
position: np.ndarray,
direction: np.ndarray,
map_id: str) -> Dict[str, Any]:
"""
Check traffic rules at a given position and direction.
Args:
position: (x, y) position
direction: Direction of travel [dx, dy]
map_id: ID of the map to use
Returns:
Dictionary of applicable traffic rules
"""
if not self.config.use_traffic_rules:
return {}
# In a real implementation, find applicable traffic rules
# For this example, we'll return synthetic rules
# Get speed limit from map data
speed_limit = self.map_data.get_speed_limit(position, map_id)
# Get traffic light state from map data
traffic_light = self.map_data.get_traffic_light_state(position, map_id, direction)
# Random rules
stop_sign = random.random() < 0.05
yield_sign = random.random() < 0.05 and not stop_sign
no_turn = random.random() < 0.1
one_way = random.random() < 0.3
# Return combined rules
return {
"speed_limit": speed_limit,
"traffic_light": traffic_light,
"stop_sign": stop_sign,
"yield_sign": yield_sign,
"no_turn": no_turn,
"one_way": one_way
}
class AgentState:
"""
Represents the state of an agent (vehicle, pedestrian, etc.) in the scene.
"""
def __init__(self,
agent_id: str,
agent_type: str,
position: np.ndarray,
velocity: np.ndarray,
heading: float,
length: float = 4.0,
width: float = 2.0,
timestamp: float = 0.0):
"""
Initialize an agent state.
Args:
agent_id: Unique identifier for the agent
agent_type: Type of agent (vehicle, pedestrian, cyclist, etc.)
position: (x, y) position
velocity: (vx, vy) velocity
heading: Heading angle in radians
length: Length of the agent in meters
width: Width of the agent in meters
timestamp: Timestamp of the state
"""
self.agent_id = agent_id
self.agent_type = agent_type
self.position = np.array(position)
self.velocity = np.array(velocity)
self.heading = heading
self.length = length
self.width = width
self.timestamp = timestamp
# Derived properties
self.speed = np.linalg.norm(velocity)
self.acceleration = np.zeros(2) # Will be calculated from velocity changes
self.yaw_rate = 0.0 # Will be calculated from heading changes
# Future trajectory (ground truth)
self.future_trajectory = None
# Additional properties
self.lane_id = None
self.on_intersection = False
self.distance_to_lane = 0.0
self.distance_to_intersection = float('inf')
self.nearest_agents = {} # {agent_id: distance}
self.traffic_rules = {}
self.intention = None
self.intention_prob = None
def update_derived_properties(self, prev_state: Optional['AgentState'] = None):
"""
Update derived properties based on current state and previous state.
Args:
prev_state: Previous state of the agent
"""
self.speed = np.linalg.norm(self.velocity)
if prev_state is not None:
dt = self.timestamp - prev_state.timestamp
if dt > 0:
# Calculate acceleration
self.acceleration = (self.velocity - prev_state.velocity) / dt
# Calculate yaw rate
heading_diff = self.heading - prev_state.heading
# Normalize to [-pi, pi]
heading_diff = (heading_diff + np.pi) % (2 * np.pi) - np.pi
self.yaw_rate = heading_diff / dt
def to_vector(self) -> np.ndarray:
"""
Convert agent state to feature vector.
Returns:
Feature vector representing the agent state
"""
# Basic state features: [x, y, vx, vy, heading, speed, acc_x, acc_y, yaw_rate]
features = np.zeros(9)
features[0:2] = self.position
features[2:4] = self.velocity
features[4] = self.heading
features[5] = self.speed
features[6:8] = self.acceleration
features[8] = self.yaw_rate
return features
def predict_state(self, dt: float, motion_model: str = "CV") -> 'AgentState':
"""
Predict future state after time dt using a motion model.
Args:
dt: Time increment in seconds
motion_model: Motion model to use (CV: constant velocity,
CA: constant acceleration,
CTRV: constant turn rate and velocity)
Returns:
Predicted agent state
"""
# Create a copy of current state
predicted_state = copy.deepcopy(self)
predicted_state.timestamp += dt
if motion_model == "CV": # Constant velocity
# Update position
predicted_state.position = self.position + self.velocity * dt
# Velocity remains constant
predicted_state.velocity = self.velocity
# Heading remains constant
predicted_state.heading = self.heading
elif motion_model == "CA": # Constant acceleration
# Update position
predicted_state.position = self.position + self.velocity * dt + 0.5 * self.acceleration * dt**2
# Update velocity
predicted_state.velocity = self.velocity + self.acceleration * dt
# Heading update based on velocity direction if speed is sufficient
if np.linalg.norm(predicted_state.velocity) > 0.5:
predicted_state.heading = np.arctan2(predicted_state.velocity[1], predicted_state.velocity[0])
elif motion_model == "CTRV": # Constant turn rate and velocity
# Handle the case where yaw_rate is close to zero
if abs(self.yaw_rate) < 1e-6:
# Same as CV model if yaw_rate is zero
predicted_state.position = self.position + self.velocity * dt
predicted_state.heading = self.heading
else:
# Calculate heading
predicted_state.heading = self.heading + self.yaw_rate * dt
# Calculate position change
v_over_omega = self.speed / self.yaw_rate
predicted_state.position[0] = self.position[0] + v_over_omega * (
np.sin(predicted_state.heading) - np.sin(self.heading))
predicted_state.position[1] = self.position[1] + v_over_omega * (
-np.cos(predicted_state.heading) + np.cos(self.heading))
# Speed remains constant, but direction changes
speed = self.speed
predicted_state.velocity = np.array([
speed * np.cos(predicted_state.heading),
speed * np.sin(predicted_state.heading)
])
else:
raise ValueError(f"Unknown motion model: {motion_model}")
return predicted_state
def to_dict(self) -> Dict:
"""Convert agent state to dictionary."""
return {
'agent_id': self.agent_id,
'agent_type': self.agent_type,
'position': self.position.tolist(),
'velocity': self.velocity.tolist(),
'heading': self.heading,
'length': self.length,
'width': self.width,
'timestamp': self.timestamp,
'speed': self.speed,
'acceleration': self.acceleration.tolist(),
'yaw_rate': self.yaw_rate,
'lane_id': self.lane_id,
'on_intersection': self.on_intersection,
'distance_to_lane': self.distance_to_lane,
'distance_to_intersection': self.distance_to_intersection,
'intention': self.intention,
'intention_prob': self.intention_prob
}
@classmethod
def from_dict(cls, data: Dict) -> 'AgentState':
"""Create agent state from dictionary."""
state = cls(
agent_id=data['agent_id'],
agent_type=data['agent_type'],
position=np.array(data['position']),
velocity=np.array(data['velocity']),
heading=data['heading'],
length=data['length'],
width=data['width'],
timestamp=data['timestamp']
)
# Set additional properties if available
if 'acceleration' in data:
state.acceleration = np.array(data['acceleration'])
if 'yaw_rate' in data:
state.yaw_rate = data['yaw_rate']
if 'lane_id' in data:
state.lane_id = data['lane_id']
if 'on_intersection' in data:
state.on_intersection = data['on_intersection']
if 'distance_to_lane' in data:
state.distance_to_lane = data['distance_to_lane']
if 'distance_to_intersection' in data:
state.distance_to_intersection = data['distance_to_intersection']
if 'intention' in data:
state.intention = data['intention']
if 'intention_prob' in data:
state.intention_prob = data['intention_prob']
return state
def get_bounding_box(self) -> np.ndarray:
"""
Get the bounding box of the agent.
Returns:
Array of (x, y) coordinates for the 4 corners of the bounding box
"""
# Rotation matrix
c, s = np.cos(self.heading), np.sin(self.heading)
R = np.array([[c, -s], [s, c]])
# Half-dimensions
hl = self.length / 2
hw = self.width / 2
# Corners in vehicle frame
corners_local = np.array([
[-hl, -hw], # rear-left
[hl, -hw], # front-left
[hl, hw], # front-right
[-hl, hw] # rear-right
])
# Corners in world frame
corners_world = np.array([self.position + R @ corner for corner in corners_local])
return corners_world
def check_collision(self, other: 'AgentState') -> bool:
"""
Check if this agent collides with another agent.
Args:
other: The other agent state
Returns:
True if collision, False otherwise
"""
# Simple distance-based collision check
dist = np.linalg.norm(self.position - other.position)
min_dist = (self.length + other.length) / 2
return dist < min_dist
class SceneState:
"""
Represents the state of the entire scene, including all agents and environmental context.
"""
def __init__(self,
timestamp: float,
map_id: str,
ego_agent_id: str,
weather_info: Optional[Dict] = None,
time_info: Optional[Dict] = None):
"""
Initialize a scene state.
Args:
timestamp: Timestamp of the scene
map_id: ID of the map
ego_agent_id: ID of the ego agent
weather_info: Weather information
time_info: Time information (time of day, etc.)
"""
self.timestamp = timestamp
self.map_id = map_id
self.ego_agent_id = ego_agent_id
self.weather_info = weather_info if weather_info is not None else {}
self.time_info = time_info if time_info is not None else {}
# Agents in the scene
self.agents: Dict[str, AgentState] = {}
# Scene graph representation
self.scene_graph = nx.DiGraph()
# Traffic rules applicable to the scene
self.traffic_rules = {}