def predict_classes(tf_model, tf_lite_interpreter, x):
tf_lite_interpreter.set_tensor(input_details[0]['index'], x)
tf_lite_interpreter.invoke()
tflite_results = tf_lite_interpreter.get_tensor(output_details[0]['index'])
tf_results = model.predict_classes(x)
if tflite_results.shape[-1] > 1:
tflite_results = tflite_results.argmax(axis=-1)
else:
tflite_results = (tflite_results > 0.5).astype('int32')
plt.figure(figsize=(12,10))
for n in range(32):
plt.subplot(6,6,n+1)
plt.subplots_adjust(hspace = 0.3)
plt.imshow(image_batch[n])
color = "blue" if tf_results[n] == tflite_results[n] else "red"
plt.title(tflite_results[n], color=color)
plt.axis('off')
_ = plt.suptitle("Model predictions (blue: same, red: different)")