import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# Load the TensorFlow Lite model into the interpreter
interpreter = tf.lite.Interpreter(model_path="TFLite_converted_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
output_shape = output_details[0]['shape']
# Print name, shape, and type of model input and output
print("Input details",
"\nName:", input_details[0]['name'],
"\nShape:", input_details[0]['shape'],
"\nType:", input_details[0]['dtype'],
"\nIndex:", input_details[0]['index'])
print("\n\nOutput details",
"name:", output_details[0]['name'],
"\nshape:", output_details[0]['shape'],
"\ntype:", output_details[0]['dtype'],
"\nIndex:", output_details[0]['index'])
# Load a test image to be interpreted by the TF Lite model
#image_path = '/usr/local/notebooks/datasets/horses_or_humans_dataset/horse-or-human/horse-or-human/validation/humans/valhuman01-00.png'
image_path = '/usr/local/notebooks/datasets/horses_or_humans_dataset/horse-or-human/horse-or-human/validation/horses/horse2-011.png'
input_image = tf.keras.preprocessing.image.load_img(image_path, target_size=(224, 224))
input_image = tf.keras.preprocessing.image.img_to_array(input_image)
input_image = np.expand_dims(input_image, axis=0)
input_image = input_image / 255.0
# Set the input tensor
interpreter.set_tensor(input_details[0]['index'], input_image)
# Run inference
interpreter.invoke()
# Get the output tensor
output_data = interpreter.get_tensor(output_details[0]['index'])
print('\nOutput data:', output_data)
predicted_label_index = np.argmax(output_data)
print('\npredicted label index:', predicted_label_index)
# Display the image and predicted label: 0/1 for horse/human
plt.imshow(tf.squeeze(input_image))
plt.xticks([])
plt.yticks([])
plt.grid(False)
predicted_label = 'Human' if output_data[0,1] > output_data[0,0] else 'Horse'
plt.xlabel('Predicted label: %s' % predicted_label)
plt.savefig('output/test_image.png', dpi = 300)