In [20]:
def predict_classes(tf_model, tf_lite_interpreter, x):
    tf_lite_interpreter.set_tensor(input_details[0]['index'], x)
    tf_lite_interpreter.invoke()
    tflite_results = tf_lite_interpreter.get_tensor(output_details[0]['index'])

    tf_results = model.predict_classes(x)

    if tflite_results.shape[-1] > 1:
        tflite_results = tflite_results.argmax(axis=-1)
    else:
        tflite_results = (tflite_results > 0.5).astype('int32')

    plt.figure(figsize=(12,10))
    for n in range(32):
        plt.subplot(6,6,n+1)
        plt.subplots_adjust(hspace = 0.3)
        plt.imshow(image_batch[n])
        color = "blue" if tf_results[n] == tflite_results[n] else "red"
        plt.title(tflite_results[n], color=color)
        plt.axis('off')
    _ = plt.suptitle("Model predictions (blue: same, red: different)")

In [21]:
predict_classes(model, interpreter, image_batch)