diff --git a/.mypy/baseline.json b/.mypy/baseline.json index 78ead0261..45404c815 100644 --- a/.mypy/baseline.json +++ b/.mypy/baseline.json @@ -192,7 +192,7 @@ "column": 27, "message": "Need type annotation for \"index_to_axis_length\"", "offset": 46, - "src": "index_to_axis_length = immutabledict(index_to_axis_length_dict)", + "src": "index_to_axis_length = constantdict(index_to_axis_length_dict)", "target": "pytato.array._normalize_einsum_in_subscript" }, { @@ -200,7 +200,7 @@ "column": 27, "message": "See https://kotlinisland.github.io/basedmypy/_refs.html#code-var-annotated for more info", "offset": 0, - "src": "index_to_axis_length = immutabledict(index_to_axis_length_dict)", + "src": "index_to_axis_length = constantdict(index_to_axis_length_dict)", "target": null }, { @@ -208,7 +208,7 @@ "column": 21, "message": "Need type annotation for \"index_to_descr\"", "offset": 1, - "src": "index_to_descr = immutabledict(index_to_descr_dict)", + "src": "index_to_descr = constantdict(index_to_descr_dict)", "target": "pytato.array._normalize_einsum_in_subscript" }, { @@ -840,7 +840,7 @@ "column": 20, "message": "\"Call\" gets multiple values for keyword argument \"tags\"", "offset": 230, - "src": "call_site = Call(self, bindings=immutabledict(kwargs),", + "src": "call_site = Call(self, bindings=constantdict(kwargs),", "target": "pytato.function.FunctionDefinition.__call__" }, { @@ -848,7 +848,7 @@ "column": 20, "message": "Missing positional argument \"function\" in call to \"Call\"", "offset": 0, - "src": "call_site = Call(self, bindings=immutabledict(kwargs),", + "src": "call_site = Call(self, bindings=constantdict(kwargs),", "target": "pytato.function.FunctionDefinition.__call__" }, { @@ -856,7 +856,7 @@ "column": 25, "message": "Argument 1 to \"Call\" has incompatible type \"FunctionDefinition\"; expected \"frozenset[Tag]\"", "offset": 0, - "src": "call_site = Call(self, bindings=immutabledict(kwargs),", + "src": "call_site = Call(self, bindings=constantdict(kwargs),", "target": "pytato.function.FunctionDefinition.__call__" }, { @@ -896,7 +896,7 @@ "column": 38, "message": "Too few arguments for \"__getitem__\" of \"Call\"", "offset": 3, - "src": "return immutabledict({kw: call_site[kw] for kw in self.returns})", + "src": "return constantdict({kw: call_site[kw] for kw in self.returns})", "target": "pytato.function.FunctionDefinition.__call__" }, { @@ -904,7 +904,7 @@ "column": 48, "message": "Invalid index type \"str\" for \"Call\"; expected type \"Call\"", "offset": 0, - "src": "return immutabledict({kw: call_site[kw] for kw in self.returns})", + "src": "return constantdict({kw: call_site[kw] for kw in self.returns})", "target": "pytato.function.FunctionDefinition.__call__" }, { @@ -1088,7 +1088,7 @@ { "code": "arg-type", "column": 39, - "message": "Argument 2 to \"LoopyCall\" has incompatible type \"immutabledict[str, Array | int | integer[Any] | float | complex | inexact[Any, float | complex] | bool | numpy.bool[bool]]\"; expected \"TranslationUnit\"", + "message": "Argument 2 to \"LoopyCall\" has incompatible type \"constantdict[str, Array | int | integer[Any] | float | complex | inexact[Any, float | complex] | bool | numpy.bool[bool]]\"; expected \"TranslationUnit\"", "offset": 0, "src": "return LoopyCall(translation_unit, bindings_new, entrypoint,", "target": "pytato.loopy.call_loopy" @@ -1834,9 +1834,9 @@ { "code": "arg-type", "column": 20, - "message": "Argument 2 to \"Call\" has incompatible type \"immutabledict[Never, Never]\"; expected \"FunctionDefinition\"", + "message": "Argument 2 to \"Call\" has incompatible type \"constantdict[Never, Never]\"; expected \"FunctionDefinition\"", "offset": 1, - "src": "immutabledict({name: self.rec(bnd)", + "src": "constantdict({name: self.rec(bnd)", "target": "pytato.transform.CopyMapper.map_call" }, { @@ -1970,9 +1970,9 @@ { "code": "arg-type", "column": 20, - "message": "Argument 2 to \"Call\" has incompatible type \"immutabledict[Never, Never]\"; expected \"FunctionDefinition\"", + "message": "Argument 2 to \"Call\" has incompatible type \"constantdict[Never, Never]\"; expected \"FunctionDefinition\"", "offset": 1, - "src": "immutabledict({name: self.rec(bnd, *args, **kwargs)", + "src": "constantdict({name: self.rec(bnd, *args, **kwargs)", "target": "pytato.transform.CopyMapperWithExtraArgs.map_call" }, { diff --git a/doc/conf.py b/doc/conf.py index fb1367aec..90e1292e0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -28,10 +28,10 @@ "loopy": ("https://documen.tician.de/loopy/", None), "sumpy": ("https://documen.tician.de/sumpy/", None), "islpy": ("https://documen.tician.de/islpy/", None), - "jax": ("https://jax.readthedocs.io/en/latest/", None), + "jax": ("https://docs.jax.dev/en/latest/", None), "mpi4py": ("https://mpi4py.readthedocs.io/en/latest", None), - "immutabledict": ("https://immutabledict.corenting.fr/", None), - "orderedsets": ("https://matthiasdiener.github.io/orderedsets", None), + "constantdict": ("https://matthiasdiener.github.io/constantdict/", None), + "orderedsets": ("https://matthiasdiener.github.io/orderedsets/", None), } # Some modules need to import things just so that sphinx can resolve symbols in diff --git a/pyproject.toml b/pyproject.toml index a9f499478..3e8d9a5e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ ] dependencies = [ "bidict", - "immutabledict", + "constantdict", "loopy>=2020.2", "pytools>=2024.1.21", "pymbolic>=2024.2", diff --git a/pytato/array.py b/pytato/array.py index 8d83ed7e2..e616c75d4 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -203,7 +203,7 @@ from warnings import warn import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict from typing_extensions import Self import pymbolic.primitives as prim @@ -735,7 +735,7 @@ def _unary_op(self, op: Any) -> Array: indices = tuple(var(f"_{i}") for i in range(self.ndim)) expr = op(var("_in0")[indices]) - bindings: Mapping[str, Array] = immutabledict({"_in0": self}) + bindings: Mapping[str, Array] = constantdict({"_in0": self}) return IndexLambda( expr=expr, shape=self.shape, @@ -744,7 +744,7 @@ def _unary_op(self, op: Any) -> Array: tags=_get_default_tags(), axes=_get_default_axes(self.ndim), non_equality_tags=_get_created_at_tag(), - var_to_reduction_descr=immutabledict()) + var_to_reduction_descr=constantdict()) __mul__ = partialmethod(_binary_op, operator.mul) __rmul__ = partialmethod(_binary_op, operator.mul, reverse=True) @@ -1086,8 +1086,8 @@ class IndexLambda(_SuppliedAxesAndTagsMixin, _SuppliedShapeAndDtypeMixin, Array) if __debug__: def __post_init__(self) -> None: - assert isinstance(self.bindings, immutabledict) - assert isinstance(self.var_to_reduction_descr, immutabledict) + assert isinstance(self.bindings, constantdict) + assert isinstance(self.var_to_reduction_descr, constantdict) super().__post_init__() def with_tagged_reduction(self, @@ -1114,7 +1114,7 @@ def with_tagged_reduction(self, f" '{self.var_to_reduction_descr.keys()}'," f" got '{reduction_variable}'.") - assert isinstance(self.var_to_reduction_descr, immutabledict) + assert isinstance(self.var_to_reduction_descr, constantdict) new_var_to_redn_descr = dict(self.var_to_reduction_descr) new_var_to_redn_descr[reduction_variable] = \ self.var_to_reduction_descr[reduction_variable].tagged(tags) @@ -1124,7 +1124,7 @@ def with_tagged_reduction(self, dtype=self.dtype, bindings=self.bindings, axes=self.axes, - var_to_reduction_descr=immutabledict + var_to_reduction_descr=constantdict (new_var_to_redn_descr), tags=self.tags, non_equality_tags=self.non_equality_tags) @@ -1203,7 +1203,7 @@ class Einsum(_SuppliedAxesAndTagsMixin, Array): if __debug__: def __post_init__(self) -> None: - assert isinstance(self.redn_axis_to_redn_descr, immutabledict) + assert isinstance(self.redn_axis_to_redn_descr, constantdict) super().__post_init__() @memoize_method @@ -1230,7 +1230,7 @@ def _access_descr_to_axis_len(self else: descr_to_axis_len[descr] = arg_axis_len - return immutabledict(descr_to_axis_len) + return constantdict(descr_to_axis_len) @cached_property def shape(self) -> ShapeType: @@ -1278,7 +1278,7 @@ def with_tagged_reduction(self, # }}} - assert isinstance(self.redn_axis_to_redn_descr, immutabledict) + assert isinstance(self.redn_axis_to_redn_descr, constantdict) new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr) new_redn_axis_to_redn_descr[redn_axis] = \ self.redn_axis_to_redn_descr[redn_axis].tagged(tags) @@ -1286,7 +1286,7 @@ def with_tagged_reduction(self, return type(self)(access_descriptors=self.access_descriptors, args=self.args, axes=self.axes, - redn_axis_to_redn_descr=immutabledict + redn_axis_to_redn_descr=constantdict (new_redn_axis_to_redn_descr), tags=self.tags, non_equality_tags=self.non_equality_tags, @@ -1296,7 +1296,7 @@ def with_tagged_reduction(self, EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P[a-zA-Z])|(?P\.\.\.))\s*") -def _normalize_einsum_out_subscript(subscript: str) -> immutabledict[str, +def _normalize_einsum_out_subscript(subscript: str) -> constantdict[str, EinsumAxisDescriptor]: """ Normalizes the output subscript of an einsum (provided in the explicit @@ -1336,7 +1336,7 @@ def _normalize_einsum_out_subscript(subscript: str) -> immutabledict[str, raise ValueError("Used an input more than once to refer to the" f" output axis in '{subscript}") - return immutabledict({idx: EinsumElementwiseAxis(i) + return constantdict({idx: EinsumElementwiseAxis(i) for i, idx in enumerate(normalized_indices)}) @@ -1347,9 +1347,9 @@ def _normalize_einsum_in_subscript(subscript: str, index_to_axis_length: Mapping[str, ShapeComponent], ) -> tuple[tuple[EinsumAxisDescriptor, ...], - immutabledict + constantdict [str, EinsumAxisDescriptor], - immutabledict[str, ShapeComponent]]: + constantdict[str, ShapeComponent]]: """ Normalizes the subscript for an input operand in an einsum. Returns ``(access_descrs, updated_index_to_descr, updated_to_index_to_axis_length)``, @@ -1423,8 +1423,8 @@ def _normalize_einsum_in_subscript(subscript: str, in_operand_axis_descrs.append(index_to_descr_dict[index_char]) - index_to_axis_length = immutabledict(index_to_axis_length_dict) - index_to_descr = immutabledict(index_to_descr_dict) + index_to_axis_length = constantdict(index_to_axis_length_dict) + index_to_descr = constantdict(index_to_descr_dict) return (tuple(in_operand_axis_descrs), index_to_descr, index_to_axis_length) @@ -1459,7 +1459,7 @@ def einsum(subscripts: str, *operands: Array, ) index_to_descr = _normalize_einsum_out_subscript(out_spec) - index_to_axis_length: Mapping[str, ShapeComponent] = immutabledict() + index_to_axis_length: Mapping[str, ShapeComponent] = constantdict() access_descriptors = [] for in_spec, in_operand in zip(in_specs, operands, strict=True): @@ -1494,7 +1494,7 @@ def einsum(subscripts: str, *operands: Array, if isinstance(descr, EinsumElementwiseAxis)}) ), - redn_axis_to_redn_descr=immutabledict(redn_axis_to_redn_descr), + redn_axis_to_redn_descr=constantdict(redn_axis_to_redn_descr), non_equality_tags=_get_created_at_tag(), ) @@ -2349,11 +2349,11 @@ def full(shape: ConvertibleToShape, fill_value: Scalar | prim.NaN, return IndexLambda(expr=cast("ArithmeticExpression", fill_value), shape=shape, dtype=conv_dtype, - bindings=immutabledict(), + bindings=constantdict(), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), axes=_get_default_axes(len(shape)), - var_to_reduction_descr=immutabledict()) + var_to_reduction_descr=constantdict()) def zeros(shape: ConvertibleToShape, dtype: Any = float, @@ -2396,11 +2396,11 @@ def eye(N: int, M: int | None = None, k: int = 0, # noqa: N803 raise ValueError(f"k must be int, got {type(k)}.") return IndexLambda(expr=prim.If(parse(f"(_1 - _0) == {k}"), 1, 0), - shape=(N, M), dtype=dtype, bindings=immutabledict({}), + shape=(N, M), dtype=dtype, bindings=constantdict({}), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), axes=_get_default_axes(2), - var_to_reduction_descr=immutabledict()) + var_to_reduction_descr=constantdict()) # }}} @@ -2494,11 +2494,11 @@ def arange(*args: Any, **kwargs: Any) -> Array: from pymbolic.primitives import Variable return IndexLambda(expr=start + Variable("_0") * step, - shape=(size,), dtype=dtype, bindings=immutabledict(), + shape=(size,), dtype=dtype, bindings=constantdict(), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), axes=_get_default_axes(1), - var_to_reduction_descr=immutabledict()) + var_to_reduction_descr=constantdict()) # }}} @@ -2623,7 +2623,7 @@ def logical_not(x: ArrayOrScalar) -> Array | bool: tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), axes=_get_default_axes(len(x.shape)), - var_to_reduction_descr=immutabledict()) + var_to_reduction_descr=constantdict()) # }}} @@ -2674,11 +2674,11 @@ def where(condition: ArrayOrScalar, expr=prim.If(expr1, expr2, expr3), shape=result_shape, dtype=dtype, - bindings=immutabledict(bindings), + bindings=constantdict(bindings), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), axes=_get_default_axes(len(result_shape)), - var_to_reduction_descr=immutabledict()) + var_to_reduction_descr=constantdict()) # }}} @@ -2771,13 +2771,13 @@ def make_index_lambda( # }}} return IndexLambda(expr=expression, - bindings=immutabledict(bindings), + bindings=constantdict(bindings), shape=shape, dtype=dtype, tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), axes=_get_default_axes(len(shape)), - var_to_reduction_descr=immutabledict + var_to_reduction_descr=constantdict (processed_var_to_reduction_descr)) # }}} @@ -2859,11 +2859,11 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array: shape)), shape=shape, dtype=array.dtype, - bindings=immutabledict({"in": array}), + bindings=constantdict({"in": array}), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), axes=_get_default_axes(len(shape)), - var_to_reduction_descr=immutabledict()) + var_to_reduction_descr=constantdict()) # }}} diff --git a/pytato/cmath.py b/pytato/cmath.py index ec981aa32..82872e7c1 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -60,7 +60,7 @@ from typing import TYPE_CHECKING, cast import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict import pymbolic.primitives as prim from pymbolic import Scalar, var @@ -129,11 +129,11 @@ def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...], return IndexLambda( expr=prim.Call(var(f"pytato.c99.{func_name}"), tuple(sym_args)), - shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings), + shape=shape, dtype=ret_dtype, bindings=constantdict(bindings), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(stacklevel=2), axes=_get_default_axes(len(shape)), - var_to_reduction_descr=immutabledict(), + var_to_reduction_descr=constantdict(), ) diff --git a/pytato/codegen.py b/pytato/codegen.py index cb957f076..9e1f33ff9 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -40,7 +40,7 @@ import dataclasses from typing import TYPE_CHECKING, Any -from immutabledict import immutabledict +from constantdict import constantdict import loopy as lp from pymbolic.mapper.optimize import optimize_mapper @@ -211,7 +211,7 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: # }}} - bindings: Mapping[str, Any] = immutabledict( + bindings: Mapping[str, Any] = constantdict( {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) @@ -327,7 +327,7 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult: for out in outputs.values())) # only look for dependencies between the outputs - deps: Mapping[str, Any] = immutabledict({name: get_deps(output.expr) + deps: Mapping[str, Any] = constantdict({name: get_deps(output.expr) for name, output in outputs.items()}) # represent deps in terms of output names diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 4b74bf02b..13c644af3 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -73,7 +73,7 @@ cast, ) -from immutabledict import immutabledict +from constantdict import constantdict from orderedsets import FrozenOrderedSet, OrderedSet from pymbolic.mapper.optimize import optimize_mapper @@ -372,11 +372,11 @@ def _make_distributed_partition( partition_input_names=frozenset( comm_replacer.partition_input_name_to_placeholder.keys()), output_names=frozenset(name_to_part_output.keys()), - name_to_recv_node=immutabledict({ + name_to_recv_node=constantdict({ recvd_ary_to_name[local_recv_id_to_recv_node[recv_id]]: local_recv_id_to_recv_node[recv_id] for recv_id in comm_ids.recv_ids}), - name_to_send_nodes=immutabledict(name_to_send_nodes)) + name_to_send_nodes=constantdict(name_to_send_nodes)) result = DistributedGraphPartition( parts=parts, diff --git a/pytato/function.py b/pytato/function.py index c96070bf5..3a6644c9d 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -72,7 +72,7 @@ TypeVar, ) -from immutabledict import immutabledict +from constantdict import constantdict from pytools import memoize_method from pytools.tag import Tag, Taggable @@ -162,7 +162,7 @@ class FunctionDefinition(Taggable): if __debug__: def __post_init__(self) -> None: - assert isinstance(self.returns, immutabledict) + assert isinstance(self.returns, constantdict) @cached_property def _placeholders(self) -> Mapping[str, Placeholder]: @@ -185,7 +185,7 @@ def _placeholders(self) -> Mapping[str, Placeholder]: f"Found non-argument placeholder '{next(iter(extra_pl_names))}' " \ "in function definition." - return immutabledict({arg.name: arg for arg in all_placeholders}) + return constantdict({arg.name: arg for arg in all_placeholders}) def get_placeholder(self, name: str) -> Placeholder: """ @@ -227,7 +227,7 @@ def __call__(self, **kwargs: Array # }}} - call_site = Call(self, bindings=immutabledict(kwargs), + call_site = Call(self, bindings=constantdict(kwargs), tags=_get_default_tags()) if self.return_type == ReturnType.ARRAY: @@ -236,7 +236,7 @@ def __call__(self, **kwargs: Array return tuple(call_site[f"_{iarg}"] for iarg in range(len(self.returns))) elif self.return_type == ReturnType.DICT_OF_ARRAYS: - return immutabledict({kw: call_site[kw] for kw in self.returns}) + return constantdict({kw: call_site[kw] for kw in self.returns}) else: raise NotImplementedError(self.return_type) @@ -323,7 +323,7 @@ def __post_init__(self) -> None: # check that the invocation parameters and the function definition # parameters agree with each other. assert frozenset(self.bindings) == self.function.parameters - assert isinstance(self.bindings, immutabledict) + assert isinstance(self.bindings, constantdict) super().__post_init__() def __contains__(self, name: object) -> bool: @@ -422,7 +422,7 @@ def trace_call(f: Callable[..., ReturnT], function = FunctionDefinition( frozenset(pl_arg.name for pl_arg in pl_args) | frozenset(pl_kwargs), return_type, - immutabledict(returns), + constantdict(returns), tags=_get_default_tags() | (frozenset([FunctionIdentifier(identifier)]) if identifier else frozenset()) diff --git a/pytato/loopy.py b/pytato/loopy.py index db7bfd22f..5858ed7c1 100644 --- a/pytato/loopy.py +++ b/pytato/loopy.py @@ -35,7 +35,7 @@ import islpy as isl import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict import loopy as lp import pymbolic.primitives as prim @@ -108,7 +108,7 @@ class LoopyCall(AbstractResultWithNamedArrays): copy = dataclasses.replace def __post_init__(self) -> None: - assert isinstance(self.bindings, immutabledict) + assert isinstance(self.bindings, constantdict) super().__post_init__() @property @@ -243,7 +243,7 @@ def call_loopy(translation_unit: lp.TranslationUnit, # {{{ perform shape inference here bindings_new = extend_bindings_with_shape_inference(translation_unit[entrypoint], - immutabledict(bindings)) + constantdict(bindings)) del bindings # }}} @@ -412,7 +412,7 @@ def _get_pt_dim_expr(dim: Integer | Array) -> ScalarExpression: def extend_bindings_with_shape_inference(knl: lp.LoopKernel, bindings: Mapping[str, ArrayOrScalar] - ) -> immutabledict[str, ArrayOrScalar]: + ) -> constantdict[str, ArrayOrScalar]: from functools import reduce from loopy.kernel.array import ArrayBase @@ -538,7 +538,7 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel, bindings_dict[var] = val_sp - return immutabledict(bindings_dict) + return constantdict(bindings_dict) # }}} diff --git a/pytato/raising.py b/pytato/raising.py index 13efb737c..90cee5b4c 100644 --- a/pytato/raising.py +++ b/pytato/raising.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict import pymbolic.primitives as p @@ -346,7 +346,7 @@ def index_lambda_to_high_level_op(expr: IndexLambda) -> HighLevelOp: .expr .inner_expr .aggregate.name], - axes=immutabledict({i: idx.name + axes=constantdict({i: idx.name for i, idx in enumerate(expr .expr .inner_expr diff --git a/pytato/reductions.py b/pytato/reductions.py index 0d2c5fc1e..13d6c3197 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -33,7 +33,7 @@ from typing import TYPE_CHECKING, Any import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict import pymbolic.primitives as prim from pymbolic import ArithmeticExpression @@ -225,7 +225,7 @@ def _get_reduction_indices_bounds(shape: ShapeType, indices.append(prim.Variable(f"_{n_out_dims}")) n_out_dims += 1 - return indices, immutabledict(redn_bounds) + return indices, constantdict(redn_bounds) def _get_var_to_redn_descr( @@ -269,7 +269,7 @@ def _get_var_to_redn_descr( var_to_redn_descr[idx] = redn_descr n_redn_dims += 1 - return immutabledict(var_to_redn_descr) + return constantdict(var_to_redn_descr) def _make_reduction_lambda( diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 1bc2beaac..053392c9b 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -52,7 +52,7 @@ ) import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict from typing_extensions import TypeIs import pymbolic.primitives as prim @@ -136,7 +136,7 @@ def map_reduce(self, cast("ArithmeticExpression", self.rec(expr.inner_expr, *args, **kwargs)), expr.op, - immutabledict({ + constantdict({ name: ( self.rec(lower, *args, **kwargs), self.rec(upper, *args, **kwargs) @@ -157,7 +157,7 @@ def map_reduce(self, expr: Reduce) -> ScalarExpression: assert not isinstance(inner_expr, tuple) return Reduce(inner_expr, op=expr.op, - bounds=immutabledict( + bounds=constantdict( {name: self.rec(bound) for name, bound in expr.bounds.items()})) diff --git a/pytato/stringifier.py b/pytato/stringifier.py index 17c915778..d74a7d341 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, cast import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict from pytato.array import ( Array, @@ -99,7 +99,7 @@ def __call__(self, expr: Any, depth: int = 0) -> str: def map_foreign(self, expr: Any, depth: int) -> str: if isinstance(expr, tuple): return "(" + ", ".join(self.rec(el, depth) for el in expr) + ")" - elif isinstance(expr, dict | immutabledict): + elif isinstance(expr, dict | constantdict): return ("{" + ", ".join(f"{key!r}: {self.rec(val, depth)}" for key, val diff --git a/pytato/target/loopy/__init__.py b/pytato/target/loopy/__init__.py index 2f16e3fa3..2055408b5 100644 --- a/pytato/target/loopy/__init__.py +++ b/pytato/target/loopy/__init__.py @@ -55,7 +55,7 @@ from typing import TYPE_CHECKING, Any import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict import loopy @@ -202,7 +202,7 @@ def _get_processed_bound_arguments( f" Got {type(bnd_arg).__name__} for '{name}'." ) from None - result: Mapping[str, Any] = immutabledict(proc_bnd_args) + result: Mapping[str, Any] = constantdict(proc_bnd_args) assert set(result.keys()) == set(self.bound_arguments.keys()) self._processed_bound_args_cache[cache_key] = result return result diff --git a/pytato/target/python/numpy_like.py b/pytato/target/python/numpy_like.py index f8bb080be..b56756aa4 100644 --- a/pytato/target/python/numpy_like.py +++ b/pytato/target/python/numpy_like.py @@ -37,7 +37,7 @@ ) import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict from typing_extensions import NotRequired from pytools import UniqueNameGenerator @@ -626,4 +626,4 @@ def generate_numpy_like(expr: Array | Mapping[str, Array] | DictOfNamedArrays, program, function_name, expected_arguments=frozenset(cgen_mapper.arg_names), - bound_arguments=immutabledict(cgen_mapper.bound_arguments)) + bound_arguments=constantdict(cgen_mapper.bound_arguments)) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 28d423628..0d226c8ad 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -42,7 +42,7 @@ ) import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper @@ -784,7 +784,7 @@ def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] for s in situp) def map_index_lambda(self, expr: IndexLambda) -> Array: - bindings: Mapping[str, Array] = immutabledict({ + bindings: Mapping[str, Array] = constantdict({ name: self.rec(subexpr) for name, subexpr in sorted(expr.bindings.items())}) return IndexLambda(expr=expr.expr, @@ -892,7 +892,7 @@ def map_dict_of_named_arrays(self, ) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: - bindings: Mapping[Any, Any] = immutabledict( + bindings: Mapping[Any, Any] = constantdict( {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) @@ -945,11 +945,11 @@ def map_function_definition(self, new_mapper = self.clone_for_callee(expr) new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()} - return dataclasses.replace(expr, returns=immutabledict(new_returns)) + return dataclasses.replace(expr, returns=constantdict(new_returns)) def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: return Call(self.rec_function_definition(expr.function), - immutabledict({name: self.rec(bnd) + constantdict({name: self.rec(bnd) for name, bnd in expr.bindings.items()}), tags=expr.tags, ) @@ -981,7 +981,7 @@ def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...], def map_index_lambda(self, expr: IndexLambda, *args: P.args, **kwargs: P.kwargs) -> Array: - bindings: Mapping[str, Array] = immutabledict({ + bindings: Mapping[str, Array] = constantdict({ name: self.rec(subexpr, *args, **kwargs) for name, subexpr in sorted(expr.bindings.items())}) return IndexLambda(expr=expr.expr, @@ -1107,7 +1107,7 @@ def map_dict_of_named_arrays(self, def map_loopy_call(self, expr: LoopyCall, *args: P.args, **kwargs: P.kwargs) -> LoopyCall: - bindings: Mapping[Any, Any] = immutabledict( + bindings: Mapping[Any, Any] = constantdict( {name: (self.rec(subexpr, *args, **kwargs) if isinstance(subexpr, Array) else subexpr) @@ -1169,7 +1169,7 @@ def map_function_definition( def map_call(self, expr: Call, *args: P.args, **kwargs: P.kwargs) -> AbstractResultWithNamedArrays: return Call(self.rec_function_definition(expr.function, *args, **kwargs), - immutabledict({name: self.rec(bnd, *args, **kwargs) + constantdict({name: self.rec(bnd, *args, **kwargs) for name, bnd in expr.bindings.items()}), tags=expr.tags, ) @@ -1890,7 +1890,7 @@ def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: new_expr = IndexLambda(expr=expr.expr, shape=expr.shape, dtype=expr.dtype, - bindings=immutabledict({bnd_name: bnd.expr + bindings=constantdict({bnd_name: bnd.expr for bnd_name, bnd in sorted(children_rec.items())}), axes=expr.axes, var_to_reduction_descr=expr.var_to_reduction_descr, diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 694901b03..b156b7f0d 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -38,7 +38,7 @@ from typing import TYPE_CHECKING, cast import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict from pytato.array import ( Array, @@ -231,7 +231,7 @@ def map_index_lambda(self, expr=expr.expr, shape=expr.shape, dtype=expr.dtype, - bindings=immutabledict({name: _verify_is_array(self.rec(bnd, None)) + bindings=constantdict({name: _verify_is_array(self.rec(bnd, None)) for name, bnd in sorted(expr.bindings.items())}), var_to_reduction_descr=expr.var_to_reduction_descr, tags=expr.tags, @@ -251,10 +251,10 @@ def map_einsum(self, else: ctx = _EinsumDistributiveLawMapperContext( expr.access_descriptors, - immutabledict({iarg: arg + constantdict({iarg: arg for iarg, arg in enumerate(expr.args) if iarg != distributive_law_descr.ioperand}), - immutabledict(expr.redn_axis_to_redn_descr), + constantdict(expr.redn_axis_to_redn_descr), tags=expr.tags, axes=expr.axes, ) diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 507a450cd..c7f7e0cb4 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -31,7 +31,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Never, TypeVar, cast -from immutabledict import immutabledict +from constantdict import constantdict import pymbolic.primitives as prim from pymbolic import ArithmeticExpression @@ -274,7 +274,7 @@ def map_index_lambda(self, expr: IndexLambda) -> IndexLambda: return IndexLambda(expr=expr.expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=immutabledict({name: self.rec(bnd) + bindings=constantdict({name: self.rec(bnd) for name, bnd in sorted(expr.bindings.items())}), axes=expr.axes, @@ -312,8 +312,8 @@ def map_stack(self, expr: Stack) -> IndexLambda: shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, - bindings=immutabledict(bindings), - var_to_reduction_descr=immutabledict(), + bindings=constantdict(bindings), + var_to_reduction_descr=constantdict(), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -360,9 +360,9 @@ def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: return IndexLambda(expr=concat_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=immutabledict(bindings), + bindings=constantdict(bindings), axes=expr.axes, - var_to_reduction_descr=immutabledict(), + var_to_reduction_descr=constantdict(), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -429,14 +429,14 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: from pytato.reductions import SumReductionOperation inner_expr = Reduce(inner_expr, SumReductionOperation(), - immutabledict(redn_bounds)) + constantdict(redn_bounds)) return IndexLambda(expr=inner_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=immutabledict(bindings), + bindings=constantdict(bindings), axes=expr.axes, - var_to_reduction_descr=immutabledict(var_to_redn_descr), + var_to_reduction_descr=constantdict(var_to_redn_descr), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -464,10 +464,10 @@ def map_roll(self, expr: Roll) -> IndexLambda: return IndexLambda(expr=index_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=immutabledict({name: self.rec(bnd) + bindings=constantdict({name: self.rec(bnd) for name, bnd in bindings.items()}), axes=expr.axes, - var_to_reduction_descr=immutabledict(), + var_to_reduction_descr=constantdict(), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -535,11 +535,11 @@ def map_contiguous_advanced_index(self, return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), - bindings=immutabledict(bindings), + bindings=constantdict(bindings), shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, - var_to_reduction_descr=immutabledict(), + var_to_reduction_descr=constantdict(), tags=expr.tags, non_equality_tags=expr.non_equality_tags, ) @@ -604,11 +604,11 @@ def map_non_contiguous_advanced_index( return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), - bindings=immutabledict(bindings), + bindings=constantdict(bindings), shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, - var_to_reduction_descr=immutabledict(), + var_to_reduction_descr=constantdict(), tags=expr.tags, non_equality_tags=expr.non_equality_tags, ) @@ -641,11 +641,11 @@ def map_basic_index(self, expr: BasicIndex) -> IndexLambda: return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), - bindings=immutabledict(bindings), + bindings=constantdict(bindings), shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, - var_to_reduction_descr=immutabledict(), + var_to_reduction_descr=constantdict(), tags=expr.tags, non_equality_tags=expr.non_equality_tags, ) @@ -656,9 +656,9 @@ def map_reshape(self, expr: Reshape) -> IndexLambda: return IndexLambda(expr=index_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=immutabledict({"_in0": self.rec(expr.array)}), + bindings=constantdict({"_in0": self.rec(expr.array)}), axes=expr.axes, - var_to_reduction_descr=immutabledict(), + var_to_reduction_descr=constantdict(), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -673,9 +673,9 @@ def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: return IndexLambda(expr=index_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=immutabledict({"_in0": self.rec(expr.array)}), + bindings=constantdict({"_in0": self.rec(expr.array)}), axes=expr.axes, - var_to_reduction_descr=immutabledict(), + var_to_reduction_descr=constantdict(), tags=expr.tags, non_equality_tags=expr.non_equality_tags) diff --git a/pytato/utils.py b/pytato/utils.py index 31b621176..3974c441b 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -32,7 +32,7 @@ import islpy as isl import numpy as np -from immutabledict import immutabledict +from constantdict import constantdict import pymbolic.primitives as prim from pymbolic import ArithmeticExpression, Bool, Scalar @@ -260,10 +260,10 @@ def cast_to_result_type( return IndexLambda(expr=op(expr1, expr2), shape=result_shape, dtype=result_dtype, - bindings=immutabledict(bindings), + bindings=constantdict(bindings), tags=tags, non_equality_tags=non_equality_tags, - var_to_reduction_descr=immutabledict(), + var_to_reduction_descr=constantdict(), axes=_get_default_axes(len(result_shape))) diff --git a/test/test_pytato.py b/test/test_pytato.py index 5ee9cb66b..0cd924496 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -824,7 +824,7 @@ def test_basic_index_equality_traverses_underlying_arrays(): def test_idx_lambda_to_hlo(): - from immutabledict import immutabledict + from constantdict import constantdict from pytato.raising import ( BinaryOp, @@ -873,11 +873,11 @@ def test_idx_lambda_to_hlo(): assert (index_lambda_to_high_level_op(pt.sum(b, axis=1)) == ReduceOp(SumReductionOperation(), b, - immutabledict({1: "_r0"}))) + constantdict({1: "_r0"}))) assert (index_lambda_to_high_level_op(pt.prod(a)) == ReduceOp(ProductReductionOperation(), a, - immutabledict({0: "_r0", + constantdict({0: "_r0", 1: "_r1"}))) assert index_lambda_to_high_level_op(pt.sinh(a)) == C99CallOp("sinh", (a,)) assert index_lambda_to_high_level_op(pt.arctan2(b, a)) == C99CallOp("atan2", @@ -1062,8 +1062,8 @@ def __init__(self) -> None: def map_index_lambda(self, expr: pt.IndexLambda) -> pt.Array: assert not any(isinstance(s, pt.Array) for s in expr.shape) - from immutabledict import immutabledict - new_bindings: Mapping[str, pt.Array] = immutabledict({ + from constantdict import constantdict + new_bindings: Mapping[str, pt.Array] = constantdict({ name: self.rec(subexpr) for name, subexpr in sorted(expr.bindings.items())}) if ( @@ -1093,8 +1093,8 @@ def __init__(self) -> None: def map_index_lambda(self, expr: pt.IndexLambda) -> pt.Array: assert not any(isinstance(s, pt.Array) for s in expr.shape) - from immutabledict import immutabledict - new_bindings: Mapping[str, pt.Array] = immutabledict({ + from constantdict import constantdict + new_bindings: Mapping[str, pt.Array] = constantdict({ name: self.rec(subexpr) for name, subexpr in sorted(expr.bindings.items())}) return pt.IndexLambda(expr=expr.expr, @@ -1458,7 +1458,7 @@ def post_visit(self, expr, passed_number): def test_unify_axes_tags_indexlambda(): - from immutabledict import immutabledict + from constantdict import constantdict from testlib import BarTag, FooTag, TestlibTag from pymbolic import primitives as prim @@ -1474,11 +1474,11 @@ def test_unify_axes_tags_indexlambda(): prim.Subscript(prim.Variable("_in1"), (prim.Variable("_1"), 0))) ), - bindings=immutabledict({"_in0": x, "_in1": y}), + bindings=constantdict({"_in0": x, "_in1": y}), dtype=float, axes=pt.array._get_default_axes(2), tags=pt.array._get_default_tags(), shape=(10, 4), - var_to_reduction_descr=immutabledict({})) + var_to_reduction_descr=constantdict({})) z_unified = pt.unify_axes_tags(z) @@ -1608,7 +1608,7 @@ def test_unify_axes_tags(): # {{ Reduction Operations with IndexLambda # {{{ Reduce on outside of scalar expression - from immutabledict import immutabledict + from constantdict import constantdict import pymbolic.primitives as prim @@ -1640,7 +1640,7 @@ def assert_tags_were_propagated_appropriately(arr): frozenset([BazTag()]) def get_def_reduction_descrs(): - return immutabledict({"_r0": pt.array.ReductionDescriptor(frozenset([])), + return constantdict({"_r0": pt.array.ReductionDescriptor(frozenset([])), "_r1": pt.array.ReductionDescriptor(frozenset([])) }) @@ -1648,9 +1648,9 @@ def get_def_reduction_descrs(): # sum((_r0, _r1), a[_0] + b[_0, _r0] + b[_0,_r1])) w = pt.IndexLambda(expr=pt.scalar_expr.Reduce(prim.Sum((x, y, z)), pt.reductions.SumReductionOperation, - immutabledict({"_r0": (0, 10), + constantdict({"_r0": (0, 10), "_r1": (0, 10)})), - bindings=immutabledict({"_in0": a, "_in1": b, "_in2": c}), + bindings=constantdict({"_in0": a, "_in1": b, "_in2": c}), shape=(512,), tags=pt.array._get_default_tags(), axes=pt.array._get_default_axes(1), dtype=float, @@ -1670,11 +1670,11 @@ def get_def_reduction_descrs(): w = pt.IndexLambda(expr=prim.Sum((x, pt.scalar_expr.Reduce(y, pt.reductions.SumReductionOperation, - immutabledict({"_r0": (0, 10)})), + constantdict({"_r0": (0, 10)})), pt.scalar_expr.Reduce(z, pt.reductions.SumReductionOperation, - immutabledict({"_r1": (0, 10)})))), - bindings=immutabledict({"_in0": a, "_in1": b, "_in2": c}), + constantdict({"_r1": (0, 10)})))), + bindings=constantdict({"_in0": a, "_in1": b, "_in2": c}), shape=(512,), tags=pt.array._get_default_tags(), axes=pt.array._get_default_axes(1), dtype=float, @@ -1692,10 +1692,10 @@ def get_def_reduction_descrs(): w = pt.IndexLambda(expr=prim.Sum((x, pt.scalar_expr.Reduce(prim.Sum((y, pt.scalar_expr.Reduce(z, pt.reductions.SumReductionOperation, - immutabledict({"_r1": (0, 10)})))), + constantdict({"_r1": (0, 10)})))), pt.reductions.SumReductionOperation, - immutabledict({"_r0": (0, 10)})))), - bindings=immutabledict({"_in0": a, "_in1": b, "_in2": c}), + constantdict({"_r0": (0, 10)})))), + bindings=constantdict({"_in0": a, "_in1": b, "_in2": c}), shape=(512,), tags=pt.array._get_default_tags(), axes=pt.array._get_default_axes(1), dtype=float, @@ -1718,7 +1718,7 @@ def test_unify_axes_tags_with_unbroadcastable_expressions(): a = a.with_tagged_axis(1, QuuxTag()) a = a.with_tagged_axis(2, FooTag()) - from immutabledict import immutabledict + from constantdict import constantdict import pymbolic.primitives as prim @@ -1727,11 +1727,11 @@ def test_unify_axes_tags_with_unbroadcastable_expressions(): y = prim.Subscript(prim.Variable("_in1"), (prim.Variable("_0"), prim.Variable("_1"))) - z = pt.IndexLambda(expr=x+y, bindings=immutabledict({"_in0": a, "_in1": b}), + z = pt.IndexLambda(expr=x+y, bindings=constantdict({"_in0": a, "_in1": b}), shape=(512, 10, 8), tags=pt.array._get_default_tags(), axes=pt.array._get_default_axes(3), dtype=float, - var_to_reduction_descr=immutabledict({})) + var_to_reduction_descr=constantdict({})) z_unified = pt.unify_axes_tags(z)