In [ ]:
import numpy as np
from src import layer, net, util
import matplotlib.pyplot as plt
from nptyping import NDArray

train_label, train_img = util.load_mnist("./dataset")
test_label, test_img = util.load_mnist("./dataset", "t10k")
plt.imshow(train_img[1].reshape(28, 28))
Out[ ]:
<matplotlib.image.AxesImage at 0x7f6946c778d0>
In [ ]:
nt = net.neu_net([28*28, 20, 20, 10])
In [ ]:
nt.get_simple_predict(train_img[0])
Out[ ]:
array([0.47027136, 0.35077137, 0.77519676, 0.34190864, 0.57000972,
       0.40998121, 0.84363719, 0.84418194, 0.59020439, 0.27889296])
In [ ]:
def train(rep_cnt: int, dset_img, dset_label):
    for i in range(rep_cnt):
        idx = np.random.randint(len(dset_img))
        for j in range(5):
            rate = 0.04 * np.random.rand() 
            nt.bp(dset_img[idx], dset_label[idx], rate)
In [ ]:
train(114514, train_img, train_label)
In [ ]:
print(nt.get_simple_predict(test_img[1]).argmax())
plt.imshow(test_img[1].reshape(28, 28))
print(test_label[1].argmax())
2
2
In [ ]:
def get_correct_rate(tset: NDArray, tlabel: NDArray):
    tot = tset.shape[0]
    correct = 0
    print(tset.shape)
    for img, label in zip(tset, tlabel):
        ret = nt.get_simple_predict(img)
        pred: int = ret.argmax()
        ans: int = label.argmax()
        if (pred == ans):
            correct += 1
        else:
            print("wrong ans, correct: {}, predicted {}".format(ans, pred))
            plt.figure()
            plt.imshow(img.reshape(28, 28))
            plt.show()
            
    return correct / tot
In [ ]:
get_correct_rate(test_img[0 : 200], test_label[0 : 200])
(200, 784)
wrong ans, correct: 5, predicted 6
wrong ans, correct: 4, predicted 0
wrong ans, correct: 2, predicted 3
wrong ans, correct: 3, predicted 5
wrong ans, correct: 7, predicted 1
wrong ans, correct: 2, predicted 8
wrong ans, correct: 7, predicted 9
wrong ans, correct: 2, predicted 9
Out[ ]:
0.96