Skip to content

Commit

Permalink
fixed bug with traversability estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Nov 2, 2023
1 parent a27f1bc commit d09caf9
Showing 1 changed file with 17 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
extraction_store_folder=None,
anomaly_detection: bool = False,
vis_training_samples: bool = False,
**kwargs,
use_feature_extractor: bool = False**kwargs,
):
self._device = device
self._mode = mode
Expand Down Expand Up @@ -85,13 +85,15 @@ def __init__(
self._segmentation_type = segmentation_type
self._feature_type = feature_type

self._feature_extractor = FeatureExtractor(
self._device,
segmentation_type=self._segmentation_type,
feature_type=self._feature_type,
input_size=image_size,
**kwargs,
)
self._use_feature_extractor = use_feature_extractor
if use_feature_extractor:
self._feature_extractor = FeatureExtractor(
self._device,
segmentation_type=self._segmentation_type,
feature_type=self._feature_type,
input_size=image_size,
**kwargs,
)

# Mutex
self._learning_lock = Lock()
Expand Down Expand Up @@ -215,9 +217,10 @@ def change_device(self, device: str):
"""
self._supervision_graph.change_device(device)
self._mission_graph.change_device(device)
self._feature_extractor.change_device(device)
self._model = self._model.to(device)

if self._use_feature_extractor:
self._feature_extractor.change_device(device)
if self._scale_traversability:
# Use 500 bins for constant memory usuage
self._auxiliary_training_roc.to(device)
Expand All @@ -229,6 +232,11 @@ def update_features(self, node: MissionNode):
Args:
node (MissionNode): new node in the mission graph
"""
if not self._use_feature_extractor:
raise ValueError(
"Udate features can be not called given that when creating the TraversabilityEstimator the FeatureExtractor was not used: use_feature_extractor = False"
)

if self._mode != WVNMode.EXTRACT_LABELS:
# Extract features
# Check do we need to add here the .clone() in
Expand Down

0 comments on commit d09caf9

Please sign in to comment.