A LeNet5 implementation in Python3 using Keras
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 3.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from lenet5 import LeNet
  2. from keras.datasets import mnist
  3. from keras.optimizers import SGD
  4. from keras.utils import np_utils
  5. from keras import backend as K
  6. import numpy as np
  7. import argparse
  8. import cv2
  9. # parse arguments
  10. ap = argparse.ArgumentParser()
  11. ap.add_argument(
  12. "-s",
  13. "--save-model",
  14. type=int,
  15. default=-1,
  16. help="(optional) whether or not model should be saved to disk",
  17. )
  18. ap.add_argument(
  19. "-l",
  20. "--load-model",
  21. type=int,
  22. default=-1,
  23. help="(optional) whether or not pre-trained model should be loaded",
  24. )
  25. ap.add_argument("-w", "--weights", type=str, help="(optional) path to weights file")
  26. args = vars(ap.parse_args())
  27. # download MNIST
  28. print("[INFO] uwu downloading MNIST pwease wait...")
  29. ((trainData, trainLabels), (testData, testLabels)) = mnist.load_data()
  30. # check if "channels first" or "channels last" and reshape data accordingly
  31. if K.image_data_format() == "channels_first":
  32. trainData = trainData.reshape((trainData.shape[0], 1, 28, 28))
  33. testData = testData.reshape((testData.shape[0], 1, 28, 28))
  34. else:
  35. trainData = trainData.reshape((trainData.shape[0], 28, 28, 1))
  36. testData = testData.reshape((testData.shape[0], 28, 28, 1))
  37. # scale from [0, 255] to [0, 1]
  38. trainData = trainData.astype("float32") / 255.0
  39. testData = testData.astype("float32") / 255.0
  40. # turns labels into vectors
  41. trainLabels = np_utils.to_categorical(trainLabels, 10)
  42. testLabels = np_utils.to_categorical(testLabels, 10)
  43. # make model with optimizer and compile
  44. print("[INFO] OwO notices compiling of model")
  45. opt = SGD(lr=0.01) # optimizer
  46. model = LeNet.build(
  47. numChannels=1,
  48. imgRows=28,
  49. imgCols=28,
  50. numClasses=10,
  51. weightsPath=args["weights"] if args["load_model"] > 0 else None,
  52. )
  53. model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
  54. # trains model if NOT loading a model
  55. if args["load_model"] < 0:
  56. print("[INFO] training... great time to listen to The Eye of the Tiger")
  57. model.fit(trainData, trainLabels, batch_size=128, epochs=20, verbose=1)
  58. print("[INFO] evaluating... twying vwey hawd")
  59. (loss, accuracy) = model.evaluate(testData, testLabels, batch_size=128, verbose=1)
  60. print("[INFO] accuracy: {:.2f}%".format(accuracy * 100))
  61. # checks to see if model should be saved
  62. if args["save_model"] > 0:
  63. print("[INFO] saving weights to file...")
  64. model.save_weights(args["weights"], overwrite=True)
  65. # test a few random digits
  66. for i in np.random.choice(np.arange(0, len(testLabels)), size=(10,)):
  67. probs = model.predict(testData[np.newaxis, i])
  68. prediction = probs.argmax(axis=1)
  69. if K.image_data_format() == "channels_first":
  70. image = (testData[i][0] * 255).astype("uint8")
  71. else:
  72. image = (testData[i] * 255).astype("uint8")
  73. # merge channels into a single image
  74. image = cv2.merge([image] * 3)
  75. image = cv2.resize(image, (96, 96), interpolation=cv2.INTER_LINEAR)
  76. # display output prediction
  77. cv2.putText(
  78. image,
  79. str(prediction[0]),
  80. (5, 20),
  81. cv2.FONT_HERSHEY_SIMPLEX,
  82. 0.75,
  83. (0, 255, 0),
  84. 2,
  85. )
  86. print(
  87. "[INFO] predicted {}, actual: {}".format(
  88. prediction[0], np.argmax(testLabels[i])
  89. )
  90. )
  91. cv2.imshow("digit", image)
  92. cv2.waitKey(0)