DEEPLAB_V3
The DEEPLAB_V3 node returns a segmentation mask from an input image in a dataframe.The input image is expected to be a DataContainer of an 'image' type.
The output is a DataContainer of an 'image' type with the same dimensions as the input image, but with the red, green, and blue channels replaced with the segmentation mask.Params:Returns:out : Image
Python Code
from flojoy import flojoy, run_in_venv, Image
@flojoy
@run_in_venv(
pip_dependencies=[
"torch==2.0.1",
"torchvision==0.15.2",
"Pillow",
"numpy",
]
)
def DEEPLAB_V3(default: Image) -> Image:
"""The DEEPLAB_V3 node returns a segmentation mask from an input image in a dataframe.
The input image is expected to be a DataContainer of an 'image' type.
The output is a DataContainer of an 'image' type with the same dimensions as the input image, but with the red, green, and blue channels replaced with the segmentation mask.
Returns
-------
Image
"""
import os
import numpy as np
import PIL.Image
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
from flojoy import Image
from flojoy.utils import FLOJOY_CACHE_DIR
# Parse input image
input_image = default
r, g, b, a = input_image.r, input_image.g, input_image.b, input_image.a
nparray = (
np.stack((r, g, b, a), axis=2) if a is not None else np.stack((r, g, b), axis=2)
)
# Convert input image
input_image = TF.to_pil_image(nparray).convert("RGB")
# Set torch hub cache directory
torch.hub.set_dir(os.path.join(FLOJOY_CACHE_DIR, "cache", "torch_hub"))
model = torch.hub.load(
"pytorch/vision:v0.15.2",
"deeplabv3_resnet50",
pretrained=True,
skip_validation=True,
)
model.eval()
# Preprocessing
preprocess_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Feed the input image to the model
input_tensor = preprocess_transform(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.inference_mode():
output = model(input_batch)["out"][0]
# Fetch the output
output_predictions = output.argmax(0)
palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")
# plot the semantic segmentation predictions of 21 classes in each color
r = PIL.Image.fromarray(output_predictions.byte().cpu().numpy()).resize(
input_image.size
)
r.putpalette(colors)
out_img = np.array(r.convert("RGB"))
# Build the output image
return Image(
r=out_img[:, :, 0],
g=out_img[:, :, 1],
b=out_img[:, :, 2],
a=None,
)
Example
Having problem with this example app? Join our Discord community and we will help you out!
In this example, the node DEEPLAB_V3
is producing a segmentation image mask from an input image generated by the LOCAL_FILE
node.