Skip to content

DEEPLAB_V3

The DEEPLAB_V3 node returns a segmentation mask from an input image in a dataframe.The input image is expected to be a DataContainer of an 'image' type. The output is a DataContainer of an 'image' type with the same dimensions as the input image, but with the red, green, and blue channels replaced with the segmentation mask.Params:Returns:out : Image
Python Code
from flojoy import flojoy, run_in_venv, Image


@flojoy
@run_in_venv(
    pip_dependencies=[
        "torch==2.0.1",
        "torchvision==0.15.2",
        "Pillow",
        "numpy",
    ]
)
def DEEPLAB_V3(default: Image) -> Image:
    """The DEEPLAB_V3 node returns a segmentation mask from an input image in a dataframe.

    The input image is expected to be a DataContainer of an 'image' type.

    The output is a DataContainer of an 'image' type with the same dimensions as the input image, but with the red, green, and blue channels replaced with the segmentation mask.

    Returns
    -------
    Image
    """

    import os
    import numpy as np
    import PIL.Image
    import torch
    from torchvision import transforms
    import torchvision.transforms.functional as TF
    from flojoy import Image
    from flojoy.utils import FLOJOY_CACHE_DIR

    # Parse input image
    input_image = default
    r, g, b, a = input_image.r, input_image.g, input_image.b, input_image.a
    nparray = (
        np.stack((r, g, b, a), axis=2) if a is not None else np.stack((r, g, b), axis=2)
    )
    # Convert input image
    input_image = TF.to_pil_image(nparray).convert("RGB")
    # Set torch hub cache directory
    torch.hub.set_dir(os.path.join(FLOJOY_CACHE_DIR, "cache", "torch_hub"))
    model = torch.hub.load(
        "pytorch/vision:v0.15.2",
        "deeplabv3_resnet50",
        pretrained=True,
        skip_validation=True,
    )
    model.eval()
    # Preprocessing
    preprocess_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    # Feed the input image to the model
    input_tensor = preprocess_transform(input_image)
    input_batch = input_tensor.unsqueeze(0)
    with torch.inference_mode():
        output = model(input_batch)["out"][0]
    # Fetch the output
    output_predictions = output.argmax(0)
    palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
    colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
    colors = (colors % 255).numpy().astype("uint8")
    # plot the semantic segmentation predictions of 21 classes in each color
    r = PIL.Image.fromarray(output_predictions.byte().cpu().numpy()).resize(
        input_image.size
    )
    r.putpalette(colors)
    out_img = np.array(r.convert("RGB"))
    # Build the output image
    return Image(
        r=out_img[:, :, 0],
        g=out_img[:, :, 1],
        b=out_img[:, :, 2],
        a=None,
    )

Find this Flojoy Block on GitHub

Example

Having problem with this example app? Join our Discord community and we will help you out!
React Flow mini map

In this example, the node DEEPLAB_V3 is producing a segmentation image mask from an input image generated by the LOCAL_FILE node.