Skip to content

Commit

Permalink
Add initial LSTM network layer
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinCoble committed Jun 20, 2016
1 parent aaa74c1 commit c2eb1a2
Show file tree
Hide file tree
Showing 11 changed files with 1,452 additions and 592 deletions.
6 changes: 6 additions & 0 deletions AIToolbox.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
0E0E25FF1CA119160087932B /* PrincipalComponantAnalysis.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0E0E25FD1CA119160087932B /* PrincipalComponantAnalysis.swift */; };
0E0E26011CA44EA80087932B /* PCATests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0E0E26001CA44EA80087932B /* PCATests.swift */; };
0E0E26021CA44EA80087932B /* PCATests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0E0E26001CA44EA80087932B /* PCATests.swift */; };
0E12025E1D0490AD0072257B /* LSTMTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0E12025D1D0490AD0072257B /* LSTMTests.swift */; };
0E12025F1D0490AD0072257B /* LSTMTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0E12025D1D0490AD0072257B /* LSTMTests.swift */; };
0E3CB5A01CCD30AE008ABA4E /* Validation.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0E3CB59F1CCD30AE008ABA4E /* Validation.swift */; };
0E3CB5A11CCD30AE008ABA4E /* Validation.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0E3CB59F1CCD30AE008ABA4E /* Validation.swift */; };
0E3CB5A31CCD30D3008ABA4E /* ValidationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0E3CB5A21CCD30D3008ABA4E /* ValidationTests.swift */; };
Expand Down Expand Up @@ -148,6 +150,7 @@
0E0A4FC21C153D8500AD5AAE /* Kernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Kernel.swift; sourceTree = "<group>"; };
0E0E25FD1CA119160087932B /* PrincipalComponantAnalysis.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PrincipalComponantAnalysis.swift; sourceTree = "<group>"; };
0E0E26001CA44EA80087932B /* PCATests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PCATests.swift; sourceTree = "<group>"; };
0E12025D1D0490AD0072257B /* LSTMTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LSTMTests.swift; sourceTree = "<group>"; };
0E3CB59F1CCD30AE008ABA4E /* Validation.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Validation.swift; sourceTree = "<group>"; };
0E3CB5A21CCD30D3008ABA4E /* ValidationTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ValidationTests.swift; sourceTree = "<group>"; };
0E3CB5A71CDC3DEB008ABA4E /* RecurrentNeuralNetwork.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = RecurrentNeuralNetwork.swift; sourceTree = "<group>"; };
Expand Down Expand Up @@ -352,6 +355,7 @@
0E02FB9B1CC1B25400C32F9F /* NonLinearRegressionTests.swift */,
0E3CB5A21CCD30D3008ABA4E /* ValidationTests.swift */,
0E3CB5AD1CE6CA26008ABA4E /* LogisticRegressionTests.swift */,
0E12025D1D0490AD0072257B /* LSTMTests.swift */,
0EBA29BB1A91350C0012CEC9 /* Supporting Files */,
);
path = AIToolboxTests;
Expand Down Expand Up @@ -606,6 +610,7 @@
files = (
0E0A4FB61C1342EF00AD5AAE /* AlphaBetaTests.swift in Sources */,
0E0A4FB81C1342F500AD5AAE /* NeuralNetworkTests.swift in Sources */,
0E12025F1D0490AD0072257B /* LSTMTests.swift in Sources */,
0E02FB971CBF10E900C32F9F /* MixtureOfGaussianTests.swift in Sources */,
0E8C9F7E1C9BB0B000F88E34 /* KMeansTest.swift in Sources */,
0E3CB5A41CCD30D3008ABA4E /* ValidationTests.swift in Sources */,
Expand Down Expand Up @@ -662,6 +667,7 @@
files = (
0ECD88681A9C44BD00049F28 /* AlphaBetaTests.swift in Sources */,
0EA941AD1BD5F80E006BAECD /* NeuralNetworkTests.swift in Sources */,
0E12025E1D0490AD0072257B /* LSTMTests.swift in Sources */,
0E02FB961CBF10E900C32F9F /* MixtureOfGaussianTests.swift in Sources */,
0E8C9F7D1C9BB0B000F88E34 /* KMeansTest.swift in Sources */,
0E3CB5A31CCD30D3008ABA4E /* ValidationTests.swift in Sources */,
Expand Down
17 changes: 14 additions & 3 deletions AIToolbox/DataSet.swift
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,22 @@ public class DataSet {
// Validate the data
if (dataType != .Regression) { throw DataTypeError.DataWrongForType }
if (index < 0) { throw DataIndexError.Negative }
if (index > inputs.count) { throw DataIndexError.Negative }
if (index > inputs.count) { throw DataIndexError.IndexAboveDimension }
if (newOutput.count != outputDimension) { throw DataTypeError.WrongDimensionOnOutput }

// Add the new output item
outputs![index] = newOutput
// Make sure we have outputs up until this index (we have the inputs already)
if (index >= outputs!.count) {
while (index > outputs!.count) { // Insert any uncreated data between this index and existing values
outputs!.append([Double](count:outputDimension, repeatedValue: 0.0))
}
// Append the new data
outputs!.append(newOutput)
}

else {
// Replace the new output item
outputs![index] = newOutput
}
}

public func setClass(index: Int, newClass : Int) throws
Expand Down
49 changes: 31 additions & 18 deletions AIToolbox/KMeans.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ public class KMeans {
// 1. Choose one center uniformly at random from among the data points.
centroids = []
var pointIndex = Int(arc4random_uniform(UInt32(data.size)))
centroids.append(data.inputs[pointIndex])
let inputs = try data.getInput(pointIndex)
centroids.append(inputs)
data.classes![pointIndex] = 0

while (centroids.count < numClasses) {
Expand All @@ -63,7 +64,8 @@ public class KMeans {
var distanceSquared = 0.0
var minDistanceSquared = Double.infinity
for centroid in centroids {
vDSP_distancesqD(data.inputs[point], 1, centroid, 1, &distanceSquared, vDSP_Length(data.inputDimension))
let inputs = try data.getInput(point)
vDSP_distancesqD(inputs, 1, centroid, 1, &distanceSquared, vDSP_Length(data.inputDimension))
if (distanceSquared < minDistanceSquared) {
minDistanceSquared = distanceSquared
}
Expand All @@ -86,7 +88,8 @@ public class KMeans {
if (selectionDistance < totalDistanceToIndex) {break}
}
data.classes![pointIndex] = centroids.count
centroids.append(data.inputs[pointIndex])
let inputs = try data.getInput(pointIndex)
centroids.append(inputs)

// 4. Repeat Steps 2 and 3 until k centers have been chosen.
}
Expand All @@ -107,7 +110,8 @@ public class KMeans {
// Set the centroids to those point values
centroids = []
for classIndex in 0..<numClasses {
centroids.append(data.inputs[initialSet[classIndex]])
let inputs = try data.getInput(initialSet[classIndex])
centroids.append(inputs)
}
}

Expand All @@ -118,19 +122,23 @@ public class KMeans {
// Assign each point to the nearest mean
changedClass = false
for point in 0..<data.size {
var newClass = -1
var closestDistanceSquared = Double.infinity
for testClass in 0..<numClasses {
vDSP_distancesqD(data.inputs[point], 1, centroids[testClass], 1, &distanceSquared, vDSP_Length(data.inputDimension))
if (distanceSquared < closestDistanceSquared) {
newClass = testClass
closestDistanceSquared = distanceSquared
do {
let inputs = try data.getInput(point)
var newClass = -1
var closestDistanceSquared = Double.infinity
for testClass in 0..<numClasses {
vDSP_distancesqD(inputs, 1, centroids[testClass], 1, &distanceSquared, vDSP_Length(data.inputDimension))
if (distanceSquared < closestDistanceSquared) {
newClass = testClass
closestDistanceSquared = distanceSquared
}
}
if (newClass != data.classes![point]) {
data.classes![point] = newClass
changedClass = true
}
}
if (newClass != data.classes![point]) {
data.classes![point] = newClass
changedClass = true
}
catch { print("error indexing input array") }
}

// Move the centroid of each class to the mean of all the points assigned to the class
Expand All @@ -139,10 +147,15 @@ public class KMeans {
var startLoc = [Double](count: data.inputDimension, repeatedValue: 0.0)
centroids[testClass] = [Double](count: data.inputDimension, repeatedValue: 0.0)
for point in 0..<data.size {
if (data.classes![point] == testClass) {
vDSP_vaddD(data.inputs[point], 1, startLoc, 1, &startLoc, 1, vDSP_Length(data.inputDimension))
count += 1
do {
let inputs = try data.getInput(point)
let currentClass = try data.getClass(point)
if (currentClass == testClass) {
vDSP_vaddD(inputs, 1, startLoc, 1, &startLoc, 1, vDSP_Length(data.inputDimension))
count += 1
}
}
catch { print("error indexing class array") }
}
if (count > 0) {
var scale = 1.0 / Double(count)
Expand Down
10 changes: 8 additions & 2 deletions AIToolbox/Kernel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,14 @@ class Kernel {
func dotProduct(vector1Index : Int, _ vector2Index : Int) -> Double
{
var sum = 0.0
vDSP_dotprD(problemData.inputs[vector1Index], 1, problemData.inputs[vector2Index], 1, &sum, vDSP_Length(problemData.inputDimension))

do {
let vector1 = try problemData.getInput(vector1Index)
let vector2 = try problemData.getInput(vector2Index)
vDSP_dotprD(vector1, 1, vector2, 1, &sum, vDSP_Length(problemData.inputDimension))
}
catch {
print("invalid index in kernel dotProduct - \(vector1Index) or \(vector2Index)")
}
return sum
}

Expand Down
Loading

0 comments on commit c2eb1a2

Please sign in to comment.