diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e9c80d578..b94ca5b10 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,13 +26,15 @@ """ from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - TYPE_CHECKING) + TYPE_CHECKING, Iterable) from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, ShapeType) -from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper +from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper, CombineMapper from pytato.loopy import LoopyCall +from pytools.tag import Tag +from pytato.tags import ImplStored if TYPE_CHECKING: from pytato.distributed import DistributedRecv, DistributedSendRefHolder @@ -47,6 +49,11 @@ .. autofunction:: get_num_nodes .. autoclass:: DirectPredecessorsGetter + +.. autoclass:: TagCountMapper +.. autofunction:: get_num_tags_of_type + +.. autofunction:: get_num_materialized """ @@ -371,12 +378,78 @@ def post_visit(self, expr: Any) -> None: def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """Returns the number of nodes in DAG *outputs*.""" - from pytato.codegen import normalize_outputs - outputs = normalize_outputs(outputs) - ncm = NodeCountMapper() ncm(outputs) return ncm.count # }}} + + +# {{{ TagCountMapper + +class TagCountMapper(CombineMapper[int]): + """ + Returns the number of nodes in a DAG that are tagged with all the tags in *tags*. + """ + + def __init__(self, tags: Union[Tag, Iterable[Tag]]) -> None: + super().__init__() + if isinstance(tags, Tag): + tags = frozenset((tags,)) + elif not isinstance(tags, frozenset): + tags = frozenset(tags) + self._tags = tags + + def combine(self, *args: int) -> int: + return sum(args) + + # type-ignore reason: incompatible return type with super class + def rec(self, expr: ArrayOrNames) -> int: # type: ignore + if expr in self.cache: + return self.cache[expr] + + if isinstance(expr, Array) and self._tags <= expr.tags: + result = 1 + super().rec(expr) + else: + result = 0 + super().rec(expr) + + self.cache[expr] = 0 + return result + + +def get_num_tags_of_type( + outputs: Union[Array, DictOfNamedArrays], + tags: Union[Tag, Iterable[Tag]]) -> int: + """Returns the number of nodes in DAG *outputs* that are tagged with + all the tags in *tags*.""" + + tcm = TagCountMapper(tags) + + return tcm(outputs) + +# }}} + + +def get_num_materialized(outputs: Union[Array, DictOfNamedArrays]) \ + -> Dict[ArrayOrNames, int]: + """Returns the number of materialized nodes each node in *outputs* depends on.""" + from pytato.transform import rec_get_all_user_nodes + users = rec_get_all_user_nodes(outputs) + + def is_materialized(expr: ArrayOrNames) -> bool: + if (isinstance(expr, Array) and + any(isinstance(tag, ImplStored) for tag in expr.tags)): + return True + else: + return False + + res: Dict[ArrayOrNames, int] = {} + + for node in users.keys(): + if is_materialized(node): + for user in users[node]: + res.setdefault(user, 0) + res[user] += 1 + + return res diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index c60c0e79c..1c69e2f47 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -86,6 +86,7 @@ .. autofunction:: reverse_graph .. autofunction:: tag_user_nodes .. autofunction:: rec_get_user_nodes +.. autofunction:: rec_get_all_user_nodes .. autofunction:: deduplicate_data_wrappers @@ -1030,8 +1031,8 @@ def _materialize_if_mpms(expr: Array, ) -> MPMSMaterializerAccumulator: """ Returns an instance of :class:`MPMSMaterializerAccumulator`, that - materializes *expr* if it has more than 1 successors and more than 1 - materialized predecessors. + materializes *expr* if it has more than 1 successor and more than 1 + materialized predecessor. """ from functools import reduce @@ -1250,8 +1251,8 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: .. note:: - MPMS materialization strategy is a greedy materialization algorithm in - which any node with more than 1 materialized predecessors and more than - 1 successors is materialized. + which any node with more than 1 materialized predecessor and more than + 1 successor is materialized. - Materializing here corresponds to tagging a node with :class:`~pytato.tags.ImplStored`. - Does not attempt to materialize sub-expressions in @@ -1292,13 +1293,21 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: ====== ======== ======= """ - from pytato.analysis import get_nusers + from pytato.analysis import get_nusers, get_num_nodes, get_num_tags_of_type materializer = MPMSMaterializer(get_nusers(expr)) new_data = {} for name, ary in expr.items(): new_data[name] = materializer(ary.expr).expr - return DictOfNamedArrays(new_data) + res = DictOfNamedArrays(new_data) + + from pytato import DEBUG_ENABLED + if DEBUG_ENABLED: + logger.info("materialize_with_mpms: materialized " + f"{get_num_tags_of_type(res, ImplStored())} out of " + f"{get_num_nodes(res)} nodes") + + return res # }}} @@ -1501,6 +1510,20 @@ def rec_get_user_nodes(expr: ArrayOrNames, return _recursively_get_all_users(users, node) +def rec_get_all_user_nodes(expr: ArrayOrNames) \ + -> Dict[ArrayOrNames, FrozenSet[ArrayOrNames]]: + """ + Returns all direct and indirect users of all nodes in *expr*. + """ + users = get_users(expr) + + res = {} + + for node in users.keys(): + res[node] = _recursively_get_all_users(users, node) + return res + + def tag_user_nodes( graph: Mapping[ArrayOrNames, Set[ArrayOrNames]], tag: Any, diff --git a/test/test_pytato.py b/test/test_pytato.py index 3a698c3a9..9c85fe991 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -608,8 +608,7 @@ def test_nodecountmapper(): axis_len=axis_len, use_numpy=False) dag = make_random_dag(rdagc) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): @@ -884,6 +883,41 @@ def test_adv_indexing_into_zero_long_axes(): # }}} +def test_tagcountmapper(): + from testlib import RandomDAGContext, make_random_dag + from pytato.analysis import get_num_tags_of_type, get_num_nodes + from pytools.tag import Tag + + class NonExistentTag(Tag): + pass + + class ExistentTag(Tag): + pass + + seed = 199 + axis_len = 3 + + rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False) + + out = make_random_dag(rdagc_pt).tagged(ExistentTag()) + + dag = pt.make_dict_of_named_arrays({"out": out}) + + # get_num_nodes() returns an extra DictOfNamedArrays node + assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag)-1 + + assert get_num_tags_of_type(dag, NonExistentTag()) == 0 + assert get_num_tags_of_type(dag, frozenset((ExistentTag(),))) == 1 + assert get_num_tags_of_type(dag, + frozenset((ExistentTag(), NonExistentTag()))) == 0 + + a = pt.make_data_wrapper(np.arange(27)) + dag = a+a+a+a+a+a+a+a + + assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag) + + def test_expand_dims_input_validate(): a = pt.make_placeholder("x", (10, 4), dtype="float64") @@ -901,6 +935,26 @@ def test_expand_dims_input_validate(): pt.expand_dims(a, -4) +def test_materialization_counter(): + from pytato.analysis import get_num_materialized + from testlib import RandomDAGContext, make_random_dag + + seed = 1999 + axis_len = 4 + + rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False) + + out = make_random_dag(rdagc_pt) + + res = pt.make_dict_of_named_arrays({"out": out}) + res = pt.transform.materialize_with_mpms(res) + + r = get_num_materialized(res) + + assert max([v for v in r.values()]) == 6 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])