#include <iostream>
#include <algorithm>
#include "MnistFile.cpp"
#include <cmath>
using namespace std;
const int synapseNums = 800;
class Node {
public:
//int weight = 1;//权重
double value = 0;//保持计算和
Node *pre[synapseNums];//pre[0]指向前一层节点
Node *next;//链接下一个节点
};
class Chain {//循环链表
public:
// start->input->others
Node *start;//指向输出层的开始节点
Node *input;//指向输入层的开始节点
Node *others;//指向中间层的节点,动态增删
Node *end;//整个链的最后节点,下一个节点指向start
Node null;//删除节点造成的空指针集中指向的节点
Node **outputLink;//指向输出层的节点
};
class Network {
public:
Chain chain;
const int inputNodeNums;//输入节点维度
const int outputNodeNums;//输出节点维度
int othersNodeNums = 0;//中间层节点数目
int inputIndexValue = -1;//当前标签值
public:
Network(int in, int out);
~Network();
void init();//建立网络结构
void setNull(Node *p);
void setNodes(Node **p, int nums);
void inputValue(vector<double> input, double index);
double activate(double x);
int softMax(double **x, int n);
int retIndexOf(Node *p, Node *q);
int retNullIndex(Node *p);
void forward();
double getIndexValue(int index);//得到标签的输出值
int getIndex();//得到正确的标签值
double getChange();//得到改变边后输出的变化
void newNode();
void deleteNode(Node *p);
void deleteNode2(Node *p);
void newEdge(Node **p);
void newEdge2(Node **ip, vector<double> &labels, vector<vector<double> >&images);
void newOutputEdge();
//void deleteedge();
void train(int trainSampleNums, vector<double>&labels, vector<vector<double> > &images);
void train2(int trainSampleNums, vector<double>&labels, vector<vector<double> > &images);
int predictIndex();//预测的标签值
//void writeDate();
//void readData(string s);
double evalStudyRate(int testSampleNums, vector<double> &labels, vector<vector<double> >&images);
};
void Network::setNull(Node *p) {
for (int i = 0; i < synapseNums; ++i)
p->pre[i] = &chain.null;
}
Network::Network(int in, int out) : inputNodeNums(in), outputNodeNums(out) {
//设置null节点
chain.null.next = nullptr;
chain.null.value = 0;
setNull(&chain.null);
//初始化
init();
//设置outputLink
chain.outputLink = new Node*[outputNodeNums];
Node *p = chain.start;
for (int i = 0; i < outputNodeNums; ++i) {
chain.outputLink[i] = p;
p = p->next;
}
}
Network::~Network() {
cout << "~Network ..." << endl;
Node *p = nullptr;
Node *q = nullptr;
p = chain.start;
chain.start = nullptr;
while (p != chain.input) {
//cout << "p isn't null" << endl;
q = p->next;
delete p;
p = q;
}
delete [] chain.outputLink;
}
void Network::setNodes(Node **p, int nums) {
for (int i = 0; i < nums; i++) {
Node *q = new Node;
setNull(q);
(*p)->next = q;
*p = q;
}
}
void Network::init() {
cout << "init ..." << endl;
//将所有节点拉成链
//建立输出层
Node *p = new Node;
setNull(p);
chain.start = p;//p前节点
setNodes(&p, outputNodeNums - 1);
//建立输入层
setNodes(&p, 1);
chain.input = p;//p前节点
setNodes(&p, inputNodeNums - 1);
//连接others
setNodes(&p, 1);
othersNodeNums = 1;
//连接最后节点指针
chain.others = chain.end = p;
chain.end->next = chain.start;//循环链表
}
double Network::activate(double x) {
return 1.0 / (1.0 + exp(-x));
}
int Network::softMax(double **x, int n) {
double *p = *x;
int k = 0;
for (int i = 0; i < n; ++i) {
if (p[k] < p[i])
k = i;
}
return k;
}
void Network::inputValue(vector<double> input, double index) {
Node *p = chain.input;
double s = 0;
for (int j = 0; j < inputNodeNums; ++j) {
s += input[j];
}
for (int i = 0; i < inputNodeNums; ++i) {
p->value = input[i];// / s;//正规化
p = p->next;
}
inputIndexValue = index;
//cout << "Network::inputValue(vector<double> input, double index); end!" << endl;
}
void Network::forward() {
Node *p = chain.others;
double sum = 0;
while (p != chain.input) {
for (int i = 0; i < synapseNums; ++i) {
sum = sum + p->pre[i]->value;
}
p->value = activate(sum);
//cout << p->value << " ";
p = p->next;
}
}
void Network::newNode() {
cout << "newNode ..." << endl;
Node *p = new Node;
setNull(p);
othersNodeNums++;
chain.end->next = p;
chain.end = p;
chain.end->next = chain.start;
}
void Network::deleteNode(Node *p) { //
//删除一个节点,需要找到指向该节点的指针,然后再删除该节点
//找到p的前驱
Node *q = chain.others;
while (q) {
if (q->next == p || q->next == chain.start) {
break;
} else {
q = q->next;
}
}
if (q->next != p)
return;//没找到这个节点
//找到指向p的边修改为null
Node *t = chain.others;
while (t) {
for (int i = 0; i < synapseNums; ++i) {
if (t->pre[i] == p) {
t->pre[i] = &chain.null;//t->pre[i] == nullptr;
}
}
t = t->next;
if (t == chain.input)
break;
}
//删除节点
q->next = p->next;
delete p;
othersNodeNums--;
}
void Network::deleteNode2(Node *p) {
p->value = 0;
}
double Network::getIndexValue(int index) {
return chain.outputLink[index]->value;
}
double Network::getChange() {
forward();
int index = getIndex();
return getIndexValue(index);
}
int Network::getIndex() {
return this->inputIndexValue;
}
int Network::predictIndex() {
double x[outputNodeNums];
memset(x, 0, outputNodeNums);
Node *p = chain.start;
for (int i = 0; i < outputNodeNums; ++i) {
x[i] = p->value;
p = p->next;
}
double *q = x;
return softMax(&q, outputNodeNums);
}
//返回p指向q的下标
int Network::retIndexOf(Node *p, Node *q) {
//p->q?
for (int i = 0; i < synapseNums; ++i) {
if (p->pre[i] == q) {
return i;//找到了,p的第i个指针指向q
}
}
return -1;//没找到
}
//寻找空闲指针
int Network::retNullIndex(Node *p){
Node *q = &chain.null;
for (int i = 0; i < synapseNums; ++i) {
if (p->pre[i] == q)
return i;
}
return -1;//说明满了,没有空闲的
}
void Network::newEdge(Node **ip) {
cout << "newEdge ..." << endl;
Node *p = *ip;
Node *t = chain.input;
double old = 0;
double now = 0;
for (int i = 0; i < synapseNums;) {
//测试当前right index的输出值
old = getChange();
//连接边
// i = retNullIndex(p);
// if (i == -1)
// break;//这里不对,与下面的i++冲突
p->pre[i] = t;//当前节点指向t
//修改output层的正确输出的指针指向p
int re = retIndexOf(chain.outputLink[getIndex()], p);
if (re == -1){
int rn = retNullIndex(chain.outputLink[getIndex()]);
if (rn == -1)
cerr << "retNullIndex is overfill!!\n";
chain.outputLink[getIndex()]->pre[rn] = p;
}
//测试当前连接后的right index的输出值
now = getChange();
//比较,若大于,则连接,否则连接下一个节点
if (now > old)
++i;
t = t->next;
if (t == chain.start)
break;
if(predictIndex() == getIndex())
break;//如果得到正确的输出就跳出
// cout << i << " ";
}
}
//void
//void Network::newOutputEdge() {
// Node *p = chain.start;//输出层节点指针
// double old = 0;
// double now = 0;
// Node *q = chain.input;//输入层和中间层节点指针
//
// while (p != chain.input){
// int r = retNullIndex(p);//空闲指针
//
// p = p->next;
//
// }
//}
double Network::evalStudyRate(int testSampleNums, vector<double> &labels, vector<vector<double> >&images) {
cout << "evalStudyRate : " ;
//根据输入的测试集数目来评估正确率
int rightNums = 0;
for (int i = 0; i < testSampleNums; ++i) {
inputValue(images[50000+i], labels[50000+i]);
forward();//前向传播
if (predictIndex() == inputIndexValue) //预测
rightNums++;
}
return (double)rightNums / testSampleNums * 100.0;
}
void Network::newEdge2(Node **ip, vector<double> &labels, vector<vector<double> >&images) {
cout << "newEdge ..." << endl;
Node *p = *ip;
Node *t = chain.input;
double old = 0;
double now = 0;
for (int i = 0; i < synapseNums;) {
//测试当前right index的输出值
old = evalStudyRate(100, labels, images);
//连接边
p->pre[i] = t;//当前节点指向t
// //修改output层的正确输出的指针指向p
// int re = retIndexOf(chain.outputLink[getIndex()], p);
// if (re == -1){
// int rn = retNullIndex(chain.outputLink[getIndex()]);
// if (rn == -1)
// cerr << "retNullIndex is overfill!!\n";
// chain.outputLink[getIndex()]->pre[rn] = p;
// }
//测试当前连接后的right index的输出值
now = evalStudyRate(100, labels, images);
//比较,若大于,则连接,否则连接下一个节点
if (now > old)
++i;
t = t->next;
if (t == chain.start)
break;
if(predictIndex() == getIndex())
break;//如果得到正确的输出就跳出
//cout << i << " ";
}
}
//训练这个网络,得到网络结构
void Network::train(int trainSampleNums, vector<double>&labels, vector<vector<double> > &images) {
cout << "train ... " << endl;
//训练
for (int i = 0; i < trainSampleNums; ++i) {
//输入数据进入网络
inputValue(images[i], labels[i]);
//连接边
newEdge(&chain.end);
//新建节点
newNode();
cout << "othersNodeNums = " << othersNodeNums << endl;
if (othersNodeNums >= 50) {
cout << "if (othersNodeNums >= 8) stop" << endl;
break;
}
}
//评估学习率
cout << evalStudyRate(9000, labels, images) << "%" << endl;
}
//训练这个网络,得到网络结构
void Network::train2(int trainSampleNums, vector<double>&labels, vector<vector<double> > &images) {
cout << "train ... " << endl;
//训练
for (int i = 0; i < trainSampleNums; ++i) {
//输入数据进入网络
inputValue(images[i], labels[i]);
//连接边
newEdge(&chain.end);
//新建节点
newNode();
cout << "othersNodeNums = " << othersNodeNums << endl;
if (othersNodeNums >= 4) {
cout << "if (othersNodeNums >= 8) stop" << endl;
break;
}
}
// Node *p = chain.others;
// while (p != chain.input) {
// newEdge2(&p, labels, images);
// }
//评估学习率
cout << evalStudyRate(9000, labels, images) << "%" << endl;
}
int main(){
//读数据
cout << "read data..." << endl;
vector<double>labels;
read_Mnist_Label("train-labels.idx1-ubyte", labels);
vector<vector<double>> images;
read_Mnist_Images("train-images.idx3-ubyte", images);
Network network(784, 10);
network.train2(10, labels, images);
}