import torch import gradio as gr from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages, legendHandling from model.modelLoading import loadModel ## %% CONSTANTS gta_image_dir = "./preloadedImages/GTAV" city_image_dir = "./preloadedImages/cityScapes" turin_image_dir = "./preloadedImages/turin" device = 'cuda' if torch.cuda.is_available() else 'cpu' MODELS = { # "BISENET-BASE": loadModel('bisenet', device, 'weight_Base.pth'), "BISENET-BEST": loadModel('bisenet', device, 'weight_Best.pth'), # "BISENETV2-BASE": loadModel('bisenetv2', device, 'weight_Base.pth'), "BISENETV2-BEST": loadModel('bisenetv2', device, 'weight_Best.pth') } image_list = loadPreloadedImages(gta_image_dir, city_image_dir, turin_image_dir) # %% prediction on an image def predict(inputImage: torch.Tensor, model) -> torch.Tensor: """ Predict the segmentation mask for the input image using the provided model. Args: inputImage (torch.Tensor): The input image tensor. model (BiSeNet): The BiSeNet model for segmentation. Returns: prediction (torch.Tensor): The predicted segmentation mask. """ with torch.no_grad(): output = model(preprocessing(inputImage.clone()).to(device)) output = output[0] if isinstance(output, (tuple, list)) else output return output[0].argmax(dim=0, keepdim=True).to(device) # %% Gradio interface def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]: if image is None: return (gr.update(value=None, visible=False), gr.update(value=f"❌ No image provided for prediction.", visible=True)) if selected_model is None: return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True)) if not isinstance(selected_model, str) or selected_model.strip().upper() not in MODELS: return (gr.update(value=None, visible=False), gr.update(value=f"❌ Invalid model selected.", visible=True)) try: image = hfImageToTensor(image, width=1024, height=512) prediction = predict(image, MODELS[selected_model.strip().upper()]) prediction = postprocessing(prediction) except Exception as e: return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True)) return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False)) # Gradio UI with gr.Blocks(title="Semantic Segmentation Predictors") as demo: gr.Markdown("# Semantic Segmentation with Real-Time Networks") gr.Markdown('A small user interface created to run semantic segmentation on images using Cityscapes-like predictions and real-time segmentation networks.') gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.") gr.Markdown("The full code for the project is available on [GitHub](https://github.com/Nuzz23/MLDL_SemanticSegmentation).") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload image") model_selector = gr.Radio( choices=[ #"BiSeNet-Base", "BiSeNet-Best", # "BiSeNetV2-Base", "BiSeNetV2-Best"], value="BiSeNet-Best", label="Select the real time segmentation model" ) submit_btn = gr.Button("Run prediction") with gr.Column(): result_display = gr.Image(label="Model prediction", visible=True) error_text = gr.Markdown("", visible=False) gr.Markdown("The legend of the classes is the following (format **name** -> **color**)") with gr.Row(): legend = legendHandling() for i in range(0, len(legend), 2): with gr.Row(): with gr.Column(scale=1): color_box0 = f"""""" gr.HTML(f"
{legend[i][1]} → {color_box0}
") with gr.Column(scale=1): if i + 1 < len(legend): color_box1 = f"""""" gr.HTML(f"
{legend[i+1][1]} → {color_box1}
") else: gr.Markdown("") with gr.Row(): gr.Markdown("## Preloaded images to be used for testing the model") gr.Markdown("""You can use images taken from the Grand Theft Auto V video game, the Cityscapes dataset or even the city of Turin to be used as input for the model without need for manual upload.""") # Mostriamo 4 righe da 5 immagini for i in range(0, len(image_list), 5): with gr.Row(): for img in image_list[i:i+5]: img_comp = gr.Image(value=img, interactive=False, show_label=False, show_download_button=False, height=180, width=256, show_fullscreen_button=False, show_share_button=False, mirror_webcam=False) img_comp.select(fn=lambda x:x, inputs=img_comp, outputs=image_input) submit_btn.click( fn=run_prediction, inputs=[image_input, model_selector], outputs=[result_display, error_text], ) gr.Markdown("Made by group 21 semantic segmentation project at Politecnico di Torino 2024/2025") demo.launch()