Skip to content

[TECS'23] CNNBench tool for generation and evaluation of CNN architectures.

License

Notifications You must be signed in to change notification settings

jha-lab/cnn_design-space

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CNNBench: A CNN Design-Space Generation Tool and Benchmark

Python Version Conda Tensorflow

This repository contains the tool CNNBench which can be used to generate and evaluate different Convolutional Neural Network (CNN) architectures pertinent to the domain of Machine-Learning Accelerators. This repository has been forked from nasbench and then expanded to cover a larger set of CNN architectures.

Table of Contents

Environment setup

Clone this repository

git clone https://github.com/JHA-Lab/cnn_design-space.git
cd cnn_design-space

Setup python environment

  • PIP
virtualenv cnnbench
source cnnbench/bin/activate
pip install -r requirements.txt
  • CONDA
source env_setup.sh

This installs a GPU version of Tensorflow. To run on CPU, tensorflow-cpu can be used instead.

Basic run of the tool

Running a basic version of the tool comprises of the following:

  • CNNs with modules comprising of upto two vertices, each is one of the operations in [MAXPOOL3X3, CONV1X1, CONV3X3]
  • Each module is stacked three times. A base stem of 3x3 convolution with 128 output channels is used. The stack of modules is followed by global average pooling and a final dense softmax layer.
  • Training on the CIFAR-10 dataset.

Download and prepare the CIFAR-10 dataset

cd cnnbenchs/scripts
python generate_tfrecords.py

To use another dataset (among CIFAR-10, CIFAR-100, MNIST, or ImageNet) use input arguments; check: python generate_tfrecords.py --help.

Generate computational graphs

cd ../../job_scripts
python generate_graphs_script.py

This will create a .json file of all graphs at: ../results/vertices_2/generate_graphs.json using the MD5 hashing algorithm.

To generate graphs of upto 'n' vertices with SHA-256 hashing algorithm, use: python generate_graphs_script.py --max_vertices n --hash_algo sha256.

Run evaluation over all generated graphs

python run_evaluation_script.py

This will save all the evaluated results and model checkpoints to ../results/vertices_2/evaluation.

To run evaluation over graphs generate with 'n' vertices, use: python run_evaluation_script.py --module_vertices n. For more input arguments, check: python run_evaluation_script.py --helpfull.

Generate the CNNBench dataset

python generate_dataset_script.py

This generates the CNNBench dataset as a cnnbench.tfrecord file with the evaluation results for all computational graphs that are trained.

For visualization use: visualization/cnnbench_results.ipynb.

This basic run as explained above can be implemented automatically by running the script: job_scripts/basic_run.sh.

Job Scripts

To efficiently use mutiple GPUs/TPUs on a cluster, a slurm script is provided at: job_scripts/job_basic.slurm. To run the tool on multiple nodes and utilize multiple GPUs in a cluster according to given constraints in the design-space, use job_scripts/job_creator_script.sh.

For more details on how to use this script, check: source job_scripts/job_creator_script.sh --help. Currently, these scripts only support running on Adroit/Tiger clusters at Princeton University.

More information about these clusters and their usage can be found at the Princeton Research Computing website.

Colab

You can directly run tests on the generated dataset using a Google Colaboratory without needing to install anything on your local machine. Click "Open in Colab" below:

Open In Colab

Todo

Broad-level tasks left:

  1. Implement end-to-end PyTorch training (replacing functions running in compatibility mode)
  2. Implement automatic hyper-parameter tuning.
  3. Define popular networks in expanded CNNBench framework.
  4. Run training on popular networks and correlate with performance in literature.
  5. Implement graph generation in the expanded design space starting from clusters around popular networks.