Skip to content

Commit

Permalink
Release first demo about regression
Browse files Browse the repository at this point in the history
1. Release first demo about regression
2. Add plot module for 1D function plotting
  • Loading branch information
suquark committed Feb 25, 2017
1 parent c0298b7 commit 2749da0
Show file tree
Hide file tree
Showing 11 changed files with 590 additions and 63 deletions.
23 changes: 23 additions & 0 deletions .eslintrc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"env": {
"browser": true,
"commonjs": true,
"es6": true,
"node": true
},
"parserOptions": {
"ecmaFeatures": {
"jsx": true
},
"sourceType": "module"
},
"rules": {
"no-const-assign": "warn",
"no-this-before-super": "warn",
"no-undef": "warn",
"no-unreachable": "warn",
"no-unused-vars": "warn",
"constructor-super": "warn",
"valid-typeof": "warn"
}
}
15 changes: 9 additions & 6 deletions backend/symbols.js
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,13 @@ function _saveDir(dir, maplist, buf) {
packet = { name: name, nodes: [] };
_saveDir(dir[name], packet.nodes, buf);
} else {
// We attach save() method to all objects...
packet = dir[name].save(buf);
// We attach __save__() method to all objects...
if (dir[name].__save__){
packet = dir[name].__save__(buf);
} else {
// objects
packet = { type:'json', content: dir[name] }
}
packet.name = name; // override name
}
maplist.push(packet);
Expand All @@ -222,15 +227,13 @@ Object.getPrototypeOf(Int8Array.prototype).__save__ = function(buf) {
// really ... hacking ways
buf.write(this.slice().buffer);
return {type: getNativeType(this), name: this.name, length: this.length};
}
};


ArrayBuffer.prototype.__save__ = function(buf) {
buf.write(this);
return { name: this.name, type:'ArrayBuffer', byteLength:this.byteLength };
}

Object.prototype.__save__ = function() { return { name: this.name, type:'json', content: this }; }
};

export {
globals,
Expand Down
43 changes: 27 additions & 16 deletions backend/tensor.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { assert, checkClass, isArray } from 'util/assert.js';
import { assert, isArray } from 'util/assert.js';
import { zeros } from 'util/array.js';

class Tensor {
constructor(shape=[null], rawdata=undefined) {
constructor(shape = [null], rawdata = undefined) {
this._shape = shape;
this._size = 1;
let vacant = -1;
Expand All @@ -12,7 +12,7 @@ class Tensor {
assert(vacant <= 0, 'A tensor can have at most 1 dim of arbitary length.');
vacant = i;
} else {
assert(n > 0, 'Length of a specified dim should be positive.');
assert(n > 0, 'Length of a specified dim should be positive.');
assert(Math.floor(n) == n, 'Length of a specified dim should be an interger.');
this._size *= n;
}
Expand All @@ -37,6 +37,9 @@ class Tensor {

// turn into native
this.w = typeof rawdata !== 'undefined' ? Float32Array.from(rawdata) : zeros(this._size);

// TODO: optimize it out
this.dw = this.w.slice();
}

get buffer() {
Expand Down Expand Up @@ -77,19 +80,21 @@ class Tensor {
}

axis(idx) {
return idx >= 0 ? this.shape[idx] : this.shape[this.ndim + idx];
return idx >= 0 ? this.shape[idx] : this.shape[this.ndim + idx];
}

get max() {
let limit = this.size;
let amax = this.w[0], w = this.w;
let amax = this.w[0],
w = this.w;
for (let i = 1; i < limit; i++) {
if(w[i] > amax) amax = w[i];
if (w[i] > amax) amax = w[i];
}
return amax;
}

get softmax() {
let N = this.size;
let es = zeros(N);
this.softmax_a(es);
return es;
Expand All @@ -111,13 +116,12 @@ class Tensor {
for (let i = 0; i < N; i++) es[i] /= esum;
}

get max_index() {
let limit = this.size;
get max_index() {
let limit = this.size;
let amax = this.w[0];
let idx = 0;
for (let i = 1; i < limit; i++) {
if (this.w[i] > amax)
{
if (this.w[i] > amax) {
idx = i;
amax = this.w[i];
}
Expand All @@ -130,7 +134,7 @@ class Tensor {
this.dw[i] /= batch_size;
}
}

cloneAndZero() { return new Tensor(this.shape); }

clone() {
Expand All @@ -141,9 +145,16 @@ class Tensor {
return zeros(this.size);
}

save(buf) {
static fromNumber(value) {
return new Tensor([1], [value]);
}

/**
* Save tensor. Interface implement.
*/
__save__(buf) {
buf.write(this.w);
return {name: this.name, shape: this.shape, type: 'tensor'};
return { name: this.name, shape: this.shape, type: 'tensor' };
}

/**
Expand All @@ -161,7 +172,7 @@ class Tensor {
t.name = map.name;
return t;
}

}

class Vector extends Tensor {
Expand All @@ -171,10 +182,10 @@ class Vector extends Tensor {
}

class Placeholder {
constructor(shape=[null]) {
constructor(shape = [null]) {
this._shape = shape;
this._size = 1;
}
}

export { Tensor, Vector, Placeholder };
export { Tensor, Vector, Placeholder };
172 changes: 172 additions & 0 deletions demo/regression/demo.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import { Sequential } from 'topology/sequential.js';
import { Trainer, Batch } from 'trainer.js';

import { Tensor } from 'backend/tensor.js';
import { randf, sampleFunctionUniform } from 'util/random.js';
import { assert } from 'util/assert.js';

import { zip } from 'util/functools.js';

import { AvgWindow } from 'util/record.js';
import { Plot1D } from 'visualize/plot.js';

assert(d3, "d3js is required for this demo");

/* constant */

const batch_size = 32,
epoch = 5;

/* net */

var net = new Sequential();
net.makeLayers([
{ type: 'input', out_sx: 1, out_sy: 1, out_depth: 1 },
{ type: 'fc', num_neurons: 20 },
'lrelu',
{ type: 'fc', num_neurons: 20 },
'lrelu',
{ type: 'fc', num_neurons: 20 },
'sigmoid',
{ type: 'fc', num_neurons: 1 }
]);

var trainer = new Trainer(net, {
learning_rate: 0.001,
lr_decay: 0,
method: 'adam',
loss: 'mse',
batch_size: batch_size,
l2_decay: 0.001
});


var net_f = x => net.forward(Tensor.fromNumber(x)).w[0];

/* input */

var complexity = d3.select("#complexity").on("input", function() {
d3.select("#complexity-text").text(`complexity = ${+this.value}`);
reload();
}).node();

var traindata_size = d3.select('#traindata_size').on("input", function() {
d3.select("#traindata_size-text").text(`# of data points = ${+this.value}`);
gen_data(+this.value);
}).node();

/* control */

d3.select("#reload").on("click", function() {
reload();
});


d3.select("#play_control").on("click", function() {
if (conti) {
d3.select("#play_icon").text("play_arrow");
conti = false;
} else {
d3.select("#play_icon").text("pause");
conti = true;
iterate();
}

});


/* plot */

var svg = d3.select('#testbox')
.append('svg')
.attr('width', '800')
.attr('height', '400');

var plot = new Plot1D(svg, [-5, 5]);
var groundTruthPlot = plot.createPlot(450),
regressionPlot = plot.createPlot(100);

/* data */

var targetf, batch;

function gen_func(vari = 1) {
let A = [],
phi = [],
k = [];
for (let i = 0; i < vari; i++) {
A.push(randf(-1, 1));
phi.push(randf(-2, 2));
k.push(randf(-vari, vari));
}
targetf = function(x) {
let r = 0;
for (let i = 0; i < vari; i++) r += A[i] * Math.sin(k[i] * x + phi[i]);
return r;
};

plot.setSpan(targetf, 800, 400);
groundTruthPlot.draw(targetf);
}


function gen_data(N) {
var pairs = sampleFunctionUniform(targetf, N, -5, 5);
var [data, labels] = pairs;

// preprocess data
data = Array.from(data).map(Tensor.fromNumber);
labels = Array.from(labels).map(Tensor.fromNumber);
batch = new Batch(data, labels, batch_size);

var points = zip(pairs).map(([x, y]) => ({ x: x, y: y }));
regressionPlot.accurate = Math.max(100, N * 2);
plot.drawPoints(points);
}

function reload() {
gen_func(complexity.value | 0);
gen_data(traindata_size.value | 0);
}

/* record */

var avlosstext = d3.select('#avloss'),
avtimetext = d3.select('#avtime'),
iterstext = d3.select('#iters');

var loss_record = new AvgWindow(batch_size * epoch, 1);
var time_record = new AvgWindow(batch_size * epoch * 10, 1);


/* train */

var steps = 0;

function update() {
++steps;
for (let iters = 0; iters < epoch; iters++) {
let stats = trainer.trainBatch(...batch.nextBatch());
loss_record.push(stats.loss);
time_record.push(stats.batch_time);
}
if (steps % 5 == 0) {
iterstext.text(`${steps}`);
avlosstext.text(`${loss_record.average.toExponential(3)}`);
avtimetext.text(`${(time_record.average * epoch).toFixed(3)} ms`);
}
regressionPlot.draw(net_f);
}

/* running control */

var conti = true;

function iterate() {
update();
if (!conti) return;
window.requestAnimationFrame(iterate);
}

reload();
iterate();
Loading

0 comments on commit 2749da0

Please sign in to comment.