VisualSemSeg / app.py
Nunzio
modified weights, final commit
9213c5c
import torch
import gradio as gr
from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages, legendHandling
from model.modelLoading import loadModel
## %% CONSTANTS
gta_image_dir = "./preloadedImages/GTAV"
city_image_dir = "./preloadedImages/cityScapes"
turin_image_dir = "./preloadedImages/turin"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODELS = {
# "BISENET-BASE": loadModel('bisenet', device, 'weight_Base.pth'),
"BISENET-BEST": loadModel('bisenet', device, 'weight_Best.pth'),
# "BISENETV2-BASE": loadModel('bisenetv2', device, 'weight_Base.pth'),
"BISENETV2-BEST": loadModel('bisenetv2', device, 'weight_Best.pth')
}
image_list = loadPreloadedImages(gta_image_dir, city_image_dir, turin_image_dir)
# %% prediction on an image
def predict(inputImage: torch.Tensor, model) -> torch.Tensor:
"""
Predict the segmentation mask for the input image using the provided model.
Args:
inputImage (torch.Tensor): The input image tensor.
model (BiSeNet): The BiSeNet model for segmentation.
Returns:
prediction (torch.Tensor): The predicted segmentation mask.
"""
with torch.no_grad():
output = model(preprocessing(inputImage.clone()).to(device))
output = output[0] if isinstance(output, (tuple, list)) else output
return output[0].argmax(dim=0, keepdim=True).to(device)
# %% Gradio interface
def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
if image is None:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ No image provided for prediction.", visible=True))
if selected_model is None:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True))
if not isinstance(selected_model, str) or selected_model.strip().upper() not in MODELS:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ Invalid model selected.", visible=True))
try:
image = hfImageToTensor(image, width=1024, height=512)
prediction = predict(image, MODELS[selected_model.strip().upper()])
prediction = postprocessing(prediction)
except Exception as e:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True))
return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))
# Gradio UI
with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
gr.Markdown("# Semantic Segmentation with Real-Time Networks")
gr.Markdown('A small user interface created to run semantic segmentation on images using Cityscapes-like predictions and real-time segmentation networks.')
gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.")
gr.Markdown("The full code for the project is available on [GitHub](https://github.com/Nuzz23/MLDL_SemanticSegmentation).")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload image")
model_selector = gr.Radio(
choices=[ #"BiSeNet-Base",
"BiSeNet-Best",
# "BiSeNetV2-Base",
"BiSeNetV2-Best"],
value="BiSeNet-Best",
label="Select the real time segmentation model"
)
submit_btn = gr.Button("Run prediction")
with gr.Column():
result_display = gr.Image(label="Model prediction", visible=True)
error_text = gr.Markdown("", visible=False)
gr.Markdown("The legend of the classes is the following (format **name** -> **color**)")
with gr.Row():
legend = legendHandling()
for i in range(0, len(legend), 2):
with gr.Row():
with gr.Column(scale=1):
color_box0 = f"""<span style='display:inline-block; width:15px; height:15px;
background-color:rgb({legend[i][3][0]},{legend[i][3][1]},{legend[i][3][2]}); margin-left:6px; border:1px solid #000;'></span>"""
gr.HTML(f"<div style='display:flex; align-items:center; margin-bottom:-10px; margin-top:-5px;'><b>{legend[i][1]}</b> β†’ {color_box0}</div>")
with gr.Column(scale=1):
if i + 1 < len(legend):
color_box1 = f"""<span style='display:inline-block; width:15px; height:15px;
background-color:rgb({legend[i+1][3][0]},{legend[i+1][3][1]},{legend[i+1][3][2]}); margin-left:6px; border:1px solid #000;'></span>"""
gr.HTML(f"<div style='display:flex; align-items:center; margin-bottom:-10px; margin-top:-5px;'><b>{legend[i+1][1]}</b> β†’ {color_box1}</div>")
else:
gr.Markdown("")
with gr.Row():
gr.Markdown("## Preloaded images to be used for testing the model")
gr.Markdown("""You can use images taken from the Grand Theft Auto V video game, the Cityscapes dataset or
even the city of Turin to be used as input for the model without need for manual upload.""")
# Mostriamo 4 righe da 5 immagini
for i in range(0, len(image_list), 5):
with gr.Row():
for img in image_list[i:i+5]:
img_comp = gr.Image(value=img, interactive=False, show_label=False, show_download_button=False, height=180, width=256,
show_fullscreen_button=False, show_share_button=False, mirror_webcam=False)
img_comp.select(fn=lambda x:x, inputs=img_comp, outputs=image_input)
submit_btn.click(
fn=run_prediction,
inputs=[image_input, model_selector],
outputs=[result_display, error_text],
)
gr.Markdown("Made by group 21 semantic segmentation project at Politecnico di Torino 2024/2025")
demo.launch()