Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions .mypy/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -192,23 +192,23 @@
"column": 27,
"message": "Need type annotation for \"index_to_axis_length\"",
"offset": 46,
"src": "index_to_axis_length = immutabledict(index_to_axis_length_dict)",
"src": "index_to_axis_length = constantdict(index_to_axis_length_dict)",
"target": "pytato.array._normalize_einsum_in_subscript"
},
{
"code": "var-annotated",
"column": 27,
"message": "See https://kotlinisland.github.io/basedmypy/_refs.html#code-var-annotated for more info",
"offset": 0,
"src": "index_to_axis_length = immutabledict(index_to_axis_length_dict)",
"src": "index_to_axis_length = constantdict(index_to_axis_length_dict)",
"target": null
},
{
"code": "var-annotated",
"column": 21,
"message": "Need type annotation for \"index_to_descr\"",
"offset": 1,
"src": "index_to_descr = immutabledict(index_to_descr_dict)",
"src": "index_to_descr = constantdict(index_to_descr_dict)",
"target": "pytato.array._normalize_einsum_in_subscript"
},
{
Expand Down Expand Up @@ -840,23 +840,23 @@
"column": 20,
"message": "\"Call\" gets multiple values for keyword argument \"tags\"",
"offset": 230,
"src": "call_site = Call(self, bindings=immutabledict(kwargs),",
"src": "call_site = Call(self, bindings=constantdict(kwargs),",
"target": "pytato.function.FunctionDefinition.__call__"
},
{
"code": "call-arg",
"column": 20,
"message": "Missing positional argument \"function\" in call to \"Call\"",
"offset": 0,
"src": "call_site = Call(self, bindings=immutabledict(kwargs),",
"src": "call_site = Call(self, bindings=constantdict(kwargs),",
"target": "pytato.function.FunctionDefinition.__call__"
},
{
"code": "arg-type",
"column": 25,
"message": "Argument 1 to \"Call\" has incompatible type \"FunctionDefinition\"; expected \"frozenset[Tag]\"",
"offset": 0,
"src": "call_site = Call(self, bindings=immutabledict(kwargs),",
"src": "call_site = Call(self, bindings=constantdict(kwargs),",
"target": "pytato.function.FunctionDefinition.__call__"
},
{
Expand Down Expand Up @@ -896,15 +896,15 @@
"column": 38,
"message": "Too few arguments for \"__getitem__\" of \"Call\"",
"offset": 3,
"src": "return immutabledict({kw: call_site[kw] for kw in self.returns})",
"src": "return constantdict({kw: call_site[kw] for kw in self.returns})",
"target": "pytato.function.FunctionDefinition.__call__"
},
{
"code": "index",
"column": 48,
"message": "Invalid index type \"str\" for \"Call\"; expected type \"Call\"",
"offset": 0,
"src": "return immutabledict({kw: call_site[kw] for kw in self.returns})",
"src": "return constantdict({kw: call_site[kw] for kw in self.returns})",
"target": "pytato.function.FunctionDefinition.__call__"
},
{
Expand Down Expand Up @@ -1088,7 +1088,7 @@
{
"code": "arg-type",
"column": 39,
"message": "Argument 2 to \"LoopyCall\" has incompatible type \"immutabledict[str, Array | int | integer[Any] | float | complex | inexact[Any, float | complex] | bool | numpy.bool[bool]]\"; expected \"TranslationUnit\"",
"message": "Argument 2 to \"LoopyCall\" has incompatible type \"constantdict[str, Array | int | integer[Any] | float | complex | inexact[Any, float | complex] | bool | numpy.bool[bool]]\"; expected \"TranslationUnit\"",
"offset": 0,
"src": "return LoopyCall(translation_unit, bindings_new, entrypoint,",
"target": "pytato.loopy.call_loopy"
Expand Down Expand Up @@ -1834,9 +1834,9 @@
{
"code": "arg-type",
"column": 20,
"message": "Argument 2 to \"Call\" has incompatible type \"immutabledict[Never, Never]\"; expected \"FunctionDefinition\"",
"message": "Argument 2 to \"Call\" has incompatible type \"constantdict[Never, Never]\"; expected \"FunctionDefinition\"",
"offset": 1,
"src": "immutabledict({name: self.rec(bnd)",
"src": "constantdict({name: self.rec(bnd)",
"target": "pytato.transform.CopyMapper.map_call"
},
{
Expand Down Expand Up @@ -1970,9 +1970,9 @@
{
"code": "arg-type",
"column": 20,
"message": "Argument 2 to \"Call\" has incompatible type \"immutabledict[Never, Never]\"; expected \"FunctionDefinition\"",
"message": "Argument 2 to \"Call\" has incompatible type \"constantdict[Never, Never]\"; expected \"FunctionDefinition\"",
"offset": 1,
"src": "immutabledict({name: self.rec(bnd, *args, **kwargs)",
"src": "constantdict({name: self.rec(bnd, *args, **kwargs)",
"target": "pytato.transform.CopyMapperWithExtraArgs.map_call"
},
{
Expand Down
6 changes: 3 additions & 3 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
"loopy": ("https://documen.tician.de/loopy/", None),
"sumpy": ("https://documen.tician.de/sumpy/", None),
"islpy": ("https://documen.tician.de/islpy/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"jax": ("https://docs.jax.dev/en/latest/", None),
"mpi4py": ("https://mpi4py.readthedocs.io/en/latest", None),
"immutabledict": ("https://immutabledict.corenting.fr/", None),
"orderedsets": ("https://matthiasdiener.github.io/orderedsets", None),
"constantdict": ("https://matthiasdiener.github.io/constantdict/", None),
"orderedsets": ("https://matthiasdiener.github.io/orderedsets/", None),
}

# Some modules need to import things just so that sphinx can resolve symbols in
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ classifiers = [
]
dependencies = [
"bidict",
"immutabledict",
"constantdict",
"loopy>=2020.2",
"pytools>=2024.1.21",
"pymbolic>=2024.2",
Expand Down
64 changes: 32 additions & 32 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
from warnings import warn

import numpy as np
from immutabledict import immutabledict
from constantdict import constantdict
from typing_extensions import Self

import pymbolic.primitives as prim
Expand Down Expand Up @@ -735,7 +735,7 @@ def _unary_op(self, op: Any) -> Array:
indices = tuple(var(f"_{i}") for i in range(self.ndim))
expr = op(var("_in0")[indices])

bindings: Mapping[str, Array] = immutabledict({"_in0": self})
bindings: Mapping[str, Array] = constantdict({"_in0": self})
return IndexLambda(
expr=expr,
shape=self.shape,
Expand All @@ -744,7 +744,7 @@ def _unary_op(self, op: Any) -> Array:
tags=_get_default_tags(),
axes=_get_default_axes(self.ndim),
non_equality_tags=_get_created_at_tag(),
var_to_reduction_descr=immutabledict())
var_to_reduction_descr=constantdict())

__mul__ = partialmethod(_binary_op, operator.mul)
__rmul__ = partialmethod(_binary_op, operator.mul, reverse=True)
Expand Down Expand Up @@ -1086,8 +1086,8 @@ class IndexLambda(_SuppliedAxesAndTagsMixin, _SuppliedShapeAndDtypeMixin, Array)

if __debug__:
def __post_init__(self) -> None:
assert isinstance(self.bindings, immutabledict)
assert isinstance(self.var_to_reduction_descr, immutabledict)
assert isinstance(self.bindings, constantdict)
assert isinstance(self.var_to_reduction_descr, constantdict)
super().__post_init__()

def with_tagged_reduction(self,
Expand All @@ -1114,7 +1114,7 @@ def with_tagged_reduction(self,
f" '{self.var_to_reduction_descr.keys()}',"
f" got '{reduction_variable}'.")

assert isinstance(self.var_to_reduction_descr, immutabledict)
assert isinstance(self.var_to_reduction_descr, constantdict)
new_var_to_redn_descr = dict(self.var_to_reduction_descr)
new_var_to_redn_descr[reduction_variable] = \
self.var_to_reduction_descr[reduction_variable].tagged(tags)
Expand All @@ -1124,7 +1124,7 @@ def with_tagged_reduction(self,
dtype=self.dtype,
bindings=self.bindings,
axes=self.axes,
var_to_reduction_descr=immutabledict
var_to_reduction_descr=constantdict
(new_var_to_redn_descr),
tags=self.tags,
non_equality_tags=self.non_equality_tags)
Expand Down Expand Up @@ -1203,7 +1203,7 @@ class Einsum(_SuppliedAxesAndTagsMixin, Array):

if __debug__:
def __post_init__(self) -> None:
assert isinstance(self.redn_axis_to_redn_descr, immutabledict)
assert isinstance(self.redn_axis_to_redn_descr, constantdict)
super().__post_init__()

@memoize_method
Expand All @@ -1230,7 +1230,7 @@ def _access_descr_to_axis_len(self
else:
descr_to_axis_len[descr] = arg_axis_len

return immutabledict(descr_to_axis_len)
return constantdict(descr_to_axis_len)

@cached_property
def shape(self) -> ShapeType:
Expand Down Expand Up @@ -1278,15 +1278,15 @@ def with_tagged_reduction(self,

# }}}

assert isinstance(self.redn_axis_to_redn_descr, immutabledict)
assert isinstance(self.redn_axis_to_redn_descr, constantdict)
new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr)
new_redn_axis_to_redn_descr[redn_axis] = \
self.redn_axis_to_redn_descr[redn_axis].tagged(tags)

return type(self)(access_descriptors=self.access_descriptors,
args=self.args,
axes=self.axes,
redn_axis_to_redn_descr=immutabledict
redn_axis_to_redn_descr=constantdict
(new_redn_axis_to_redn_descr),
tags=self.tags,
non_equality_tags=self.non_equality_tags,
Expand All @@ -1296,7 +1296,7 @@ def with_tagged_reduction(self,
EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P<alpha>[a-zA-Z])|(?P<ellipsis>\.\.\.))\s*")


def _normalize_einsum_out_subscript(subscript: str) -> immutabledict[str,
def _normalize_einsum_out_subscript(subscript: str) -> constantdict[str,
EinsumAxisDescriptor]:
"""
Normalizes the output subscript of an einsum (provided in the explicit
Expand Down Expand Up @@ -1336,7 +1336,7 @@ def _normalize_einsum_out_subscript(subscript: str) -> immutabledict[str,
raise ValueError("Used an input more than once to refer to the"
f" output axis in '{subscript}")

return immutabledict({idx: EinsumElementwiseAxis(i)
return constantdict({idx: EinsumElementwiseAxis(i)
for i, idx in enumerate(normalized_indices)})


Expand All @@ -1347,9 +1347,9 @@ def _normalize_einsum_in_subscript(subscript: str,
index_to_axis_length: Mapping[str,
ShapeComponent],
) -> tuple[tuple[EinsumAxisDescriptor, ...],
immutabledict
constantdict
[str, EinsumAxisDescriptor],
immutabledict[str, ShapeComponent]]:
constantdict[str, ShapeComponent]]:
"""
Normalizes the subscript for an input operand in an einsum. Returns
``(access_descrs, updated_index_to_descr, updated_to_index_to_axis_length)``,
Expand Down Expand Up @@ -1423,8 +1423,8 @@ def _normalize_einsum_in_subscript(subscript: str,

in_operand_axis_descrs.append(index_to_descr_dict[index_char])

index_to_axis_length = immutabledict(index_to_axis_length_dict)
index_to_descr = immutabledict(index_to_descr_dict)
index_to_axis_length = constantdict(index_to_axis_length_dict)
index_to_descr = constantdict(index_to_descr_dict)

return (tuple(in_operand_axis_descrs), index_to_descr, index_to_axis_length)

Expand Down Expand Up @@ -1459,7 +1459,7 @@ def einsum(subscripts: str, *operands: Array,
)

index_to_descr = _normalize_einsum_out_subscript(out_spec)
index_to_axis_length: Mapping[str, ShapeComponent] = immutabledict()
index_to_axis_length: Mapping[str, ShapeComponent] = constantdict()
access_descriptors = []

for in_spec, in_operand in zip(in_specs, operands, strict=True):
Expand Down Expand Up @@ -1494,7 +1494,7 @@ def einsum(subscripts: str, *operands: Array,
if isinstance(descr,
EinsumElementwiseAxis)})
),
redn_axis_to_redn_descr=immutabledict(redn_axis_to_redn_descr),
redn_axis_to_redn_descr=constantdict(redn_axis_to_redn_descr),
non_equality_tags=_get_created_at_tag(),
)

Expand Down Expand Up @@ -2349,11 +2349,11 @@ def full(shape: ConvertibleToShape, fill_value: Scalar | prim.NaN,

return IndexLambda(expr=cast("ArithmeticExpression", fill_value),
shape=shape, dtype=conv_dtype,
bindings=immutabledict(),
bindings=constantdict(),
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=immutabledict())
var_to_reduction_descr=constantdict())


def zeros(shape: ConvertibleToShape, dtype: Any = float,
Expand Down Expand Up @@ -2396,11 +2396,11 @@ def eye(N: int, M: int | None = None, k: int = 0, # noqa: N803
raise ValueError(f"k must be int, got {type(k)}.")

return IndexLambda(expr=prim.If(parse(f"(_1 - _0) == {k}"), 1, 0),
shape=(N, M), dtype=dtype, bindings=immutabledict({}),
shape=(N, M), dtype=dtype, bindings=constantdict({}),
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
axes=_get_default_axes(2),
var_to_reduction_descr=immutabledict())
var_to_reduction_descr=constantdict())

# }}}

Expand Down Expand Up @@ -2494,11 +2494,11 @@ def arange(*args: Any, **kwargs: Any) -> Array:

from pymbolic.primitives import Variable
return IndexLambda(expr=start + Variable("_0") * step,
shape=(size,), dtype=dtype, bindings=immutabledict(),
shape=(size,), dtype=dtype, bindings=constantdict(),
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
axes=_get_default_axes(1),
var_to_reduction_descr=immutabledict())
var_to_reduction_descr=constantdict())

# }}}

Expand Down Expand Up @@ -2623,7 +2623,7 @@ def logical_not(x: ArrayOrScalar) -> Array | bool:
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
axes=_get_default_axes(len(x.shape)),
var_to_reduction_descr=immutabledict())
var_to_reduction_descr=constantdict())

# }}}

Expand Down Expand Up @@ -2674,11 +2674,11 @@ def where(condition: ArrayOrScalar,
expr=prim.If(expr1, expr2, expr3),
shape=result_shape,
dtype=dtype,
bindings=immutabledict(bindings),
bindings=constantdict(bindings),
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
axes=_get_default_axes(len(result_shape)),
var_to_reduction_descr=immutabledict())
var_to_reduction_descr=constantdict())

# }}}

Expand Down Expand Up @@ -2771,13 +2771,13 @@ def make_index_lambda(
# }}}

return IndexLambda(expr=expression,
bindings=immutabledict(bindings),
bindings=constantdict(bindings),
shape=shape,
dtype=dtype,
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=immutabledict
var_to_reduction_descr=constantdict
(processed_var_to_reduction_descr))

# }}}
Expand Down Expand Up @@ -2859,11 +2859,11 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array:
shape)),
shape=shape,
dtype=array.dtype,
bindings=immutabledict({"in": array}),
bindings=constantdict({"in": array}),
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=immutabledict())
var_to_reduction_descr=constantdict())

# }}}

Expand Down
6 changes: 3 additions & 3 deletions pytato/cmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from typing import TYPE_CHECKING, cast

import numpy as np
from immutabledict import immutabledict
from constantdict import constantdict

import pymbolic.primitives as prim
from pymbolic import Scalar, var
Expand Down Expand Up @@ -129,11 +129,11 @@ def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...],
return IndexLambda(
expr=prim.Call(var(f"pytato.c99.{func_name}"),
tuple(sym_args)),
shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings),
shape=shape, dtype=ret_dtype, bindings=constantdict(bindings),
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(stacklevel=2),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=immutabledict(),
var_to_reduction_descr=constantdict(),
)


Expand Down
Loading
Loading