import numpy as np
from src import layer, net, util
import matplotlib.pyplot as plt
from nptyping import NDArray
train_label, train_img = util.load_mnist("./dataset")
test_label, test_img = util.load_mnist("./dataset", "t10k")
plt.imshow(train_img[1].reshape(28, 28))
<matplotlib.image.AxesImage at 0x7f6946c778d0>
nt = net.neu_net([28*28, 20, 20, 10])
nt.get_simple_predict(train_img[0])
array([0.47027136, 0.35077137, 0.77519676, 0.34190864, 0.57000972, 0.40998121, 0.84363719, 0.84418194, 0.59020439, 0.27889296])
def train(rep_cnt: int, dset_img, dset_label):
for i in range(rep_cnt):
idx = np.random.randint(len(dset_img))
for j in range(5):
rate = 0.04 * np.random.rand()
nt.bp(dset_img[idx], dset_label[idx], rate)
train(114514, train_img, train_label)
print(nt.get_simple_predict(test_img[1]).argmax())
plt.imshow(test_img[1].reshape(28, 28))
print(test_label[1].argmax())
2 2
def get_correct_rate(tset: NDArray, tlabel: NDArray):
tot = tset.shape[0]
correct = 0
print(tset.shape)
for img, label in zip(tset, tlabel):
ret = nt.get_simple_predict(img)
pred: int = ret.argmax()
ans: int = label.argmax()
if (pred == ans):
correct += 1
else:
print("wrong ans, correct: {}, predicted {}".format(ans, pred))
plt.figure()
plt.imshow(img.reshape(28, 28))
plt.show()
return correct / tot
get_correct_rate(test_img[0 : 200], test_label[0 : 200])
(200, 784) wrong ans, correct: 5, predicted 6
wrong ans, correct: 4, predicted 0
wrong ans, correct: 2, predicted 3
wrong ans, correct: 3, predicted 5
wrong ans, correct: 7, predicted 1
wrong ans, correct: 2, predicted 8
wrong ans, correct: 7, predicted 9
wrong ans, correct: 2, predicted 9
0.96