Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions mojo/stdlib/std/builtin/variadics.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,67 @@ struct Variadic:
values: The values to zip.
"""

comptime filter_types[
T: type_of(AnyType),
//,
*element_types: T,
predicate: _TypePredicateGenerator[T],
] = _ReduceVariadicAndIdxToVariadic[
BaseVal = Variadic.empty_of_trait[T],
VariadicType=element_types,
Reducer = _FilterReducer[T, predicate],
]
"""Filter types from a variadic sequence based on a predicate function.

Returns a new variadic containing only the types for which the predicate
returns True.

Parameters:
T: The trait that the types conform to.
element_types: The input variadic sequence.
predicate: A generator function that takes a type and returns Bool.

Examples:

```mojo
from std.builtin.variadics import Variadic
from utils import Variant
from sys.intrinsics import _type_is_eq

comptime FullVariant = Variant[Int, String, Float64, Bool]

# Exclude a single type
comptime IsNotInt[Type: AnyType] = not _type_is_eq[Type, Int]()
comptime WithoutInt = Variadic.filter_types[*FullVariant.Ts, predicate=IsNotInt]
comptime FilteredVariant = Variant[*WithoutInt]
# FilteredVariant is Variant[String, Float64, Bool]

# Keep only specific types
comptime IsNumeric[Type: AnyType] = (
_type_is_eq[Type, Int]() or _type_is_eq[Type, Float64]()
)
comptime OnlyNumeric = Variadic.filter_types[*FullVariant.Ts, predicate=IsNumeric]
# OnlyNumeric is Variadic.types[T=AnyType, Int, Float64]

# Exclude multiple types using a variadic check
comptime ExcludeList = Variadic.types[T=AnyType, Int, Bool]
comptime NotInList[Type: AnyType] = not Variadic.contains[
type=Type, element_types=ExcludeList
]
comptime Filtered = Variadic.filter_types[*FullVariant.Ts, predicate=NotInList]
# Filtered is Variadic.types[T=AnyType, String, Float64]
```

Filter operations can be chained for complex transformations:

```mojo
comptime IsNotBool[Type: AnyType] = not _type_is_eq[Type, Bool]()
comptime Step1 = Variadic.filter_types[*FullVariant.Ts, predicate=IsNotBool]
comptime Step2 = Variadic.filter_types[*Step1, predicate=IsNotInt]
comptime ChainedVariant = Variant[*Step2]
```
"""


# ===-----------------------------------------------------------------------===#
# VariadicList / VariadicListMem
Expand Down Expand Up @@ -1223,3 +1284,39 @@ Parameters:
From: The input variadic sequence.
idx: The current index being processed.
"""

comptime _TypePredicateGenerator[T: type_of(AnyType)] = __mlir_type[
`!lit.generator<<"Type": `,
T,
`>`,
Bool,
`>`,
]
"""Generator type for type predicates.

A predicate takes a type and returns a boolean indicating whether to keep it.

Parameters:
T: The trait that the types conform to.
"""

comptime _FilterReducer[
Trait: type_of(AnyType),
Predicate: _TypePredicateGenerator[Trait],
Prev: Variadic.TypesOfTrait[Trait],
From: Variadic.TypesOfTrait[Trait],
idx: Int,
] = (
Variadic.concat[Prev, Variadic.types[T=Trait, From[idx]]] if Predicate[
From[idx]
] else Prev
)
"""A reducer that filters types based on a predicate function.

Parameters:
Trait: The trait that the types conform to.
Predicate: A generator that takes a type and returns Bool.
Prev: The accumulated result variadic so far.
From: The input variadic sequence.
idx: The current index being processed.
"""
49 changes: 49 additions & 0 deletions mojo/stdlib/test/builtin/test_variadic.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,54 @@ def test_map_types_to_types():
assert_true(_type_is_eq[variadic[1], String]())


def test_filter_types_exclude_one():
comptime IsNotInt[Type: Movable] = not _type_is_eq[Type, Int]()
comptime without_int = Variadic.filter_types[
*Tuple[Int, String, Float64, Bool].element_types, predicate=IsNotInt
]
assert_equal(Variadic.size(without_int), 3)
assert_true(_type_is_eq[without_int[0], String]())
assert_true(_type_is_eq[without_int[1], Float64]())
assert_true(_type_is_eq[without_int[2], Bool]())


def test_filter_types_keep_only():
comptime IsStringOrFloat[Type: Movable] = (
_type_is_eq[Type, String]() or _type_is_eq[Type, Float64]()
)
comptime kept = Variadic.filter_types[
*Tuple[Int, String, Float64, Bool].element_types,
predicate=IsStringOrFloat,
]
assert_equal(Variadic.size(kept), 2)
assert_true(_type_is_eq[kept[0], String]())
assert_true(_type_is_eq[kept[1], Float64]())


def test_filter_types_exclude_many():
comptime NotIntOrBool[Type: Movable] = (
not _type_is_eq[Type, Int]() and not _type_is_eq[Type, Bool]()
)
comptime filtered = Variadic.filter_types[
*Tuple[Int, String, Float64, Bool].element_types,
predicate=NotIntOrBool,
]
assert_equal(Variadic.size(filtered), 2)
assert_true(_type_is_eq[filtered[0], String]())
assert_true(_type_is_eq[filtered[1], Float64]())


def test_filter_types_chained():
comptime IsNotBool[Type: Movable] = not _type_is_eq[Type, Bool]()
comptime IsNotInt[Type: Movable] = not _type_is_eq[Type, Int]()
comptime step1 = Variadic.filter_types[
*Tuple[Int, String, Float64, Bool].element_types, predicate=IsNotBool
]
comptime step2 = Variadic.filter_types[*step1, predicate=IsNotInt]
assert_equal(Variadic.size(step2), 2)
assert_true(_type_is_eq[step2[0], String]())
assert_true(_type_is_eq[step2[1], Float64]())


def main():
TestSuite.discover_tests[__functions_in_module()]().run()
Loading