Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
1. Fix bugs under loading networks
2. separate weights initialization module
  • Loading branch information
suquark committed Jan 24, 2017
1 parent 242c4c6 commit 7b31a85
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 121 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

gen_layers.js

zws5.js

63 changes: 63 additions & 0 deletions conf.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@

// syntactic sugar function for getting default parameter values
function getopt(opt, field_name, default_value) {
if (typeof field_name === 'string') {
// case of single string
return (typeof opt[field_name] !== 'undefined') ? opt[field_name] : default_value;
} else {
// assume we are given a list of string instead
var ret = default_value;
for (var i = 0; i < field_name.length; i++) {
var f = field_name[i];
if (typeof opt[f] !== 'undefined') {
ret = opt[f]; // overwrite return value
}
}
return ret;
}
}

function load_opt(self, opt) {

// required is a list of string
opt.required.forEach(function(key) {
if (typeof opt[key] !== 'undefined') {
self[key] = opt[key];
} else {
console.error('cannot find necessary value of "' + key +'"');
}
});

opt.optional.forEach(function(pair) {
let v = pair.find(x => typeof x !== 'undefined')
pair.forEach(function(key) {
self[key] = v;
});
});

// pair is a list whose values should be same
opt.bind.forEach(function(pair) {
let v = pair.find(x => typeof opt[x] !== 'undefined')
pair.forEach(function(key) {
self[key] = v;
});
});



}

export {
randf,
randi,
randn,
zeros,
maxmin,
randperm,
weightedSample,
arrUnique,
arrContains,
getopt,
assert,
indexOfMax
};
File renamed without changes.
File renamed without changes.
File renamed without changes.
14 changes: 5 additions & 9 deletions layers/conv.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Layer } from 'layers/layer.js';
import {Vol} from 'vol.js';
import { Vol, createVector, createMatrix, getVolFromJSON } from 'vol.js';
import {getopt} from 'util.js';

// This file contains all layers that do dot products with input,
Expand Down Expand Up @@ -37,14 +37,10 @@ class ConvLayer extends Layer {
this.out_sx = Math.floor((this.in_sx + this.pad * 2 - this.sx) / this.stride + 1);
this.out_sy = Math.floor((this.in_sy + this.pad * 2 - this.sy) / this.stride + 1);


// initializations
let bias = getopt(opt, 'bias_pref', 0.0);
this.filters = [];
for (var i = 0; i < this.out_depth; i++) {
this.filters.push(new Vol(1, 1, this.in_depth));
}
this.biases = new Vol(1, 1, this.out_depth, bias);
this.filters = createMatrix(this.out_depth, this.in_depth);
this.biases = createVector(this.out_depth, bias);

// record updated values for updating
this.updated = this.filters.concat([this.biases]);
Expand Down Expand Up @@ -149,8 +145,8 @@ class ConvLayer extends Layer {
this.l1_decay_mul = getopt(json, 'l1_decay_mul', 0.0);
this.l2_decay_mul = getopt(json, 'l2_decay_mul', 1.0);

this.filters = json.filters.map(x => new Vol(0, 0, 0, 0).fromJSON(x));
this.biases = new Vol(0, 0, 0, 0).fromJSON(json.biases);
this.filters = json.filters.map(getVolFromJSON);
this.biases = getVolFromJSON(json.biases);
}

}
Expand Down
14 changes: 5 additions & 9 deletions layers/deconv.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Layer } from 'layers/layer.js'
import { Vol } from 'vol.js'
import { Vol, createVector, createMatrix, getVolFromJSON } from 'vol.js';
import {getopt} from 'util.js'

var get_deconv_outsize = function(size, k, s, p) {
Expand Down Expand Up @@ -45,12 +45,8 @@ class DeconvLayer extends Layer {

// initializations
let bias = getopt(opt, 'bias_pref', 0.0);
this.filters = [];
for (var i = 0; i < this.out_depth; i++)
{
this.filters.push(new Vol(1, 1, this.in_depth));
}
this.biases = new Vol(1, 1, this.out_depth, bias);
this.filters = createMatrix(this.out_depth, this.in_depth);
this.biases = createVector(this.out_depth, bias);

// record updated values for updating
this.updated = this.filters.concat([this.biases]);
Expand Down Expand Up @@ -123,8 +119,8 @@ class DeconvLayer extends Layer {
this.l1_decay_mul = getopt(json, 'l1_decay_mul', 0.0);
this.l2_decay_mul = getopt(json, 'l2_decay_mul', 1.0);

this.filters = json.filters.map(x => new Vol(0,0,0,0).fromJSON(x));
this.biases = new Vol(0,0,0,0).fromJSON(json.biases);
this.filters = json.filters.map(getVolFromJSON);
this.biases = getVolFromJSON(json.biases);
}

}
Expand Down
16 changes: 6 additions & 10 deletions layers/fullconn.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { OutputLayer } from 'layers/layer.js';
import { Vol } from 'vol.js';
import { Vol, createVector, createMatrix, getVolFromJSON } from 'vol.js';
import {getopt} from 'util.js';

class FullyConnLayer extends OutputLayer {
Expand All @@ -15,14 +15,10 @@ class FullyConnLayer extends OutputLayer {
this.l2_decay_mul = getopt(opt, 'l2_decay_mul', 1.0);

// initializations
let bias = getopt(opt, 'bias_pref', 0.0);
this.num_inputs = opt.in_sx * opt.in_sy * opt.in_depth;
this.filters = [];
for (var i = 0; i < this.out_depth; i++)
{
this.filters.push(new Vol(1, 1, this.num_inputs));
}
this.biases = new Vol(1, 1, this.out_depth, bias);
let bias = getopt(opt, 'bias_pref', 0.0);
this.filters = createMatrix(this.out_depth, this.num_inputs);
this.biases = createVector(this.out_depth, bias);

// record updated values for updating
this.updated = this.filters.concat([this.biases]);
Expand Down Expand Up @@ -79,8 +75,8 @@ class FullyConnLayer extends OutputLayer {
this.l1_decay_mul = getopt(json, 'l1_decay_mul', 0.0);
this.l2_decay_mul = getopt(json, 'l2_decay_mul', 1.0);

this.filters = json.filters.map(x => new Vol(0,0,0,0).fromJSON(x));
this.biases = new Vol(0,0,0,0).fromJSON(json.biases);
this.filters = json.filters.map(getVolFromJSON);
this.biases = getVolFromJSON(json.biases);
}
}

Expand Down
64 changes: 0 additions & 64 deletions test.html

This file was deleted.

76 changes: 47 additions & 29 deletions vol.js
Original file line number Diff line number Diff line change
@@ -1,40 +1,37 @@
import { randn } from 'util.js';
import { get_norm_weights } from 'weights_init.js';

class Vol {
constructor(sx, sy, depth, c) {
// this is how you check if a variable is an array. Oh, Javascript :)
if(Object.prototype.toString.call(sx) === '[object Array]') {
var arr = sx;
this.w = sx.slice(); // copy content
// we were given a list in sx, assume 1D volume and fill it up
sx = 1;
sy = 1;
depth = arr.length;
depth = this.w.length;
}

// we were given dimensions of the vol
this.sx = sx;
this.sy = sy;
this.depth = depth;
let n = this.sx * this.sy * this.depth;

if (Object.prototype.toString.call(sx) === '[object Array]') {
this.w = arr.slice(); // copy content
} else {
if(typeof c === 'undefined') {
// weight normalization is done to equalize the output
// variance of every neuron, otherwise neurons with a lot
// of incoming connections have outputs of larger variance
let scale = Math.sqrt(1.0 / (sx*sy*depth));
this.w = new Array(n).fill(0.);
for (let i = 0; i < n; i++) this.w[i] = randn(0.0, scale);

this.shape = [this.sx, this.sy, this.depth];
this.size = this.sx * this.sy * this.depth;

if (typeof this.w === 'undefined') {
if (typeof c === 'undefined') {
this.w = get_norm_weights(this.size);
} else {
this.w = new Array(n).fill(c);
this.w = new Array(this.size).fill(c);
}
}
this.dw = new Array(this.w.length).fill(0.);

this.dw = this.zeros_like();
this.length = this.w.length;
}


get(x, y, d) {
let ix = ((this.sx * y) + x) * this.depth + d;
return this.w[ix];
Expand All @@ -44,20 +41,20 @@ class Vol {
this.w[ix] = v;
}
add(x, y, d, v) {
var ix = ((this.sx * y) + x) * this.depth + d;
this.w[ix] += v;
var ix = ((this.sx * y) + x) * this.depth + d;
this.w[ix] += v;
}
get_grad(x, y, d) {
var ix = ((this.sx * y) + x) * this.depth + d;
return this.dw[ix];
var ix = ((this.sx * y) + x) * this.depth + d;
return this.dw[ix];
}
set_grad(x, y, d, v) {
var ix = ((this.sx * y) + x) * this.depth + d;
this.dw[ix] = v;
var ix = ((this.sx * y) + x) * this.depth + d;
this.dw[ix] = v;
}
add_grad(x, y, d, v) {
var ix = ((this.sx * y) + x) * this.depth + d;
this.dw[ix] += v;
var ix = ((this.sx * y) + x) * this.depth + d;
this.dw[ix] += v;
}
max(limit=V.w.length) {
if (limit === V.w.length)
Expand All @@ -77,7 +74,7 @@ class Vol {
return V;
}
zeros_like() {
return new Array(w.length).fill(0.);
return new Array(this.size).fill(0.);
}
addFrom(V) { for(var k = 0; k < this.w.length; k++) { this.w[k] += V.w[k]; }}
addFromScaled(V, a) { for(var k = 0; k < this.w.length; k++) { this.w[k] += a * V.w[k]; }}
Expand All @@ -104,11 +101,32 @@ class Vol {
this.w = new Array(n).fill(0.);
this.dw = new Array(n).fill(0.);
// copy over the elements.
this.w[i] = json.w.slice();
this.length = w.length;
this.w = json.w.slice();

this.shape = [this.sx, this.sy, this.depth];
this.size = this.sx * this.sy * this.depth;
this.length = this.size;
// for map function
return this;
}
}

export { Vol };
function createVector(depth, bias) {
// bias == undefined will cause initialize
return new Vol(1, 1, depth, bias);
}

function createMatrix(m, n) {
// m * n Matrix, m: output, n: input
let filters = [];
for (let i = 0; i < m; i++) {
filters.push(new Vol(1, 1, n));
}
return filters;
}

function getVolFromJSON(json) {
return new Vol(0, 0, 0, 0).fromJSON(json);
}

export { Vol, createVector, createMatrix, getVolFromJSON };
19 changes: 19 additions & 0 deletions weights_init.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { randn } from 'util.js';

// weight normalization is done to equalize the output
// variance of every neuron, otherwise neurons with a lot
// of incoming connections have outputs of larger variance

function norm_weights(V) {
let scale = Math.sqrt(1.0 / V.size);
for (let i = 0; i < n; i++) V.w[i] = randn(0.0, scale);
}

function get_norm_weights(size) {
let scale = Math.sqrt(1.0 / size);
let w = new Array(size)
for (let i = 0; i < size; i++) w[i] = randn(0.0, scale);
return w;
}

export { norm_weights, get_norm_weights };

0 comments on commit 7b31a85

Please sign in to comment.