Skip to content

Commit

Permalink
feat: add triang_filter_bank, tests and utils
Browse files Browse the repository at this point in the history
  • Loading branch information
Teagum committed Jul 4, 2024
1 parent 691e60e commit c8b2d90
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 0 deletions.
111 changes: 111 additions & 0 deletions src/apollon/signal/filter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""
Simple filter implementations
"""
from typing import Literal, Self, Sequence

import numpy as np
from pydantic import BaseModel, model_validator
import scipy.signal as _scs

from .. typing import FloatArray, floatarray
from . tools import mel_to_hz, hz_to_mel


def coef_bw_bandpass(low: int, high: int, fps: int, order: int = 4
Expand Down Expand Up @@ -41,3 +46,109 @@ def bandpass_filter(inp: FloatArray, fps: int, low: int, high: int,
"""
coeffs = coef_bw_bandpass(low, high, fps, order)
return floatarray(_scs.lfilter(*coeffs, inp))


def triang_filter_bank(low: float, high: float, n_filters: int, fps: int, size: int,
domain: Literal["mel"] = "mel"
) -> FloatArray:
"""Compute a bank of triangular filters.
This function computes ``n_filters`` triangular filters. The center
frequencies are linearly spaced in the given domain. Currently, only
'Mel' domain is implemented.
Args:
low: Lower cut-off frequency in Hz
high: Upper cut-off frequency in Hz
n_filters: Number of filters
fps: Sample rate
n_fft: FFT length
domain: Spacing domain, either "mel", "hz". Default ist "mel".
Returns:
Array with ``n_filters`` rows and columns determined by ``n_fft``.
"""
if low < 0:
raise ValueError("Lower cut-off frequency below 0 Hz")

if low >= high:
raise ValueError("Lower cut-off frequency greater or equal then high")

if high > fps//2:
raise ValueError("Upper cut-off frequency greater or equal Nyquist")

frq_space = mel_space(low, high, n_filters+2, endpoint=True)
filter_frqs = np.lib.stride_tricks.sliding_window_view(frq_space.ravel(), 3)
return triang(fps, size, filter_frqs)


def mel_space(start: float, stop: float, num: int, endpoint: bool = True) -> FloatArray:
space = np.linspace(hz_to_mel(start), hz_to_mel(stop), num, endpoint=endpoint)
return mel_to_hz(space)


def bin_from_frq(fps: int, size: int, frqs: float | FloatArray) -> FloatArray:
"""Compute the index of the FFT bin with closest center frequency to ``frqs``.
This function computes the bin index regarding a real FFT.
Args:
fps: Sample rate
n_fft: FFT length
frqs: Frequencies in Hz
Returns:
Index of nearest FFT bin.
"""
out = np.empty_like(frqs, dtype=int)
np.rint(frqs*size/fps, casting="unsafe", out=out)
return out


def triang(fps: int, n_fft: int, frqs: FloatArray,
amps: tuple[float, float, float] = (0.0, 1.0, 0.0)
) -> FloatArray:
"""Compute a triangular filter.
Compute a triangular filter of size ``n_fft'' from an array of frequencies
``frqs''. The frequency array must be of shape (n, 3), where each row
corresponds to a filter and the columns are interpreted as the lower
cut-off, center, and upper cut-off frequencies.
The filter response at the constituting frequencies is controlled with a
triplet of amplitudes ``amps''. The default specifies no response at the
cut-off frequencies and maximal response at the center frequency.
The filter has zero response at each of the remaining frequencies.
Args:
fps: Sampling rate
n_fft: Length of the filter
frqs: Constituting freqencies
amps: Amplitude of the filter at the constituting frequencies
Returns:
Array of triangular filters with shape (frqs.shape[0], size).
"""
if n_fft < 4:
raise ValueError("``n_fft'' is less than 3")

filters = []
for low, ctr, high in bin_from_frq(fps, n_fft, frqs):
out = np.zeros((n_fft+1)//2 if n_fft % 2 else n_fft//2+1)
roi = np.arange(low, high+1, dtype=int)
out[roi] = np.interp(roi, (low, ctr, high), amps)
filters.append(out)
return np.vstack(filters)


class TriangFilterSpec(BaseModel):
low: float
high: float
n_filters: int

@model_validator(mode="after")
def check_low_lt_high(self) -> Self:
if self.low >= self.high:
raise ValueError("low freq must be less then high")
return self
Empty file added tests/signal/__init__.py
Empty file.
17 changes: 17 additions & 0 deletions tests/signal/strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from hypothesis import strategies as st
from hypothesis import assume, given


__all__ = ["frequencies", "samplerates", "fftsizes"]


def frequencies(min_value: float = 0, max_value: float | None = None) -> st.SearchStrategy[float]:
if min_value < 0:
raise ValueError("Value of lower frequency bound less than 0")
return st.floats(min_value=min_value, max_value=max_value, allow_infinity=False, allow_nan=False)

def samplerates() -> st.SearchStrategy[int]:
return st.integers(min_value=1, max_value=96000)

def fftsizes(min_value: int = 1) -> st.SearchStrategy[int]:
return st.integers(min_value=min_value, max_value=48000)
77 changes: 77 additions & 0 deletions tests/signal/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from unittest import TestCase

from apollon.typing import FloatArray
from apollon.signal.filter import triang, triang_filter_bank
from hypothesis import strategies as st
from hypothesis import assume, given
import numpy as np

from .strategies import fftsizes, frequencies, samplerates


@st.composite
def triang_const_frqs(draw: st.DrawFn, fps: int, n_fft: int) -> tuple[float, float, float]:
if n_fft < 3:
raise ValueError("n_fft less than 4")

frqs = np.fft.rfftfreq(n_fft, 1/fps)

if n_fft == 3:
return tuple(frqs)
else:
center = frqs.size // 2
low = draw(st.sampled_from(frqs[:center])) # type: ignore
high = draw(st.sampled_from(frqs[center+1:])) # type: ignore

return low, center, high


@st.composite
def filterspecs(draw: st.DrawFn) -> tuple[int, int, FloatArray]:
fps = draw(samplerates())
n_fft = draw(fftsizes(min_value=4))
items = draw(st.lists(triang_const_frqs(fps, n_fft), min_size=1, max_size=50))
return (fps, n_fft, np.array(items))


@st.composite
def triangspec(draw: st.DrawFn) -> tuple[float, float, int, int, int]:
fps = draw(samplerates())
n_fft = draw(fftsizes(min_value=4))
f_max = (n_fft+1)//2 if n_fft % 2 else n_fft//2+1
low = draw(frequencies(max_value=f_max//4))
high = draw(frequencies(min_value=f_max//2, max_value=f_max))
nflt = draw(st.integers(min_value=1, max_value=100))
return (low, high, nflt, fps, n_fft)

class TestTriang(TestCase):
@given(
filterspecs(),
st.just((0,1,0))
)
def test_triang(self, filterspec: tuple[int, int, FloatArray], amps: tuple[float, float, float]) -> None:
fps, size, frqs = filterspec
res = triang(fps, size, frqs, amps)
self.assertEqual(res.shape[0], frqs.shape[0])

@given(st.integers(min_value=0, max_value=2))
def test_bad_nfft(self, n_fft: int) -> None:
with self.assertRaises(ValueError):
triang(1, n_fft, np.array([[0, 1, 2]]))


class TestTriangFilterBank(TestCase):
@given(triangspec())
def test_triang_filter_bank(self, spec: tuple[float, float, int, int, int]) -> None:
low, high, n_filters, fps, size = spec
if low < high:
if high > fps//2:
with self.assertRaises(ValueError):
fb = triang_filter_bank(low, high, n_filters, fps, size)
else:
fb = triang_filter_bank(low, high, n_filters, fps, size)
self.assertIsInstance(fb, np.ndarray)
else:
with self.assertRaises(ValueError):
fb = triang_filter_bank(low, high, n_filters, fps, size)
self.assertIsInstance(fb, np.ndarray)

0 comments on commit c8b2d90

Please sign in to comment.