Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages, legendHandling | |
| from model.modelLoading import loadModel | |
| ## %% CONSTANTS | |
| gta_image_dir = "./preloadedImages/GTAV" | |
| city_image_dir = "./preloadedImages/cityScapes" | |
| turin_image_dir = "./preloadedImages/turin" | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| MODELS = { | |
| # "BISENET-BASE": loadModel('bisenet', device, 'weight_Base.pth'), | |
| "BISENET-BEST": loadModel('bisenet', device, 'weight_Best.pth'), | |
| # "BISENETV2-BASE": loadModel('bisenetv2', device, 'weight_Base.pth'), | |
| "BISENETV2-BEST": loadModel('bisenetv2', device, 'weight_Best.pth') | |
| } | |
| image_list = loadPreloadedImages(gta_image_dir, city_image_dir, turin_image_dir) | |
| # %% prediction on an image | |
| def predict(inputImage: torch.Tensor, model) -> torch.Tensor: | |
| """ | |
| Predict the segmentation mask for the input image using the provided model. | |
| Args: | |
| inputImage (torch.Tensor): The input image tensor. | |
| model (BiSeNet): The BiSeNet model for segmentation. | |
| Returns: | |
| prediction (torch.Tensor): The predicted segmentation mask. | |
| """ | |
| with torch.no_grad(): | |
| output = model(preprocessing(inputImage.clone()).to(device)) | |
| output = output[0] if isinstance(output, (tuple, list)) else output | |
| return output[0].argmax(dim=0, keepdim=True).to(device) | |
| # %% Gradio interface | |
| def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]: | |
| if image is None: | |
| return (gr.update(value=None, visible=False), gr.update(value=f"β No image provided for prediction.", visible=True)) | |
| if selected_model is None: | |
| return (gr.update(value=None, visible=False), gr.update(value=f"β No model selected for prediction.", visible=True)) | |
| if not isinstance(selected_model, str) or selected_model.strip().upper() not in MODELS: | |
| return (gr.update(value=None, visible=False), gr.update(value=f"β Invalid model selected.", visible=True)) | |
| try: | |
| image = hfImageToTensor(image, width=1024, height=512) | |
| prediction = predict(image, MODELS[selected_model.strip().upper()]) | |
| prediction = postprocessing(prediction) | |
| except Exception as e: | |
| return (gr.update(value=None, visible=False), gr.update(value=f"β {str(e)}.", visible=True)) | |
| return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False)) | |
| # Gradio UI | |
| with gr.Blocks(title="Semantic Segmentation Predictors") as demo: | |
| gr.Markdown("# Semantic Segmentation with Real-Time Networks") | |
| gr.Markdown('A small user interface created to run semantic segmentation on images using Cityscapes-like predictions and real-time segmentation networks.') | |
| gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.") | |
| gr.Markdown("The full code for the project is available on [GitHub](https://github.com/Nuzz23/MLDL_SemanticSegmentation).") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload image") | |
| model_selector = gr.Radio( | |
| choices=[ #"BiSeNet-Base", | |
| "BiSeNet-Best", | |
| # "BiSeNetV2-Base", | |
| "BiSeNetV2-Best"], | |
| value="BiSeNet-Best", | |
| label="Select the real time segmentation model" | |
| ) | |
| submit_btn = gr.Button("Run prediction") | |
| with gr.Column(): | |
| result_display = gr.Image(label="Model prediction", visible=True) | |
| error_text = gr.Markdown("", visible=False) | |
| gr.Markdown("The legend of the classes is the following (format **name** -> **color**)") | |
| with gr.Row(): | |
| legend = legendHandling() | |
| for i in range(0, len(legend), 2): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| color_box0 = f"""<span style='display:inline-block; width:15px; height:15px; | |
| background-color:rgb({legend[i][3][0]},{legend[i][3][1]},{legend[i][3][2]}); margin-left:6px; border:1px solid #000;'></span>""" | |
| gr.HTML(f"<div style='display:flex; align-items:center; margin-bottom:-10px; margin-top:-5px;'><b>{legend[i][1]}</b> β {color_box0}</div>") | |
| with gr.Column(scale=1): | |
| if i + 1 < len(legend): | |
| color_box1 = f"""<span style='display:inline-block; width:15px; height:15px; | |
| background-color:rgb({legend[i+1][3][0]},{legend[i+1][3][1]},{legend[i+1][3][2]}); margin-left:6px; border:1px solid #000;'></span>""" | |
| gr.HTML(f"<div style='display:flex; align-items:center; margin-bottom:-10px; margin-top:-5px;'><b>{legend[i+1][1]}</b> β {color_box1}</div>") | |
| else: | |
| gr.Markdown("") | |
| with gr.Row(): | |
| gr.Markdown("## Preloaded images to be used for testing the model") | |
| gr.Markdown("""You can use images taken from the Grand Theft Auto V video game, the Cityscapes dataset or | |
| even the city of Turin to be used as input for the model without need for manual upload.""") | |
| # Mostriamo 4 righe da 5 immagini | |
| for i in range(0, len(image_list), 5): | |
| with gr.Row(): | |
| for img in image_list[i:i+5]: | |
| img_comp = gr.Image(value=img, interactive=False, show_label=False, show_download_button=False, height=180, width=256, | |
| show_fullscreen_button=False, show_share_button=False, mirror_webcam=False) | |
| img_comp.select(fn=lambda x:x, inputs=img_comp, outputs=image_input) | |
| submit_btn.click( | |
| fn=run_prediction, | |
| inputs=[image_input, model_selector], | |
| outputs=[result_display, error_text], | |
| ) | |
| gr.Markdown("Made by group 21 semantic segmentation project at Politecnico di Torino 2024/2025") | |
| demo.launch() | |