Last active
May 6, 2017 06:59
-
-
Save VictorVation/c92575e3190c1274bf4c28ee242c257e to your computer and use it in GitHub Desktop.
when you really don't wanna do work term report
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* This is how I procrastinate doing my work term report. | |
* Too bad it doesn't even work :( | |
*/ | |
#include <iostream> | |
#include <cstdint> | |
#include <math.h> | |
#include <vector> | |
class Wire { | |
public: | |
Wire(double value, double grad) : | |
value_(value), | |
grad_(grad) {}; | |
double value_; | |
double grad_; | |
}; | |
class Multiply { | |
public: | |
Multiply(Wire u0, Wire u1) : | |
u0_(u0), | |
u1_(u1), | |
utop_(Wire(u0.value_ + u1.value_, 0)) {}; | |
Wire forward(Wire u0, Wire u1) { | |
u0_ = u0; | |
u1_ = u1; | |
utop_ = Wire(u0.value_ * u1.value_, 0); | |
return utop_; | |
} | |
void backward() { | |
u0_.grad_ += u1_.value_ * utop_.grad_; | |
u1_.grad_ += u0_.value_ * utop_.grad_; | |
} | |
Wire u0_; | |
Wire u1_; | |
Wire utop_; | |
}; | |
class Add { | |
public: | |
Add(Wire u0, Wire u1) : | |
u0_(u0), | |
u1_(u1), | |
utop_(Wire(u0.value_ + u1.value_, 0)) {}; | |
Wire forward(Wire u0, Wire u1) { | |
u0_ = u0; | |
u1_ = u1; | |
utop_ = Wire(u0.value_ + u1.value_, 0); | |
return utop_; | |
} | |
void backward() { | |
u0_.grad_ += u1_.value_ * utop_.grad_; | |
u1_.grad_ += u0_.value_ * utop_.grad_; | |
} | |
Wire u0_; | |
Wire u1_; | |
Wire utop_; | |
}; | |
// class Sigmoid { | |
// public: | |
// Wire forward(Wire u0) { | |
// u0_ = u0; | |
// utop_ = Wire(Sigmoid::sigmoidImpl(u0_.value_), 0); | |
// return utop_; | |
// } | |
// void backward() { | |
// double s = Sigmoid::sigmoidImpl(u0_.value_); | |
// u0_.grad_ += (s * (1 - s)) * utop_.grad_; | |
// } | |
// Wire u0_; | |
// Wire utop_; | |
// private: | |
// double sigmoidImpl(double x) { | |
// return 1 / (1 + pow(-x)); | |
// } | |
// }; | |
class Circuit { | |
public: | |
Circuit() : | |
mulg0_(Wire(0, 0), Wire(0, 0)), | |
mulg1_(Wire(0, 0), Wire(0, 0)), | |
addg0_(Wire(0, 0), Wire(0, 0)), | |
addg1_(Wire(0, 0), Wire(0, 0)), | |
ax_(Wire(0, 0)), | |
by_(Wire(0, 0)), | |
axpby_(Wire(0, 0)), | |
axpbypc_(Wire(0, 0)) {}; | |
Wire forward(Wire x, Wire y, Wire a, Wire b, Wire c) { | |
ax_ = mulg0_.forward(a, x); // a * x | |
by_ = mulg1_.forward(b, y); // b * y | |
axpby_ = addg0_.forward(ax_, by_); // a*x + b*y | |
axpbypc_ = addg1_.forward(axpby_, c); | |
return axpbypc_; | |
} | |
void backward(int gtop) { | |
axpbypc_.grad_ = gtop; | |
addg1_.backward(); | |
addg0_.backward(); | |
mulg1_.backward(); | |
mulg0_.backward(); | |
} | |
Multiply mulg0_; | |
Multiply mulg1_; | |
Add addg0_; | |
Add addg1_; | |
Wire ax_; | |
Wire by_; | |
Wire axpby_; | |
Wire axpbypc_; | |
}; | |
class SVM { | |
public: | |
SVM() : | |
a_(Wire(1, 0)), | |
b_(Wire(-2, 0)), | |
c_(Wire(-1, 0)), | |
out_(Wire(0, 0)), | |
circuit_(Circuit()) {} | |
Wire forward(Wire x, Wire y) { | |
out_ = circuit_.forward(x, y, a_, b_, c_); | |
return out_; | |
} | |
void backward(int label) { | |
a_.grad_ = 0; | |
b_.grad_ = 0; | |
c_.grad_ = 0; | |
// Feedback | |
double correction = 0; | |
if (label == 1 && out_.value_ < 1) { | |
correction = 1; | |
} | |
if (label == 1 && out_.value_ > -1) { | |
correction = -1; | |
} | |
// Regularization | |
a_.grad_ -= a_.value_; | |
b_.grad_ -= b_.value_; | |
} | |
void updateWeights() { | |
double step = 0.01; | |
a_.value_ += step * a_.grad_; | |
b_.value_ += step * b_.grad_; | |
c_.value_ += step * c_.grad_; | |
} | |
void learn(Wire x, Wire y, int label) { | |
// Set weight in all wires | |
forward(x, y); | |
// Set diff in all wires | |
backward(label); | |
// Propogate | |
updateWeights(); | |
} | |
Wire a_; | |
Wire b_; | |
Wire c_; | |
Wire out_; | |
Circuit circuit_; | |
}; | |
double evalTrainingAccuracy( | |
SVM svm, | |
std::vector<std::pair<double, double> > data, | |
std::vector<int> labels) { | |
double num_correct = 0; | |
for (auto i = 0; i < data.size(); ++i) { | |
auto x = Wire(data[i].first, 0.0); | |
auto y = Wire(data[i].second, 0.0); | |
auto groundTruth = labels[i]; | |
// see if the prediction matches the provided label | |
auto prediction = svm.forward(x, y).value_ > 0 ? 1 : -1; | |
if (prediction == groundTruth) { | |
num_correct++; | |
} | |
} | |
std::cout << "WITHIN: CORR: " << num_correct << " dataSize: " << data.size() << '\n'; | |
return num_correct / data.size(); | |
} | |
int main() { | |
std::vector<std::pair<double, double> > data; | |
std::vector<int> labels; | |
data.push_back(std::make_pair(1.2, 0.7)); labels.push_back(1); | |
data.push_back(std::make_pair(-0.3, -0.5)); labels.push_back(-1); | |
data.push_back(std::make_pair(3.0, 0.1)); labels.push_back(1); | |
data.push_back(std::make_pair(-0.1, -1.0)); labels.push_back(-1); | |
data.push_back(std::make_pair(-1.0, 1.1)); labels.push_back(-1); | |
data.push_back(std::make_pair(2.1, -3)); labels.push_back(1); | |
auto svm = SVM(); | |
// the learning loop | |
for (auto iter = 0; iter < 400; iter++) { | |
// pick a random data point | |
auto i = floor(rand() % data.size()); | |
auto x = Wire(data[i].first, 0.0); | |
auto y = Wire(data[i].second, 0.0); | |
auto label = labels[i]; | |
svm.learn(x, y, label); | |
if (iter % 25 == 0) { // every 10 iterations... | |
std::cout << "training accuracy at iter " << iter << ": " | |
<< evalTrainingAccuracy(svm, data, labels) << '\n'; | |
} | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment