logo

classfier-plotimages-predictions

# TODO: Display an image along with the top 5 classes

# To visualize more than 1 result at a time I added this function, displays a grid of n results with image and prediction

import numpy as np
import seaborn as sb
import matplotlib.pyplot as plt
import pandas as pd

def plot_image_results(datasets, filter, count):
    # generate random n indexes to choose random testing images from dataset
    idx = np.random.randint(0,len(datasets[filter].imgs),size=(count,))
    print(idx)

    # get the image folder number idx from the randomly selected dataset image
    batch_idx = [datasets[filter].imgs[x][0].split('\\')[-2] for x in idx]
    print(batch_idx)

    # fix the full path for the batch idx's
    batch_paths = [datasets[filter].imgs[x][0].replace('\\','/') for x in idx]
    print(batch_paths)

    # get actual flower name from the mapping back to the label
    labeled_names = [flowers_to_name[x] for x in batch_idx]
    print(labeled_names)

    # zip the data
    data = dict(zip(labeled_names, batch_paths))

    # set the subplots
    rows = (len(data.items())) 
    cols = 2
    fig, axs = plt.subplots(nrows = rows, ncols= cols, figsize=(cols*4,rows*3), squeeze = False)
    axs = axs.flatten()
    plt.tight_layout()

    # iterate through the dict, plot the graphs on the even grid cell
    # plot matching imgs on the odd grid cells
    count, img_counter = 0, 1
    for name, path in data.items():
        # get the predictions
        results_dict = predict(loaded_net, path, flowers_to_name)
        for k,v in results_dict.items():
            print('{}:{}'.format(k, v))
        print('flower is {}\n'.format(name))
                
        # barplots for the results
        bp = sb.barplot(x=results_dict['probabilities'], y=results_dict['classes'], ax=axs[count])
        bp.set_title(name)
        
        # plot the images
        img = process_image(path, 224, 224)
        imshow(img, axs[img_counter])
         
        # increment the counters
        count += 2
        img_counter += 2
            
    plt.show()

plot_image_results(datasets, 'test', 5)
  • Share