Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add detection_filter to predict() method to allow for user-defined logic #307

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions samgeo/text_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
Credits to Luca Medeiros for the original implementation.
"""

import argparse
import inspect
import os
import warnings
import argparse

import numpy as np
import torch
from PIL import Image
Expand Down Expand Up @@ -238,6 +240,7 @@ def predict(
save_args={},
return_results=False,
return_coords=False,
detection_filter=None,
**kwargs,
):
"""
Expand All @@ -253,6 +256,10 @@ def predict(
dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.
save_args (dict, optional): Save arguments for the prediction. Defaults to {}.
return_results (bool, optional): Whether to return the results. Defaults to False.
detection_filter (callable, optional):
Callable which with box, mask, logit, phrase, and index args returns a boolean.
If provided, the function will be called for each detected object.
Defaults to None.

Returns:
tuple: Tuple containing masks, boxes, phrases, and logits.
Expand Down Expand Up @@ -312,12 +319,34 @@ def predict(
image_np[..., 0], dtype=dtype
) # Adjusted for single channel

for i, (box, mask) in enumerate(zip(boxes, masks)):
# Validate the detection_filter argument
if detection_filter is not None:

if not callable(detection_filter):
raise ValueError("detection_filter must be callable.")

req_nargs = 6 if inspect.ismethod(detection_filter) else 5
if not len(inspect.signature(detection_filter).parameters) == req_nargs:
raise ValueError(
"detection_filter required args: "
"box, mask, logit, phrase, and index."
)

for i, (box, mask, logit, phrase) in enumerate(
zip(boxes, masks, logits, phrases)
):

# Convert tensor to numpy array if necessary and ensure it contains integers
if isinstance(mask, torch.Tensor):
mask = (
mask.cpu().numpy().astype(dtype)
) # If mask is on GPU, use .cpu() before .numpy()

# Apply the user-supplied filtering logic if provided
if detection_filter is not None:
if not detection_filter(box, mask, logit, phrase, i):
continue

mask_overlay += ((mask > 0) * (i + 1)).astype(
dtype
) # Assign a unique value for each mask
Expand Down
Loading