Skip to content

Commit

Permalink
Save prediction result
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Sep 15, 2024
1 parent b37bfed commit c38d149
Showing 1 changed file with 37 additions and 2 deletions.
39 changes: 37 additions & 2 deletions samgeo/samgeo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,7 @@ def set_video(
raise ValueError("Input video_path must be a string.")

self.video_path = output_dir
self._num_images = len(os.listdir(output_dir))
self.inference_state = self.predictor.init_state(video_path=output_dir)

def predict_video(
Expand All @@ -1105,6 +1106,27 @@ def predict_video(
output_dir (Optional[str]): The directory to save the output images. Defaults to None.
img_ext (str): The file extension for the output images. Defaults to "png".
"""

from PIL import Image

def save_image_from_dict(data, output_path="output_image.png"):
# Find the shape of the first array in the dictionary (assuming all arrays have the same shape)
array_shape = next(iter(data.values())).shape[1:]

# Initialize an empty array with the same shape as the arrays in the dictionary, filled with zeros
output_array = np.zeros(array_shape, dtype=np.uint8)

# Iterate over each key and array in the dictionary
for key, array in data.items():
# Assign the key value wherever the boolean array is True
output_array[array[0]] = key

# Convert the output array to a PIL image
image = Image.fromarray(output_array)

# Save the image
image.save(output_path)

prompts = self._convert_prompts(prompts)
predictor = self.predictor
inference_state = self.inference_state
Expand All @@ -1121,6 +1143,13 @@ def predict_video(
)

video_segments = {}
num_frames = self._num_images
num_digits = len(str(num_frames))

if output_dir is not None:
if not os.path.exists(output_dir):
os.makedirs(output_dir)

for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state
):
Expand All @@ -1129,10 +1158,16 @@ def predict_video(
for i, out_obj_id in enumerate(out_obj_ids)
}

if output_dir is not None:
output_path = os.path.join(
output_dir, f"{str(out_frame_idx).zfill(num_digits)}.{img_ext}"
)
save_image_from_dict(video_segments[out_frame_idx], output_path)

self.video_segments = video_segments

if output_dir is not None:
self.save_video_segments(output_dir, img_ext)
# if output_dir is not None:
# self.save_video_segments(output_dir, img_ext)

def save_video_segments(self, output_dir: str, img_ext: str = "png") -> None:
"""Save the video segments to the output directory.
Expand Down

0 comments on commit c38d149

Please sign in to comment.