A fast parallel implementation pure PyTorch implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition" https://arxiv.org/abs/1905.11235.
pip install torch-cif
git clone https://github.com/George0828Zhang/torch_cif
cd torch_cif
python setup.py install
def cif_function(
inputs: Tensor,
alpha: Tensor,
beta: float = 1.0,
tail_thres: float = 0.5,
padding_mask: Optional[Tensor] = None,
target_lengths: Optional[Tensor] = None,
eps: float = 1e-4,
unbound_alpha: bool = False
) -> Dict[str, List[Tensor]]:
r""" A fast parallel implementation of continuous integrate-and-fire (CIF)
https://arxiv.org/abs/1905.11235
Shapes:
N: batch size
S: source (encoder) sequence length
C: source feature dimension
T: target sequence length
Args:
inputs (Tensor): (N, S, C) Input features to be integrated.
alpha (Tensor): (N, S) Weights corresponding to each elements in the
inputs. It is expected to be after sigmoid function.
beta (float): the threshold used for determine firing.
tail_thres (float): the threshold for determine firing for tail handling.
padding_mask (Tensor, optional): (N, S) A binary mask representing
padded elements in the inputs. 1 is padding, 0 is not.
target_lengths (Tensor, optional): (N,) Desired length of the targets
for each sample in the minibatch.
eps (float, optional): Epsilon to prevent underflow for divisions.
Default: 1e-4
unbound_alpha (bool, optional): Whether to check if 0 <= alpha <= 1.
Returns -> Dict[str, List[Tensor]]: Key/values described below.
cif_out: (N, T, C) The output integrated from the source.
cif_lengths: (N,) The output length for each element in batch.
alpha_sum: (N,) The sum of alpha for each element in batch.
Can be used to compute the quantity loss.
delays: (N, T) The expected delay (in terms of source tokens) for
each target tokens in the batch.
tail_weights: (N,) During inference, return the tail.
scaled_alpha: (N, S) alpha after applying weight scaling.
cumsum_alpha: (N, S) cumsum of alpha after scaling.
right_indices: (N, S) right scatter indices, or floor(cumsum(alpha)).
right_weights: (N, S) right scatter weights.
left_indices: (N, S) left scatter indices.
left_weights: (N, S) left scatter weights.
"""
- This implementation uses
cumsum
andfloor
to determine the firing positions, and usescatter
to merge the weighted source features. The figure below demonstrates this concept using scaled weight sequence(0.4, 1.8, 1.2, 1.2, 1.4)
- Runing test requires
pip install hypothesis expecttest
. - If
beta != 1
, our implementation slightly differ from Algorithm 1 in the paper [1]:- When a boundary is located, the original algorithm add the last feature to the current integration with weight
1 - accumulation
(line 11 in Algorithm 1), which causes negative weights in next integration whenalpha < 1 - accumulation
. - We use
beta - accumulation
, which means the weight in next integrationalpha - (beta - accumulation)
is always positive.
- When a boundary is located, the original algorithm add the last feature to the current integration with weight
- Feel free to contact me if there are bugs in the code.