Skip to content

Commit

Permalink
Create test_TResNetV2
Browse files Browse the repository at this point in the history
  • Loading branch information
mrT23 authored Jan 11, 2023
1 parent bdcd888 commit e0db3f5
Showing 1 changed file with 101 additions and 0 deletions.
101 changes: 101 additions & 0 deletions tests/test_TResNetV2
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@

# test_TResNetV2.py - Generated by https://www.codium.ai/

import unittest

"""
Code Analysis:
- This class is a subclass of the Module class from the torch.nn library.
- It initializes the TResNetV2 model with the given parameters.
- It creates a convolutional neural network with a body, head, and global pooling layers.
- It uses a SpaceToDepthModule, conv2d_ABN, AntiAliasDownsampleLayer, FastGlobalAvgPool2d, SEModule, InPlaceABN, and ABN.
- It uses a Bottleneck block with a convolutional layer, batch normalization, and ReLU activation.
- It uses a BasicBlock with a convolutional layer, batch normalization, and ReLU activation.
- It initializes the weights and biases of the convolutional layers and linear layers with Kaiming normal and constant values, respectively.
- It sets the conv2 and conv3 weights of the BasicBlock and Bottleneck layers to zero.
- It has a forward method which takes in an input and returns the logits.
"""


"""
Test strategies:
- test_init(): tests that the TResNetV2 model is initialized correctly with the given parameters.
- test_conv1(): tests that the conv1 layer is initialized correctly.
- test_anti_alias_layer(): tests that the AntiAliasDownsampleLayer is initialized correctly.
- test_global_pool_layer(): tests that the FastGlobalAvgPool2d is initialized correctly.
- test_bottleneck_block(): tests that the Bottleneck block is initialized correctly.
- test_basic_block(): tests that the BasicBlock is initialized correctly.
- test_weights_and_biases(): tests that the weights and biases of the convolutional layers and linear layers are initialized correctly.
- test_conv2_and_conv3_weights(): tests that the conv2 and conv3 weights of the BasicBlock and Bottleneck layers are set to zero.
- test_forward(): tests that the forward method takes in an input and returns the logits.
"""


class TestTResNetV2(unittest.TestCase):

def setUp(self):
self.layers = [3, 4, 6, 3]
self.in_chans = 3
self.num_classes = 1000
self.width_factor = 1.0
self.remove_model_jit = False
self.model = TResNetV2(self.layers, self.in_chans, self.num_classes, self.width_factor, self.remove_model_jit)

def test_init(self):
self.assertEqual(self.model.inplanes, 64)
self.assertEqual(self.model.planes, 64)
self.assertEqual(self.model.num_features, 2048)

def test_conv1(self):
conv1 = self.model.body[1]
self.assertIsInstance(conv1, ABN)
self.assertEqual(conv1.in_channels, 48)
self.assertEqual(conv1.out_channels, 64)
self.assertEqual(conv1.kernel_size, (3, 3))
self.assertEqual(conv1.stride, (1, 1))

def test_anti_alias_layer(self):
layer1 = self.model.body[3]
anti_alias_layer = layer1[0].anti_alias_layer
self.assertIsInstance(anti_alias_layer, partial)
self.assertEqual(anti_alias_layer.func, AntiAliasDownsampleLayer)
self.assertEqual(anti_alias_layer.args, (self.remove_model_jit,))

def test_global_pool_layer(self):
global_pool = self.model.global_pool[0]
self.assertIsInstance(global_pool, FastGlobalAvgPool2d)
self.assertTrue(global_pool.flatten)

def test_bottleneck_block(self):
layer1 = self.model.body[3]
bottleneck = layer1[0]
self.assertIsInstance(bottleneck, Bottleneck)
self.assertEqual(bottleneck.inplanes, 64)
self.assertEqual(bottleneck.planes, 64)
self.assertEqual(bottleneck.stride, 1)
self.assertIsInstance(bottleneck.downsample, nn.Sequential)
self.assertIsInstance(bottleneck.se, SEModule)

def test_basic_block(self):
layer2 = self.model.body[4]
basicblock = layer2[0]
self.assertIsInstance(basicblock, BasicBlock)
self.assertEqual(basicblock.inplanes, 128)
self.assertEqual(basicblock.planes, 128)
self.assertEqual(basicblock.stride, 2)
self.assertIsInstance(basicblock.downsample, nn.Sequential)

def test_weights_and_biases(self):
for m in self.model.modules():
if isinstance(m, nn.Conv2d):
weight = m.weight
fan = math.sqrt(weight[0].numel())
nn.init.kaiming_normal_(weight, mode='fan_out', nonlinearity='leaky_relu')
for w in weight:
for i in range(w[0].numel()):
self.assertAlmostEqual(w[0][i], 0, delta=0.001 * fan)

elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InPlaceABN):
weight = m.weight
bias = m.bias
nn.init.constant_(weight, 1)

0 comments on commit e0db3f5

Please sign in to comment.