From 68eb741d1e6d70933fdba36981a4af14a2716c9b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 May 2022 13:12:27 -0500 Subject: [PATCH 01/16] add TagCountMapper --- pytato/analysis/__init__.py | 41 +++++++++++++++++++++++++++++++++++++ test/test_pytato.py | 27 ++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 7f6a93da5..bcc2bdcc3 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -33,6 +33,7 @@ ShapeType) from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper from pytato.loopy import LoopyCall +from pytools.tag import Tag if TYPE_CHECKING: from pytato.distributed import DistributedRecv, DistributedSendRefHolder @@ -47,6 +48,9 @@ .. autofunction:: get_num_nodes .. autoclass:: DirectPredecessorsGetter + +.. autoclass:: TagCountMapper +.. autofunction:: get_num_tags_of_type """ @@ -382,3 +386,40 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: return ncm.count # }}} + + +# {{{ TagCountMapper + +class TagCountMapper(CachedWalkMapper): + """ + Counts the number of nodes in a DAG that are tagged with a superset of *tags*. + + .. attribute:: count + + The number of nodes that are tagged with a superset of *tags*. + """ + + def __init__(self, tags: FrozenSet[Tag]) -> None: + super().__init__() + self._tags = tags + self.count = 0 + + def post_visit(self, expr: Any) -> None: + if hasattr(expr, "tags") and self._tags <= expr.tags: + self.count += 1 + + +def get_num_tags_of_type( + outputs: Union[Array, DictOfNamedArrays], tags: FrozenSet[Tag]) -> int: + """Returns the number of nodes in DAG *outputs* that are tagged with a + superset of *tags*.""" + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + tcm = TagCountMapper(tags) + tcm(outputs) + + return tcm.count + +# }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 00cd3e8ea..fa307d6de 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -884,6 +884,33 @@ def test_adv_indexing_into_zero_long_axes(): # }}} +def test_tagcountmapper(): + from testlib import RandomDAGContext, make_random_dag + from pytato.analysis import get_num_tags_of_type, get_num_nodes + from pytools.tag import Tag + + class NonExistentTag(Tag): + pass + + class ExistentTag(Tag): + pass + + seed = 199 + axis_len = 3 + + rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False) + + dag = pt.make_dict_of_named_arrays( + {"out": make_random_dag(rdagc_pt).tagged(ExistentTag())}) + + assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag)-1 + assert get_num_tags_of_type(dag, frozenset((NonExistentTag(),))) == 0 + assert get_num_tags_of_type(dag, frozenset((ExistentTag(),))) == 1 + assert get_num_tags_of_type(dag, + frozenset((ExistentTag(), NonExistentTag()))) == 0 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From 7949b1912dfeaadbb03e05fc290ffb6cc4f0f5c9 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 May 2022 17:45:04 -0500 Subject: [PATCH 02/16] clarify doc --- pytato/analysis/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index bcc2bdcc3..cac6fc452 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -392,14 +392,14 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: class TagCountMapper(CachedWalkMapper): """ - Counts the number of nodes in a DAG that are tagged with a superset of *tags*. + Counts the number of nodes in a DAG that are tagged with all the tags in *tags*. .. attribute:: count - The number of nodes that are tagged with a superset of *tags*. + The number of nodes that are tagged with all the tags in *tags*. """ - def __init__(self, tags: FrozenSet[Tag]) -> None: + def __init__(self, tags: Union[Tag, Iterable[Tag]]) -> None: super().__init__() self._tags = tags self.count = 0 @@ -410,10 +410,10 @@ def post_visit(self, expr: Any) -> None: def get_num_tags_of_type( - outputs: Union[Array, DictOfNamedArrays], tags: FrozenSet[Tag]) -> int: - """Returns the number of nodes in DAG *outputs* that are tagged with a - superset of *tags*.""" - + outputs: Union[Array, DictOfNamedArrays], + tags: Union[Tag, Iterable[Tag]]) -> int: + """Returns the number of nodes in DAG *outputs* that are tagged with + all the tags in *tags*.""" from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) From 4a578d8312c2545ca083bbc2448b8f0ca2ca1fdd Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 May 2022 17:45:28 -0500 Subject: [PATCH 03/16] small code cleanups --- pytato/analysis/__init__.py | 6 +++--- test/test_pytato.py | 9 ++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index cac6fc452..d4ec18778 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -401,11 +401,13 @@ class TagCountMapper(CachedWalkMapper): def __init__(self, tags: Union[Tag, Iterable[Tag]]) -> None: super().__init__() + if isinstance(tags, Tag): + tags = frozenset((tags,)) self._tags = tags self.count = 0 def post_visit(self, expr: Any) -> None: - if hasattr(expr, "tags") and self._tags <= expr.tags: + if isinstance(expr, Array) and self._tags <= expr.tags: self.count += 1 @@ -414,8 +416,6 @@ def get_num_tags_of_type( tags: Union[Tag, Iterable[Tag]]) -> int: """Returns the number of nodes in DAG *outputs* that are tagged with all the tags in *tags*.""" - from pytato.codegen import normalize_outputs - outputs = normalize_outputs(outputs) tcm = TagCountMapper(tags) tcm(outputs) diff --git a/test/test_pytato.py b/test/test_pytato.py index fa307d6de..556e7869e 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -901,11 +901,14 @@ class ExistentTag(Tag): rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), axis_len=axis_len, use_numpy=False) - dag = pt.make_dict_of_named_arrays( - {"out": make_random_dag(rdagc_pt).tagged(ExistentTag())}) + out = make_random_dag(rdagc_pt).tagged(ExistentTag()) + dag = pt.make_dict_of_named_arrays({"out": out}) + + # get_num_nodes() returns an extra DictOfNamedArrays node assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag)-1 - assert get_num_tags_of_type(dag, frozenset((NonExistentTag(),))) == 0 + + assert get_num_tags_of_type(dag, NonExistentTag()) == 0 assert get_num_tags_of_type(dag, frozenset((ExistentTag(),))) == 1 assert get_num_tags_of_type(dag, frozenset((ExistentTag(), NonExistentTag()))) == 0 From 9a79943f051c9f762501e86cc673aea8560f9eac Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 May 2022 17:47:17 -0500 Subject: [PATCH 04/16] lint fixes --- pytato/analysis/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index d4ec18778..e9d87c646 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,7 +26,7 @@ """ from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - TYPE_CHECKING) + TYPE_CHECKING, Iterable) from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, @@ -403,6 +403,8 @@ def __init__(self, tags: Union[Tag, Iterable[Tag]]) -> None: super().__init__() if isinstance(tags, Tag): tags = frozenset((tags,)) + elif not isinstance(tags, frozenset): + tags = frozenset(tags) self._tags = tags self.count = 0 @@ -417,6 +419,9 @@ def get_num_tags_of_type( """Returns the number of nodes in DAG *outputs* that are tagged with all the tags in *tags*.""" + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + tcm = TagCountMapper(tags) tcm(outputs) From 80dbca59ab3f363ad85de8b37275150e883ef3c8 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 May 2022 19:14:41 -0500 Subject: [PATCH 05/16] use a CombineMapper --- pytato/analysis/__init__.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e9d87c646..5853db615 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -31,7 +31,7 @@ DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, ShapeType) -from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper +from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper, CombineMapper from pytato.loopy import LoopyCall from pytools.tag import Tag @@ -390,13 +390,9 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: # {{{ TagCountMapper -class TagCountMapper(CachedWalkMapper): +class TagCountMapper(CombineMapper[int]): """ - Counts the number of nodes in a DAG that are tagged with all the tags in *tags*. - - .. attribute:: count - - The number of nodes that are tagged with all the tags in *tags*. + Returns the number of nodes in a DAG that are tagged with all the tags in *tags*. """ def __init__(self, tags: Union[Tag, Iterable[Tag]]) -> None: @@ -406,11 +402,23 @@ def __init__(self, tags: Union[Tag, Iterable[Tag]]) -> None: elif not isinstance(tags, frozenset): tags = frozenset(tags) self._tags = tags - self.count = 0 - def post_visit(self, expr: Any) -> None: + def combine(self, *args: int) -> int: + from functools import reduce + return reduce(lambda a, b: a + b, args, 0) + + # type-ignore reason: incompatible return type with super class + def rec(self, expr: ArrayOrNames) -> int: # type: ignore + if expr in self.cache: + return self.cache[expr] + if isinstance(expr, Array) and self._tags <= expr.tags: - self.count += 1 + result = 1 + else: + result = 0 + + self.cache[expr] = result + super().rec(expr) + return self.cache[expr] def get_num_tags_of_type( @@ -423,8 +431,7 @@ def get_num_tags_of_type( outputs = normalize_outputs(outputs) tcm = TagCountMapper(tags) - tcm(outputs) - return tcm.count + return tcm(outputs) # }}} From 9ca88bc68e51f1aa0a45315b7fcf5889ddc57265 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 May 2022 23:16:27 -0500 Subject: [PATCH 06/16] simplify sum Co-authored-by: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com> --- pytato/analysis/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5853db615..f6614ae1f 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -404,8 +404,7 @@ def __init__(self, tags: Union[Tag, Iterable[Tag]]) -> None: self._tags = tags def combine(self, *args: int) -> int: - from functools import reduce - return reduce(lambda a, b: a + b, args, 0) + return sum(args) # type-ignore reason: incompatible return type with super class def rec(self, expr: ArrayOrNames) -> int: # type: ignore From fbdfd6ea77389249b6508b1c664119b09fa52a6f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 May 2022 23:30:25 -0500 Subject: [PATCH 07/16] add another test --- test/test_pytato.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 85a89e123..e8c638d61 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -913,6 +913,12 @@ class ExistentTag(Tag): assert get_num_tags_of_type(dag, frozenset((ExistentTag(), NonExistentTag()))) == 0 + a = pt.make_data_wrapper(np.arange(27)) + dag = a+a+a+a+a+a+a+a + + # get_num_nodes() returns an extra DictOfNamedArrays node + assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag)-1 + def test_expand_dims_input_validate(): a = pt.make_placeholder("x", (10, 4), dtype="float64") From 29281d9a776545b4a71767ff41e3bed2d4a1cf86 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 24 May 2022 00:22:03 -0500 Subject: [PATCH 08/16] set cache to zero --- pytato/analysis/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index f6614ae1f..9c71c0a80 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -412,12 +412,12 @@ def rec(self, expr: ArrayOrNames) -> int: # type: ignore return self.cache[expr] if isinstance(expr, Array) and self._tags <= expr.tags: - result = 1 + result = 1 + super().rec(expr) else: - result = 0 + result = 0 + super().rec(expr) - self.cache[expr] = result + super().rec(expr) - return self.cache[expr] + self.cache[expr] = 0 + return result def get_num_tags_of_type( From 33d7d1e307e9da602677331db0a4dc54a8999fa7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 25 May 2022 10:06:50 -0500 Subject: [PATCH 09/16] remove normalize_outputs --- pytato/analysis/__init__.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 2644a692f..2a5515311 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -375,9 +375,6 @@ def post_visit(self, expr: Any) -> None: def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """Returns the number of nodes in DAG *outputs*.""" - from pytato.codegen import normalize_outputs - outputs = normalize_outputs(outputs) - ncm = NodeCountMapper() ncm(outputs) @@ -424,9 +421,6 @@ def get_num_tags_of_type( """Returns the number of nodes in DAG *outputs* that are tagged with all the tags in *tags*.""" - from pytato.codegen import normalize_outputs - outputs = normalize_outputs(outputs) - tcm = TagCountMapper(tags) return tcm(outputs) From 68f159e7ae8f0f7aa0fefeeb29b1f482c5621af3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 25 May 2022 10:30:37 -0500 Subject: [PATCH 10/16] fix tests --- test/test_pytato.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index e8c638d61..f74356ee3 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -608,8 +608,7 @@ def test_nodecountmapper(): axis_len=axis_len, use_numpy=False) dag = make_random_dag(rdagc) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): @@ -916,8 +915,7 @@ class ExistentTag(Tag): a = pt.make_data_wrapper(np.arange(27)) dag = a+a+a+a+a+a+a+a - # get_num_nodes() returns an extra DictOfNamedArrays node - assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag)-1 + assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag) def test_expand_dims_input_validate(): From 8296c0410ee3cae02654478e6e6d4a87b546a635 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 May 2022 18:20:06 -0500 Subject: [PATCH 11/16] materialize_with_mpms: print number of materialized nodes --- pytato/transform/__init__.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index c60c0e79c..f98882bce 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1030,8 +1030,8 @@ def _materialize_if_mpms(expr: Array, ) -> MPMSMaterializerAccumulator: """ Returns an instance of :class:`MPMSMaterializerAccumulator`, that - materializes *expr* if it has more than 1 successors and more than 1 - materialized predecessors. + materializes *expr* if it has more than 1 successor and more than 1 + materialized predecessor. """ from functools import reduce @@ -1250,8 +1250,8 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: .. note:: - MPMS materialization strategy is a greedy materialization algorithm in - which any node with more than 1 materialized predecessors and more than - 1 successors is materialized. + which any node with more than 1 materialized predecessor and more than + 1 successor is materialized. - Materializing here corresponds to tagging a node with :class:`~pytato.tags.ImplStored`. - Does not attempt to materialize sub-expressions in @@ -1292,13 +1292,19 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: ====== ======== ======= """ - from pytato.analysis import get_nusers + from pytato.analysis import get_nusers, get_num_nodes, get_num_tags_of_type materializer = MPMSMaterializer(get_nusers(expr)) new_data = {} for name, ary in expr.items(): new_data[name] = materializer(ary.expr).expr - return DictOfNamedArrays(new_data) + res = DictOfNamedArrays(new_data) + + logger.info("materialize_with_mpms: materialized", + get_num_tags_of_type(res, ImplStored()), "out of", + get_num_nodes(res), "nodes") + + return res # }}} From 2eb85d4bea7691f12bdf49516f9068340c1afe6a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 May 2022 18:29:04 -0500 Subject: [PATCH 12/16] fix string --- pytato/transform/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index f98882bce..6fed4f02d 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1300,9 +1300,9 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: res = DictOfNamedArrays(new_data) - logger.info("materialize_with_mpms: materialized", - get_num_tags_of_type(res, ImplStored()), "out of", - get_num_nodes(res), "nodes") + logger.info("materialize_with_mpms: materialized " + f"{get_num_tags_of_type(res, ImplStored())} out of " + f"{get_num_nodes(res)} nodes") return res From 262787c68c32e799ced65c0d3eb1440ce2a50c83 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 1 Jun 2022 11:42:26 -0500 Subject: [PATCH 13/16] use DEBUG_ENABLED --- pytato/transform/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 6fed4f02d..d118b7923 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1300,9 +1300,11 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: res = DictOfNamedArrays(new_data) - logger.info("materialize_with_mpms: materialized " - f"{get_num_tags_of_type(res, ImplStored())} out of " - f"{get_num_nodes(res)} nodes") + from pytato import DEBUG_ENABLED + if DEBUG_ENABLED: + logger.info("materialize_with_mpms: materialized " + f"{get_num_tags_of_type(res, ImplStored())} out of " + f"{get_num_nodes(res)} nodes") return res From e65c25d53cefac996efdd05ca914f9fcf26c420a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 1 Jun 2022 14:54:31 -0500 Subject: [PATCH 14/16] add materialization counter --- pytato/analysis/__init__.py | 25 +++++++++++++++++++++++++ pytato/transform/__init__.py | 14 ++++++++++++++ test/test_pytato.py | 20 ++++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 2a5515311..a7a04f34c 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -34,6 +34,7 @@ from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper, CombineMapper from pytato.loopy import LoopyCall from pytools.tag import Tag +from pytato.tags import ImplStored if TYPE_CHECKING: from pytato.distributed import DistributedRecv, DistributedSendRefHolder @@ -426,3 +427,27 @@ def get_num_tags_of_type( return tcm(outputs) # }}} + + +def get_num_materialized(outputs: Union[Array, DictOfNamedArrays]) \ + -> Dict[Array, int]: + """Returns the number of materialized nodes each node in *outputs* depends on.""" + from pytato.transform import rec_get_all_user_nodes + users = rec_get_all_user_nodes(outputs) + + def is_materialized(expr: ArrayOrNames) -> bool: + if (isinstance(expr, Array) and + any(isinstance(tag, ImplStored) for tag in expr.tags)): + return True + else: + return False + + res = {} + + for node in users.keys(): + if is_materialized(node): + for user in users[node]: + res.setdefault(user, 0) + res[user] += 1 + + return res diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index d118b7923..0f57786bb 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -86,6 +86,7 @@ .. autofunction:: reverse_graph .. autofunction:: tag_user_nodes .. autofunction:: rec_get_user_nodes +.. autofunction:: rec_get_all_user_nodes .. autofunction:: deduplicate_data_wrappers @@ -1509,6 +1510,19 @@ def rec_get_user_nodes(expr: ArrayOrNames, return _recursively_get_all_users(users, node) +def rec_get_all_user_nodes(expr: ArrayOrNames) -> FrozenSet[ArrayOrNames]: + """ + Returns all direct and indirect users of all nodes in *expr*. + """ + users = get_users(expr) + + res = {} + + for node in users.keys(): + res[node] = _recursively_get_all_users(users, node) + return res + + def tag_user_nodes( graph: Mapping[ArrayOrNames, Set[ArrayOrNames]], tag: Any, diff --git a/test/test_pytato.py b/test/test_pytato.py index f74356ee3..9c85fe991 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -935,6 +935,26 @@ def test_expand_dims_input_validate(): pt.expand_dims(a, -4) +def test_materialization_counter(): + from pytato.analysis import get_num_materialized + from testlib import RandomDAGContext, make_random_dag + + seed = 1999 + axis_len = 4 + + rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False) + + out = make_random_dag(rdagc_pt) + + res = pt.make_dict_of_named_arrays({"out": out}) + res = pt.transform.materialize_with_mpms(res) + + r = get_num_materialized(res) + + assert max([v for v in r.values()]) == 6 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From aaf2bf833a48f7413a4924cedb0c5dc5b216c5d9 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 1 Jun 2022 14:56:32 -0500 Subject: [PATCH 15/16] add to doc --- pytato/analysis/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index a7a04f34c..95832aa58 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -52,6 +52,8 @@ .. autoclass:: TagCountMapper .. autofunction:: get_num_tags_of_type + +.. autofunction:: get_num_materialized """ From bfe8d0b892d91fbb6052402bf668f253986f27f2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 1 Jun 2022 15:00:40 -0500 Subject: [PATCH 16/16] mypy fixes --- pytato/analysis/__init__.py | 4 ++-- pytato/transform/__init__.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 95832aa58..b94ca5b10 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -432,7 +432,7 @@ def get_num_tags_of_type( def get_num_materialized(outputs: Union[Array, DictOfNamedArrays]) \ - -> Dict[Array, int]: + -> Dict[ArrayOrNames, int]: """Returns the number of materialized nodes each node in *outputs* depends on.""" from pytato.transform import rec_get_all_user_nodes users = rec_get_all_user_nodes(outputs) @@ -444,7 +444,7 @@ def is_materialized(expr: ArrayOrNames) -> bool: else: return False - res = {} + res: Dict[ArrayOrNames, int] = {} for node in users.keys(): if is_materialized(node): diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 0f57786bb..1c69e2f47 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1510,7 +1510,8 @@ def rec_get_user_nodes(expr: ArrayOrNames, return _recursively_get_all_users(users, node) -def rec_get_all_user_nodes(expr: ArrayOrNames) -> FrozenSet[ArrayOrNames]: +def rec_get_all_user_nodes(expr: ArrayOrNames) \ + -> Dict[ArrayOrNames, FrozenSet[ArrayOrNames]]: """ Returns all direct and indirect users of all nodes in *expr*. """