diff --git a/samgeo/text_sam.py b/samgeo/text_sam.py index 11e3c45c..369e6dc1 100644 --- a/samgeo/text_sam.py +++ b/samgeo/text_sam.py @@ -257,7 +257,7 @@ def predict( save_args (dict, optional): Save arguments for the prediction. Defaults to {}. return_results (bool, optional): Whether to return the results. Defaults to False. detection_filter (callable, optional): - Callable which with box, mask, logit, phrase, and index args returns a boolean. + Callable with box, mask, logit, phrase, and index args returns a boolean. If provided, the function will be called for each detected object. Defaults to None. @@ -325,8 +325,7 @@ def predict( if not callable(detection_filter): raise ValueError("detection_filter must be callable.") - req_nargs = 6 if inspect.ismethod(detection_filter) else 5 - if not len(inspect.signature(detection_filter).parameters) == req_nargs: + if not len(inspect.signature(detection_filter).parameters) == 5: raise ValueError( "detection_filter required args: " "box, mask, logit, phrase, and index."