Mahout RandomForest Driver 実装法 -大規模分散 機械学習・判別 -
Apache Mahout は、Hadoop上で動作する大規模分散データマイニング・機械学習のライブラリ。
Random Forest は大規模データで高精度の分類・判別を実現するアルゴリズム。
Random Forestを、"R言語での実行のように容易"に "大規模分散 学習・判別"できるように、
Mahout を用いた各種 Driver を実装しました。
以下に実行方法、実装を紹介します。
- org.mahoutjp.df.ForestDriver
- Random Forest の分散学習から、分散判別、判別結果出力、および、精度評価まで行う Driver。
- org.mahoutjp.df.ForestClassificationDriver
- 生成された Forest Modelを用いて、分散判別、判別結果出力、および、精度評価まで行う Driver。
両 Driver とも、1コマンドで容易に分散実行できる実装にしています。
コマンドライン実行とともに、public static の ForestDriver.run(...), ForestClassificationDriver.run(...) も用意しており、シンプルに呼び出し実行できます。
Mahoutの Example SoucrCode にある、出力結果が読めない等の各種あった課題も、解決する実装にしています。
Random Forest 入門
Random Forest は、大規模データでも他分類器に比べて高精度・高速に判別できるアルゴリズム
入門資料:
org.mathoujp.df.ForestDriver: 分散学習・判別
実行法
Forest Modelの分散学習から、分散判別、判別結果出力、および、精度評価まで行えます。
$HADOOP_HOME/bin/hadoop jar mahoutjp-0.1-job.jar \ org.mahoutjp.df.ForestDriver \ - ... (Options)
実行オプション
以下の実行オプションを指定できます。
Options
--data (-d) path Data path
--descriptor (-dsc) descriptor [descriptor ...] data descriptor (N:Numerical, C:Categorical, L:Label)
--descriptor_out (-ds) file path Path to generated descriptor file
--selection (-sl) m Number of variables to select randomly at each tree-node
--nbtrees (-t) nbtrees Number of trees to grow
--forest_out (-fo) dir path Output path of Decision Forest
-oob Optional, estimate the out-of-bag error
--seed (-sd) seed Optional, seed value used to initialise the Random number generator
--partial (-p) Optional, use the Partial Data implementation
--testdata (-td) dir path Test data path
--predout (-po) dir path Path to generated prediction output file
--help (-h) Print out help
データ形式
入力形式:
dataid, attributes1, attribute2, ... (comma sep)
出力形式:
dataid, predictionIndex, realLabelIndex (tab sep)
実行例
$HADOOP_HOME/bin/hadoop jar mahoutjp-0.1-job.jar \ org.mahoutjp.df.ForestDriver \ -Dmapred.max.split.size=5074231 \ -d testdata/kdd/KDDTrain.arff \ -ds testdata/kdd/KDD.info \ -fo testdata/kdd/forest \ -dsc N 4 C 2 N C 4 N C 8 N 2 C 19 N L \ -oob \ -sl 7 \ -p \ -t 500 \ -td testdata/kdd/KDDTest \ -po testdata/kdd/predictions
※Data Descriptorの短縮形表記。個数 Type で短縮表記ができる。 "N N N I N N C C L I I I I I"は"3 N I N N 2 C L 5 I"。
データ:
NSL-KDD データ
データ定義行(@で始まる行)を削除。dataIdを1列目に付加。
学習データ:
- KDDTrain+.arff, 12万5973件
- http://nsl.cs.unb.ca/NSL-KDD/KDDTrain+.arff
- Path: testdata/kdd/KDDTrain.arff
dataid, attributes1, attribute2, ...
1000001,0,"tcp","ftp_data","SF",491,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,2,0,0,0,0,1,0,0,150,25,0.17,0.03,0.17,0,0,0,0.05,0,"normal" 1000002,0,"udp","other","SF",146,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,13,1,0,0,0,0,0.08,0.15,0,255,1,0,0.6,0.88,0,0,0,0,0,"normal" 1000003,0,"tcp","private","S0",0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,123,6,1,1,0,0,0.05,0.07,0,255,26,0.1,0.05,0,0,1,1,0,0,"anomaly" ...
テストデータ:
- KDDTest+.arff, 2万2544件
- http://nsl.cs.unb.ca/NSL-KDD/KDDTest+.arff
- Path: testdata/kdd/Test/part-00000, part-00001, part-00002.
testdata/kdd/Test/part-00000
2000001,0,"tcp","private","REJ",0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,229,10,0,0,1,1,0.04,0.06,0,255,10,0.04,0.06,0,0,0,0,1,1,"anomaly" 2000002,0,"tcp","private","REJ",0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,136,1,0,0,1,1,0.01,0.06,0,255,1,0,0.06,0,0,0,0,1,1,"anomaly" 2000003,2,"tcp","ftp_data","SF",12983,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,1,0,0,134,86,0.61,0.04,0.61,0.02,0,0,0,0,"normal" ...
出力結果:
- testdata/kdd/predictions/part-00000, part-00001, part-00002
dataid, predictionIndex, realLabelIndex
testdata/kdd/predictions/part-00000
2000001 1 1 2000002 1 1 2000003 0 0 ...
実行結果ログ(抜粋)
精度評価まで実行される。
11/05/07 21:32:09 INFO df.ForestDriver: Generating the descriptor... 11/05/07 21:32:09 INFO df.ForestDriver: generating the descriptor dataset... 11/05/07 21:32:12 INFO df.ForestDriver: storing the dataset description 11/05/07 21:32:13 INFO df.ForestDriver: Partial Mapred 11/05/07 21:32:13 INFO df.ForestDriver: Building the forest... ... 11/05/07 21:35:24 INFO df.ForestDriver: Build Time: 0h 3m 11s 360 11/05/07 21:35:26 INFO df.ForestDriver: oob error estimate : 0.0016114564232017972 11/05/07 21:35:26 INFO df.ForestDriver: Storing the forest in: testdata/kdd/forest/forest.seq ... 11/05/07 21:36:00 INFO df.ForestClassifier: # of data which cannot be predicted: 0 11/05/07 21:36:00 INFO df.ForestClassificationDriver: ======================================================= Summary ------------------------------------------------------- Correctly Classified Instances : 18013 79.9015% Incorrectly Classified Instances : 4531 20.0985% Total Classified Instances : 22544 ======================================================= Confusion Matrix ------------------------------------------------------- a b <--Classified as 9453 4273 | 13726 a = "normal" 258 8560 | 8818 b = "anomaly" Default Category: unknown: 2
org.mathoujp.df.ForestClassificationDriver: 分散判別
分散分類のみを実行。生成されたForest Modelを用いて、分散判別、判別結果出力、および、精度評価まで行えます。
実行法
$HADOOP_HOME/bin/hadoop jar mahoutjp-0.1-job.jar \ org.mahoutjp.df.ForestClassifyDriver \ -...(Options)
実行Option
--testdata (-td) dir path Test data path --dataset (-ds) file path Dataset path --model (-m) dir path Path to the Decision Forest --predout (-po) dir path Path to generated predictions file --analyze (-a) Analyze Results --help (-h) Print out help
実行例
$HADOOP_HOME/bin/hadoop jar mahoutjp-0.1-job.jar \ org.mahoutjp.df.ForestClassificationDriver \ -td testdata/kdd/KDDTest \ -ds testdata/kdd/KDD.info \ -m testdata/kdd/forest \ -po testdata/kdd/predictions \ -a
ソース
概要:
- Random Forestの分散学習・判別
- org.mahoutjp.df.ForestDriver
- Forest Modelによる分散判別
- org.mahoutjp.df.ForestClassificationDriver
- org.mahoutjp.df.ForestClassifier
org.mahoutjp.df.ForestDriver
package org.mahoutjp.df; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Random; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; import org.apache.commons.cli2.Option; import org.apache.commons.cli2.OptionException; import org.apache.commons.cli2.builder.ArgumentBuilder; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.commons.cli2.builder.GroupBuilder; import org.apache.commons.cli2.commandline.Parser; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.common.CommandLineUtil; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.df.DFUtils; import org.apache.mahout.df.DecisionForest; import org.apache.mahout.df.ErrorEstimate; import org.apache.mahout.df.builder.DefaultTreeBuilder; import org.apache.mahout.df.callback.ForestPredictions; import org.apache.mahout.df.data.Data; import org.apache.mahout.df.data.DataLoader; import org.apache.mahout.df.data.Dataset; import org.apache.mahout.df.data.DescriptorException; import org.apache.mahout.df.data.DescriptorUtils; import org.apache.mahout.df.mapreduce.Builder; import org.apache.mahout.df.mapreduce.inmem.InMemBuilder; import org.apache.mahout.df.mapreduce.partial.PartialBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Random Forest Driver * * @author hamadakoichi */ public class ForestDriver extends AbstractJob { private static final Logger log = LoggerFactory.getLogger(ForestDriver.class); @Override public int run(String[] args) throws Exception { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option dataOpt = obuilder.withLongName("data").withShortName("d") .withRequired(true).withArgument(abuilder.withName("file path").withMinimum(1).withMaximum(1).create()) .withDescription("Data file path") .create(); Option descriptorOpt = obuilder .withLongName("descriptor").withShortName("dsc") .withRequired(true) .withArgument(abuilder.withName("descriptor").withMinimum(1).create()) .withDescription("data descriptor file path").create(); Option descPathOpt = obuilder.withLongName("descriptor_out") .withShortName("ds") .withRequired(true).withArgument(abuilder.withName("file path").withMinimum(1).withMaximum(1).create()) .withDescription("Path to generated descriptor file").create(); Option selectionOpt = obuilder .withLongName("selection").withShortName("sl") .withRequired(true) .withArgument(abuilder.withName("m").withMinimum(1).withMaximum(1).create()) .withDescription("Number of variables to select randomly at each tree-node").create(); Option oobOpt = obuilder.withShortName("oob").withRequired(false) .withDescription("Optional, estimate the out-of-bag error").create(); Option seedOpt = obuilder .withLongName("seed").withShortName("sd") .withRequired(false) .withArgument(abuilder.withName("seed").withMinimum(1).withMaximum(1).create()) .withDescription("Optional, seed value used to initialise the Random number generator").create(); Option partialOpt = obuilder.withLongName("partial").withShortName("p") .withRequired(false).withDescription("Optional, use the Partial Data implementation").create(); Option nbtreesOpt = obuilder.withLongName("nbtrees").withShortName("t") .withRequired(true).withArgument(abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create()) .withDescription("Number of trees to grow").create(); Option forestOutOpt = obuilder.withLongName("forest_out").withShortName("fo") .withRequired(true).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create()) .withDescription("Output path of Decision Forest").create(); //For Predictions Option testDataOpt = obuilder.withLongName("testdata").withShortName("td") .withRequired(true).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create()) .withDescription("Test data path").create(); Option predictionOutOpt = obuilder.withLongName("predout").withShortName("po") .withRequired(true).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create()). withDescription("Path to generated prediction output file").create(); Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create(); Group group = gbuilder.withName("Options").withOption(oobOpt) .withOption(dataOpt).withOption(descriptorOpt).withOption(descPathOpt) .withOption(selectionOpt).withOption(seedOpt).withOption(partialOpt) .withOption(nbtreesOpt).withOption(forestOutOpt) .withOption(testDataOpt).withOption(predictionOutOpt) .withOption(helpOpt).create(); try { Parser parser = new Parser(); parser.setGroup(group); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption("help")) { CommandLineUtil.printHelp(group); return -1; } // Forest Parameter boolean isPartial = cmdLine.hasOption(partialOpt); boolean isOob = cmdLine.hasOption(oobOpt); String dataName = cmdLine.getValue(dataOpt).toString(); List<String> descriptor = convert(cmdLine.getValues(descriptorOpt)); String descriptorName = cmdLine.getValue(descPathOpt).toString(); String forestName = cmdLine.getValue(forestOutOpt).toString(); int m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString()); int nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString()); Long seed = 0L ; if (cmdLine.hasOption(seedOpt)) { seed = Long.valueOf(cmdLine.getValue(seedOpt).toString()); } // Classification Parameters String testdataName = cmdLine.getValue(testDataOpt).toString(); String predictionOutName = cmdLine.getValue(predictionOutOpt).toString(); log.debug("data : {}", dataName); log.debug("descriptor: {}", descriptor); log.debug("descriptorOut: {}", descriptorName); log.debug("forestOut : {}", forestName); log.debug("m : {}", m); log.debug("seed : {}", seed); log.debug("nbtrees : {}", nbTrees); log.debug("isPartial : {}", isPartial); log.debug("isOob : {}", isOob); log.debug("testData : {}", testdataName); log.debug("predictionOut : {}", predictionOutName); // Execute run(getConf(), dataName, descriptor, descriptorName, forestName, m, seed, nbTrees, isPartial, isOob, testdataName, predictionOutName); } catch (OptionException e) { log.error("Exception", e); CommandLineUtil.printHelp(group); return -1; } return 0; } public static void run(Configuration conf, String dataPathName, List<String> description, String descriptorPathName, String forestPathName, int m, Long seed, int nbTrees, boolean isPartial, boolean isOob, String testdataPathName, String predictionOutPathName) throws DescriptorException, IOException, ClassNotFoundException, InterruptedException { // Create Descriptor Path dataPath = validateInput(dataPathName); Path descriptorPath = validateOutput(descriptorPathName); createDescriptor(conf, dataPath, description, descriptorPath); // Build Forest Path forestPath = validateOutput(forestPathName); buildForest(conf, dataPath, description, descriptorPath, forestPath, m, seed, nbTrees, isPartial, isOob); // Predict boolean analyze = true; ForestClassificationDriver.run(conf, forestPathName, testdataPathName, descriptorPathName, predictionOutPathName, analyze); } /** * Greate Descreptor */ private static void createDescriptor(Configuration conf, Path dataPath, List<String> description, Path outPath) throws DescriptorException, IOException { log.info("Generating the descriptor..."); String descriptor = DescriptorUtils.generateDescriptor(description); log.info("generating the descriptor dataset..."); Dataset dataset = generateDataset(descriptor, dataPath); log.info("storing the dataset description"); DFUtils.storeWritable(conf, outPath, dataset); } /** * Generate DataSet */ private static Dataset generateDataset(String descriptor, Path dataPath) throws IOException, DescriptorException { FileSystem fs = dataPath.getFileSystem(new Configuration()); Path[] files = DFUtils.listOutputFiles(fs, dataPath); return DataLoader.generateDataset(descriptor, fs, files[0]); } /** * Build Forest */ private static void buildForest(Configuration conf, Path dataPath, List<String> description, Path descriptorPath, Path forestPath, int m, Long seed, int nbTrees, boolean isPartial, boolean isOob) throws IOException, ClassNotFoundException,InterruptedException { FileSystem ofs = forestPath.getFileSystem(conf); if (ofs.exists(forestPath)) { log.error("Forest Output Path already exists"); return; } DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder(); treeBuilder.setM(m); Dataset dataset = Dataset.load(conf, descriptorPath); ForestPredictions callback = isOob ? new ForestPredictions(dataset .nbInstances(), dataset.nblabels()) : null; Builder forestBuilder; if (isPartial) { log.info("Partial Mapred"); forestBuilder = new PartialBuilder(treeBuilder, dataPath, descriptorPath, seed, conf); } else { log.info("InMemory Mapred"); forestBuilder = new InMemBuilder(treeBuilder, dataPath, descriptorPath, seed, conf); } forestBuilder.setOutputDirName(forestPath.getName()); log.info("Building the forest..."); long time = System.currentTimeMillis(); DecisionForest forest = forestBuilder.build(nbTrees, callback); time = System.currentTimeMillis() - time; log.info("Build Time: {}", DFUtils.elapsedTime(time)); if (isOob) { Random rng; if (seed != null) { rng = RandomUtils.getRandom(seed); } else { rng = RandomUtils.getRandom(); } FileSystem fs = dataPath.getFileSystem(conf); int[] labels = Data.extractLabels(dataset, fs, dataPath); log.info("oob error estimate : " + ErrorEstimate.errorRate(labels, callback .computePredictions(rng))); } // store the forest Path forestoutPath = new Path(forestPath, "forest.seq"); log.info("Storing the forest in: " + forestoutPath); DFUtils.storeWritable(conf, forestPath, forest); } /** * Load data */ protected static Data loadData(Configuration conf, Path dataPath, Dataset dataset) throws IOException { log.info("Loading the data..."); FileSystem fs = dataPath.getFileSystem(conf); Data data = DataLoader.loadData(dataset, fs, dataPath); log.info("Data Loaded"); return data; } /** * Convert Collections to a String List */ private static List<String> convert(Collection<?> values) { List<String> list = new ArrayList<String>(values.size()); for (Object value : values) { list.add(value.toString()); } return list; } /** * Validation of the Output Path */ private static Path validateOutput(String filePath) throws IOException { Path path = new Path(filePath); FileSystem fs = path.getFileSystem(new Configuration()); if (fs.exists(path)) { throw new IllegalStateException(path.toString() + " already exists"); } return path; } /** * Validation of the Input Path */ private static Path validateInput(String filePath) throws IOException { Path path = new Path(filePath); FileSystem fs = path.getFileSystem(new Configuration()); if (!fs.exists(path)) { throw new IllegalArgumentException(path.toString() + " does not exist"); } return path; } public static void main(String[] args) throws Exception { ToolRunner.run(new Configuration(), new ForestDriver(), args); } }
org.mahoutjp.df.ForestClassificationDriver
package org.mahoutjp.df; import java.io.IOException; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; import org.apache.commons.cli2.Option; import org.apache.commons.cli2.OptionException; import org.apache.commons.cli2.builder.ArgumentBuilder; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.commons.cli2.builder.GroupBuilder; import org.apache.commons.cli2.commandline.Parser; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.common.CommandLineUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Classification Mapreduce using a built Decision Forest * * @author hamadakoichi */ public class ForestClassificationDriver extends AbstractJob { private static final Logger log = LoggerFactory.getLogger(ForestClassificationDriver.class); @Override public int run(String[] args) throws Exception { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option testDataOpt = obuilder.withLongName("testdata").withShortName("td") .withRequired(true).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create()) .withDescription("Test data path").create(); Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds") .withRequired(true).withArgument(abuilder.withName("file path").withMinimum(1).withMaximum(1).create()) .withDescription("Dataset path").create(); Option modelOpt = obuilder.withLongName("model").withShortName("m") .withRequired(true).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create()) .withDescription("Path to the Decision Forest").create(); Option predictionOutOpt = obuilder.withLongName("predout").withShortName("po") .withRequired(false).withArgument(abuilder.withName("dir path").withMinimum(1).withMaximum(1).create()) .withDescription("Path to generated predictions file").create(); Option analyzeOpt = obuilder.withLongName("analyze").withShortName("a") .withRequired(false).create(); Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create(); Group group = gbuilder.withName("Options").withOption(testDataOpt) .withOption(datasetOpt).withOption(modelOpt).withOption(predictionOutOpt) .withOption(analyzeOpt).withOption(helpOpt).create(); try { Parser parser = new Parser(); parser.setGroup(group); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption("help")) { CommandLineUtil.printHelp(group); return -1; } //Parameters String testDataName = cmdLine.getValue(testDataOpt).toString(); String descriptorName = cmdLine.getValue(datasetOpt).toString(); String forestName = cmdLine.getValue(modelOpt).toString(); String predictionName = (cmdLine.hasOption(predictionOutOpt)) ? cmdLine .getValue(predictionOutOpt).toString() : null; boolean analyze = cmdLine.hasOption(analyzeOpt); log.debug("inout : {}", testDataName); log.debug("descriptor: {}", descriptorName); log.debug("forest : {}", forestName); log.debug("prediction: {}", predictionName); log.debug("analyze : {}", analyze); //Execute Classification log.info("Execute the Mapreduce Classification ..."); run(getConf(), forestName, testDataName, descriptorName, predictionName, analyze); } catch (OptionException e) { log.warn(e.toString(), e); CommandLineUtil.printHelp(group); return -1; } return 0; } public static void run(Configuration conf, String forestPathName, String testDataPathName, String descriptorPathName, String predictionPathName, boolean analyze) throws IOException, ClassNotFoundException, InterruptedException{ // Classify data Path testDataPath = validateInput(testDataPathName); Path descriptorPath = validateInput(descriptorPathName); Path forestPath = validateInput(forestPathName); Path predictionPath = validateOutput(conf, predictionPathName); ForestClassifier classifier = new ForestClassifier(conf, forestPath, testDataPath, descriptorPath, predictionPath, analyze); classifier.run(); // Analyze Results if (analyze) { log.info(classifier.getAnalyzer().summarize()); } } /** * Validation of the Output Path */ private static Path validateOutput(Configuration conf, String filePath) throws IOException { Path path = new Path(filePath); FileSystem fs = path.getFileSystem(conf); if (fs.exists(path)) { throw new IllegalStateException(path.toString() + " already exists"); } return path; } /** * Validation of the Input Path */ private static Path validateInput(String filePath) throws IOException { Path path = new Path(filePath); FileSystem fs = path.getFileSystem(new Configuration()); if (!fs.exists(path)) { throw new IllegalArgumentException(path.toString() + " does not exist"); } return path; } public static void main(String[] args) throws Exception { ToolRunner.run(new Configuration(), new ForestClassificationDriver(), args); } }
org.mahoutjp.df.ForestClassifier
package org.mahoutjp.df; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.mahout.df.DecisionForest; import org.apache.mahout.df.DFUtils; import org.apache.mahout.df.data.DataConverter; import org.apache.mahout.df.data.Dataset; import org.apache.mahout.df.data.Instance; import org.apache.mahout.common.RandomUtils; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.filecache.DistributedCache; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.JobContext; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.LongWritable; import java.io.IOException; import java.util.Arrays; import java.util.Random; import java.util.Scanner; import java.net.URI; import org.apache.mahout.classifier.ClassifierResult; import org.apache.mahout.classifier.ResultAnalyzer; /** * Mapreduce implementation of Classifier with Forest * * @author hamadakoichi */ public class ForestClassifier { private static final Logger log = LoggerFactory .getLogger(ForestClassifier.class); private final Path forestPath; private final Path inputPath; private final Path datasetPath; private final Configuration conf; private final ResultAnalyzer analyzer; private final Dataset dataset; private final Path outputPath; public ForestClassifier(Configuration conf, Path forestPath, Path inputPath, Path datasetPath, Path outputPath, boolean analyze) throws IOException { this.forestPath = forestPath; this.inputPath = inputPath; this.datasetPath = datasetPath; this.conf = conf; if (analyze) { dataset = Dataset.load(conf, datasetPath); analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown"); } else { dataset = null; analyzer = null; } this.outputPath = outputPath; } /** * Classification Job Configure */ private void configureClsssifyJob(Job job) throws IOException { job.setJarByClass(ForestClassifier.class); FileInputFormat.setInputPaths(job, inputPath); FileOutputFormat.setOutputPath(job, outputPath); job.setOutputKeyClass(Text.class); job.setOutputValueClass(Text.class); job.setMapperClass(ClassifyMapper.class); job.setNumReduceTasks(0); // Classification Mapper Only job.setInputFormatClass(ClassifyTextInputFormat.class); job.setOutputFormatClass(TextOutputFormat.class); } public void run() throws IOException, ClassNotFoundException, InterruptedException { FileSystem fs = FileSystem.get(conf); if (fs.exists(outputPath)) { throw new IOException(outputPath + " already exists"); } // put the dataset log.info("Adding the dataset to the DistributedCache"); DistributedCache.addCacheFile(datasetPath.toUri(), conf); // load the forest log.info("Adding the forest to the DistributedCache"); DistributedCache.addCacheFile(forestPath.toUri(), conf); // Classification Job cjob = new Job(conf, "Decision Forest classification"); log.info("Configuring the Classification Job..."); configureClsssifyJob(cjob); log.info("Running the Classification Job..."); if (!cjob.waitForCompletion(true)) { log.error("Classification Job failed!"); return; } // Analyze Results if (analyzer != null) { analyzeOutput(cjob); } } public ResultAnalyzer getAnalyzer() { return analyzer; } /** * Analyze the Classification Results * @param job */ private void analyzeOutput(Job job) throws IOException { Configuration conf = job.getConfiguration(); Integer prediction; Integer realLabel; FileSystem fs = outputPath.getFileSystem(conf); Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath); int cnt_cnp = 0; for (Path path : outfiles) { FSDataInputStream input = fs.open(path); Scanner scanner = new Scanner(input); while (scanner.hasNext()) { String line = scanner.nextLine(); if (line.isEmpty()) { continue; } // id, predict, realLabel with \t sep String[] tmp = line.split("\t", -1); prediction = Integer.parseInt(tmp[1]); realLabel = Integer.parseInt(tmp[2]); if(prediction == -1) { // label cannot be predicted cnt_cnp++; } else{ if (analyzer != null) { analyzer.addInstance(dataset.getLabel(prediction), new ClassifierResult(dataset .getLabel(realLabel), 1.0)); } } } } log.info("# of data which cannot be predicted: " + cnt_cnp); } /** * Text Input Format: Each file is processed by single Mapper. */ public static class ClassifyTextInputFormat extends TextInputFormat { @Override protected boolean isSplitable(JobContext jobContext, Path path) { return false; } } /** * Classification Mapper. */ public static class ClassifyMapper extends Mapper<LongWritable, Text, LongWritable, Text> { private DataConverter converter; private DecisionForest forest; private final Random rng = RandomUtils.getRandom(); private final Text val = new Text(); @Override protected void setup(Context context) throws IOException, InterruptedException { super.setup(context); // To change body of overridden methods use Configuration conf = context.getConfiguration(); URI[] files = DistributedCache.getCacheFiles(conf); if ((files == null) || (files.length < 2)) { throw new IOException( "not enough paths in the DistributedCache"); } Dataset dataset = Dataset.load(conf, new Path(files[0].getPath())); converter = new DataConverter(dataset); forest = DecisionForest.load(conf, new Path(files[1].getPath())); if (forest == null) { throw new InterruptedException("DecisionForest not found!"); } } @Override protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { String line = value.toString(); if (!line.isEmpty()) { String[] idVal = line.split(",", -1); Integer id = Integer.parseInt(idVal[0]); Instance instance = converter.convert(id, line); int prediction = forest.classify(rng, instance); // key:id key.set(instance.getId()); // val: prediction, originalLabel (with tab sep) StringBuffer sb = new StringBuffer(); sb.append(Integer.toString(prediction)); sb.append("\t"); sb.append(instance.getLabel()); val.set(sb.toString()); context.write(key, val); } } } }
関連資料
- Mahout JP を立ち上げました #MahoutJP - hamadakoichi blog
- Apache Mahout 0.4 - Random Forests - #TokyoWebmining #8- hamadakoichi slideshare
- 第8回 データマイニング+WEB 勉強会@東京 ( #TokyoWebmining #8) −大規模解析・ウェブ・クオンツ 祭り−を開催しました - hamadakoichi blog
- 「R言語による Random Forest 徹底入門−集団学習による分類・予測−」− #TokyoR #11 で講師をしてきました - hamadakoichi blog