logo

classfier-predict

# helper methods

def predict(network, image_path, class_names, topk=5):
    ''' Predict the class (or classes) of an image using a trained deep learning model.

    Args:
        network (nn_model.Neural_Network): The Neural_Network instance to use for the prediction.
        image_path (str): The path to the image we want to test
        class_names (dict of ints): The label map with the class names
        topk (int): The number of top probabilities and classes we want.

    Raises:
        TODO: Add exceptions

    Returns:
        data_dict (dict): Dictionary containing 'predicted_indexes': indexes predicted by network, 
                                                'idx_to_class': mapped idx_to_class, 
                                                'classes': class names,
                                                'probabilites': the probabilities for classes.        
    '''

    # convert image 
    img = process_image(image_path, 224, 224)
    # need to pass the image tensor with first argument of n where n represents our batch size
    img.unsqueeze_(0)
    # move to device
    img.to(network.device)
    
    # generate the prediction
    network.model.to(network.device)
    # enable eval mode, turn off dropout
    network.model.eval()
    # turn off the gradients since we are not updating params
    with torch.no_grad():
        img = img.to(network.device, dtype=torch.float)
        # get the log softmax
        output = network.model(img)
        # get the prob
        probabilities = torch.exp(output)
        # get the top k values
        top_probabilities, top_classes = probabilities.topk(topk, dim=1)
        
    # remove the tensor cuda by moving to cpu, squeeze to remove dimensions and send to list to index
    top_probabilities = top_probabilities.cpu().squeeze().tolist()
    top_classes = top_classes.cpu().squeeze().tolist()
    
    # generate the idx_to_class mapping dict
    data_dict = map_idx_to_classes(network.model.classifier.class_to_idx, class_names, top_classes, top_probabilities)

    return data_dict
  • Share