diff --git a/conftest.py b/conftest.py index 51aaa4d4..2d2cd835 100644 --- a/conftest.py +++ b/conftest.py @@ -9,7 +9,6 @@ import functools import logging import os -import sys from copy import deepcopy from pathlib import Path @@ -22,7 +21,6 @@ import pytest from packaging.version import Version -from imas.backends.imas_core.imas_interface import has_imas as _has_imas from imas.backends.imas_core.imas_interface import ll_interface, lowlevel from imas.dd_zip import dd_etree, dd_xml_versions, latest_dd_version from imas.ids_defs import ( @@ -39,17 +37,7 @@ os.environ["IMAS_AL_DISABLE_VALIDATE"] = "1" - -try: - import imas # noqa -except ImportError: - - class SkipOnIMASAccess: - def __getattr__(self, attr): - pytest.skip("This test requires the `imas` HLI, which is not available.") - - # Any test that tries to access an attribute from the `imas` package will be skipped - sys.modules["imas"] = SkipOnIMASAccess() +import imas # noqa def pytest_addoption(parser): @@ -78,7 +66,6 @@ def pytest_addoption(parser): if "not available" in str(iex.message): _BACKENDS.pop("mdsplus") - try: import pytest_xdist except ImportError: @@ -91,28 +78,11 @@ def worker_id(): @pytest.fixture(params=_BACKENDS) def backend(pytestconfig: pytest.Config, request: pytest.FixtureRequest): backends_provided = any(map(pytestconfig.getoption, _BACKENDS)) - if not _has_imas: - if backends_provided: - raise RuntimeError( - "Explicit backends are provided, but IMAS is not available." - ) - pytest.skip("No IMAS available, skip tests using a backend") if backends_provided and not pytestconfig.getoption(request.param): pytest.skip(f"Tests for {request.param} backend are skipped.") return _BACKENDS[request.param] -@pytest.fixture() -def has_imas(): - return _has_imas - - -@pytest.fixture() -def requires_imas(): - if not _has_imas: - pytest.skip("No IMAS available") - - def pytest_generate_tests(metafunc): if "ids_name" in metafunc.fixturenames: if metafunc.config.getoption("ids"): @@ -214,7 +184,7 @@ def wrapper(*args, **kwargs): @pytest.fixture -def log_lowlevel_calls(monkeypatch, requires_imas): +def log_lowlevel_calls(monkeypatch): """Debugging fixture to log calls to the imas lowlevel module.""" for al_function in dir(lowlevel): if al_function.startswith("ual_") or al_function.startswith("al"): diff --git a/docs/source/api.rst b/docs/source/api.rst index 5df6e579..0eaa3ed3 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -19,4 +19,5 @@ IMAS-Python IDS manipulation ids_toplevel.IDSToplevel ids_primitive.IDSPrimitive ids_structure.IDSStructure + ids_slice.IDSSlice ids_struct_array.IDSStructArray diff --git a/docs/source/array_slicing.rst b/docs/source/array_slicing.rst new file mode 100644 index 00000000..1daf8731 --- /dev/null +++ b/docs/source/array_slicing.rst @@ -0,0 +1,131 @@ +.. _array-slicing: + +Array Slicing +============= + +The ``IDSStructArray`` class supports Python's standard slicing syntax. + +Key Difference +--------------- + +- ``array[0]`` returns ``IDSStructure`` (single element) +- ``array[:]`` or ``array[1:5]`` returns ``IDSSlice`` (collection with ``values()`` method) + +Basic Usage +----------- + +.. code-block:: python + + import imas + + entry = imas.DBEntry("imas:hdf5?path=my-testdb") + cp = entry.get("core_profiles") + + # Integer indexing + first = cp.profiles_1d[0] # IDSStructure + last = cp.profiles_1d[-1] # IDSStructure + + # Slice operations + subset = cp.profiles_1d[1:5] # IDSSlice + every_other = cp.profiles_1d[::2] # IDSSlice + + # Access nested arrays + all_ions = cp.profiles_1d[:].ion[:] # IDSSlice of individual ions + + # Extract values + labels = all_ions.label.values() + +Multi-Dimensional Slicing +--------------------------- + +The ``IDSSlice`` class supports multi-dimensional shape tracking and array conversion. + +**Check shape of sliced data:** + +.. code-block:: python + + # Get shape information for multi-dimensional data + print(cp.profiles_1d[:].grid.shape) # (106,) + print(cp.profiles_1d[:].ion.shape) # (106, ~3) + print(cp.profiles_1d[1:3].ion[0].element.shape) # (2, ~3) + +**Extract values with shape preservation:** + +.. code-block:: python + + # Extract as list + grid_values = cp.profiles_1d[:].grid.values() + + # Extract as numpy array + grid_array = cp.profiles_1d[:].grid.to_array() + + # Extract as numpy array + ion_array = cp.profiles_1d[:].ion.to_array() + +**Nested structure access:** + +.. code-block:: python + + # Access through nested arrays + grid_data = cp.profiles_1d[1:3].grid.rho_tor.to_array() + + # Ion properties across multiple profiles + ion_labels = cp.profiles_1d[:].ion[:].label.to_array() + ion_charges = cp.profiles_1d[:].ion[:].z_ion.to_array() + +Common Patterns +--------------- + +**Process a range:** + +.. code-block:: python + + for element in cp.profiles_1d[5:10]: + print(element.time) + +**Iterate over nested arrays:** + +.. code-block:: python + + for ion in cp.profiles_1d[:].ion[:]: + print(ion.label.value) + +**Get all values:** + +.. code-block:: python + + times = cp.profiles_1d[:].time.values() + + # Or as numpy array + times_array = cp.profiles_1d[:].time.to_array() + +Important: Array-wise Indexing +------------------------------- + +When accessing attributes through a slice of ``IDSStructArray`` elements, +the slice operation automatically applies to each array (array-wise indexing): + +.. code-block:: python + + # Array-wise indexing: [:] applies to each ion array + all_ions = cp.profiles_1d[:].ion[:] + labels = all_ions.label.values() + + # Equivalent to manually iterating: + labels = [] + for profile in cp.profiles_1d[:]: + for ion in profile.ion: + labels.append(ion.label.value) + +Lazy-Loaded Arrays +------------------- + +Both individual indexing and slicing work with lazy loading: + +.. code-block:: python + + element = lazy_array[0] # OK - loads on demand + subset = lazy_array[1:5] # OK - loads only requested elements on demand + +When slicing lazy-loaded arrays, only the elements in the slice range are loaded, +making it memory-efficient for large datasets. diff --git a/docs/source/courses/advanced/explore.rst b/docs/source/courses/advanced/explore.rst index 7b383bc5..02f1201f 100644 --- a/docs/source/courses/advanced/explore.rst +++ b/docs/source/courses/advanced/explore.rst @@ -72,6 +72,32 @@ structures (modeled by :py:class:`~imas.ids_struct_array.IDSStructArray`) are (a name applies) arrays containing :py:class:`~imas.ids_structure.IDSStructure`\ s. Data nodes can contain scalar or array data of various types. +**Slicing Arrays of Structures** + +Arrays of structures support Python slice notation, which returns an +:py:class:`~imas.ids_slice.IDSSlice` object containing matched elements: + +.. code-block:: python + + import imas + + core_profiles = imas.IDSFactory().core_profiles() + core_profiles.profiles_1d.resize(10) # Create 10 profiles + + # Integer indexing returns a single structure + first = core_profiles.profiles_1d[0] + + # Slice notation returns an IDSSlice + subset = core_profiles.profiles_1d[2:5] # Elements 2, 3, 4 + every_other = core_profiles.profiles_1d[::2] # Every second element + + # IDSSlice supports array-wise indexing and values() for data access + all_ions = core_profiles.profiles_1d[:].ion[:] + for ion in all_ions: + print(ion.label.value) + +For detailed information on slicing operations, see :doc:`../../array_slicing`. + Some methods and properties are defined for all data nodes and arrays of structures: ``len()`` diff --git a/docs/source/imas_architecture.rst b/docs/source/imas_architecture.rst index 182d2a0c..756d8f79 100644 --- a/docs/source/imas_architecture.rst +++ b/docs/source/imas_architecture.rst @@ -168,6 +168,12 @@ The following submodules and classes represent IDS nodes. :py:class:`~imas.ids_struct_array.IDSStructArray` class, which models Arrays of Structures. It also contains some :ref:`dev lazy loading` logic. +- :py:mod:`imas.ids_slice` contains the + :py:class:`~imas.ids_slice.IDSSlice` class, which represents a collection of IDS + nodes matching a slice expression. It provides slicing operations on + :py:class:`~imas.ids_struct_array.IDSStructArray` elements with array-wise + indexing and supports the ``values()`` method for extracting raw data. + - :py:mod:`imas.ids_structure` contains the :py:class:`~imas.ids_structure.IDSStructure` class, which models Structures. It contains the :ref:`lazy instantiation` logic and some of the :ref:`dev lazy loading` diff --git a/docs/source/index.rst b/docs/source/index.rst index 8388f5b5..7b8f98fc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -50,6 +50,7 @@ Manual configuring cli netcdf + array_slicing changelog examples diff --git a/docs/source/intro.rst b/docs/source/intro.rst index 3027a242..125b407c 100644 --- a/docs/source/intro.rst +++ b/docs/source/intro.rst @@ -154,3 +154,38 @@ can use ``.get()`` to load IDS data from disk: >>> dbentry2 = imas.DBEntry("mypulsefile.nc","r") >>> core_profiles2 = dbentry2.get("core_profiles") >>> print(core_profiles2.ids_properties.comment.value) + + +.. _`Multi-Dimensional Slicing`: + +Multi-Dimensional Slicing +'''''''''''''''''''''''''' + +IMAS-Python supports advanced slicing of hierarchical data structures with automatic +shape tracking and array conversion to numpy. This enables intuitive access to +multi-dimensional scientific data: + +.. code-block:: python + + >>> # Load data + >>> entry = imas.DBEntry("mypulsefile.nc","r") + >>> cp = entry.get("core_profiles", autoconvert=False, lazy=True) + + >>> # Check shape of sliced data + >>> cp.profiles_1d[:].grid.shape + (106,) + >>> cp.profiles_1d[:].ion.shape + (106, ~3) # ~3 ions per profile + + >>> # Extract values + >>> grid_values = cp.profiles_1d[:].grid.to_array() + >>> ion_labels = cp.profiles_1d[:].ion[:].label.to_array() + + >>> # Work with subsets + >>> subset_grid = cp.profiles_1d[1:3].grid.to_array() + >>> subset_ions = cp.profiles_1d[1:3].ion.to_array() + +The ``IDSSlice`` class tracks multi-dimensional shapes and provides both +``.values()`` and ``.to_array()`` (numpy array) +methods for data extraction. For more details, see :ref:`array-slicing`. + diff --git a/imas/backends/db_entry_impl.py b/imas/backends/db_entry_impl.py index df1e4638..0c1b2cd6 100644 --- a/imas/backends/db_entry_impl.py +++ b/imas/backends/db_entry_impl.py @@ -78,7 +78,7 @@ def get( destination: IDSToplevel, lazy: bool, nbc_map: Optional[NBCPathMap], - ) -> None: + ) -> IDSToplevel: """Implement DBEntry.get/get_slice/get_sample. Load data from the data source. Args: diff --git a/imas/backends/imas_core/db_entry_al.py b/imas/backends/imas_core/db_entry_al.py index dad5019b..b9d118dd 100644 --- a/imas/backends/imas_core/db_entry_al.py +++ b/imas/backends/imas_core/db_entry_al.py @@ -38,7 +38,7 @@ from .al_context import ALContext, LazyALContext from .db_entry_helpers import delete_children, get_children, put_children -from .imas_interface import LLInterfaceError, has_imas, ll_interface +from .imas_interface import LLInterfaceError, ll_interface from .mdsplus_model import mdsplus_model_dir from .uda_support import extract_idsdef, get_dd_version_from_idsdef_xml @@ -52,14 +52,6 @@ logger = logging.getLogger(__name__) -def require_imas_available(): - if not has_imas: - raise RuntimeError( - "The IMAS Core library is not available. Please install 'imas_core', " - "or load a supported IMAS module if you use an HPC environment." - ) - - class ALDBEntryImpl(DBEntryImpl): """DBEntry implementation using imas_core as a backend.""" @@ -86,7 +78,6 @@ def __init__(self, uri: str, mode: int, factory: IDSFactory): @classmethod def from_uri(cls, uri: str, mode: str, factory: IDSFactory) -> "ALDBEntryImpl": - require_imas_available() if mode not in _OPEN_MODES: modes = list(_OPEN_MODES) raise ValueError(f"Unknown mode {mode!r}, was expecting any of {modes}") @@ -105,8 +96,6 @@ def from_pulse_run( options: Any, factory: IDSFactory, ) -> "ALDBEntryImpl": - # Raise an error if imas is not available - require_imas_available() # Set defaults user_name = user_name or getpass.getuser() diff --git a/imas/backends/imas_core/imas_interface.py b/imas/backends/imas_core/imas_interface.py index 8fa3963b..c9d69a02 100644 --- a/imas/backends/imas_core/imas_interface.py +++ b/imas/backends/imas_core/imas_interface.py @@ -12,30 +12,17 @@ from packaging.version import Version -logger = logging.getLogger(__name__) +# Import the Access Layer module +# First try to import imas_core, which is available since AL 5.2 +from imas_core import _al_lowlevel as lowlevel +from imas_core import imasdef # noqa: F401 +logger = logging.getLogger(__name__) -# Import the Access Layer module -has_imas = True -try: - # First try to import imas_core, which is available since AL 5.2 - from imas_core import _al_lowlevel as lowlevel - from imas_core import imasdef - - # Enable throwing exceptions from the _al_lowlevel interface - enable_exceptions = getattr(lowlevel, "imas_core_config_enable_exceptions", None) - if enable_exceptions: - enable_exceptions() - -except ImportError as exc: - imas = None - has_imas = False - imasdef = None - lowlevel = None - logger.warning( - "Could not import 'imas_core': %s. Some functionality is not available.", - exc, - ) +# Enable throwing exceptions from the _al_lowlevel interface +enable_exceptions = getattr(lowlevel, "imas_core_config_enable_exceptions", None) +if enable_exceptions: + enable_exceptions() class LLInterfaceError(RuntimeError): diff --git a/imas/backends/netcdf/nc2ids.py b/imas/backends/netcdf/nc2ids.py index 1b1dbfe8..564d5210 100644 --- a/imas/backends/netcdf/nc2ids.py +++ b/imas/backends/netcdf/nc2ids.py @@ -157,6 +157,8 @@ def run(self, lazy: bool) -> None: for index, node in indexed_tree_iter(self.ids, target_metadata): value = data[index] if value != getattr(var, "_FillValue", None): + if isinstance(value, np.generic): + value = value.item() # NOTE: bypassing IDSPrimitive.value.setter logic node._IDSPrimitive__value = value @@ -166,10 +168,16 @@ def run(self, lazy: bool) -> None: # here, we'll let IDSPrimitive.value.setter take care of it: self.ids[target_metadata.path].value = data - else: + # We need to unpack 0D ints, floats and complex numbers. For better + # performance this check is done outside the for-loop: + elif metadata.ndim or metadata.data_type is IDSDataType.STR: for index, node in indexed_tree_iter(self.ids, target_metadata): # NOTE: bypassing IDSPrimitive.value.setter logic node._IDSPrimitive__value = data[index] + else: + for index, node in indexed_tree_iter(self.ids, target_metadata): + # NOTE: bypassing IDSPrimitive.value.setter logic + node._IDSPrimitive__value = data[index].item() # Unpack 0D value def validate_variables(self) -> None: """Validate that all variables in the netCDF Group exist and match the DD.""" @@ -365,7 +373,7 @@ def get_child(self, child): value = var[self.index] if value is not None: - if isinstance(value, np.ndarray): + if isinstance(value, (np.ndarray, np.generic)): if value.ndim == 0: # Unpack 0D numpy arrays: value = value.item() else: diff --git a/imas/db_entry.py b/imas/db_entry.py index 471a50ad..5a470641 100644 --- a/imas/db_entry.py +++ b/imas/db_entry.py @@ -160,7 +160,7 @@ def __init__( legacy = True except TypeError as exc2: raise TypeError( - f"Incorrect arguments to {__class__.__name__}.__init__(): " + "Incorrect arguments to DBEntry.__init__(): " f"{exc1.args[0]}, {exc2.args[0]}" ) from None @@ -561,7 +561,7 @@ def _get( raise RuntimeError("Database entry is not open.") if lazy and destination: raise ValueError("Cannot supply a destination IDS when lazy loading.") - if not self._ids_factory.exists(ids_name): + if autoconvert and not self._ids_factory.exists(ids_name): raise IDSNameError(ids_name, self._ids_factory) # Note: this will raise an exception when the ids/occurrence is not filled: @@ -577,7 +577,7 @@ def _get( ids_name, occurrence, ) - elif dd_version != self.dd_version and dd_version not in dd_xml_versions(): + elif dd_version not in dd_xml_versions() and dd_version != self.dd_version: # We don't know the DD version that this IDS was written with if ignore_unknown_dd_version: # User chooses to ignore this problem, load as if it was stored with diff --git a/imas/exception.py b/imas/exception.py index 737680c2..737284d8 100644 --- a/imas/exception.py +++ b/imas/exception.py @@ -20,10 +20,8 @@ # Expose ALException, which may be thrown by the lowlevel -if _imas_interface.has_imas: - ALException = _imas_interface.lowlevel.ALException -else: - ALException = None + +ALException = _imas_interface.lowlevel.ALException class IDSNameError(ValueError): diff --git a/imas/ids_defs.py b/imas/ids_defs.py index af4ed45c..3ac3c6be 100644 --- a/imas/ids_defs.py +++ b/imas/ids_defs.py @@ -86,86 +86,46 @@ Identifier for the default serialization protocol. """ -import functools import logging -from imas.backends.imas_core.imas_interface import has_imas, imasdef +from imas.backends.imas_core.imas_interface import imasdef logger = logging.getLogger(__name__) -if has_imas: - ASCII_BACKEND = imasdef.ASCII_BACKEND - CHAR_DATA = imasdef.CHAR_DATA - CLOSE_PULSE = imasdef.CLOSE_PULSE - CLOSEST_INTERP = imasdef.CLOSEST_INTERP - CREATE_PULSE = imasdef.CREATE_PULSE - DOUBLE_DATA = imasdef.DOUBLE_DATA - COMPLEX_DATA = imasdef.COMPLEX_DATA - EMPTY_COMPLEX = imasdef.EMPTY_COMPLEX - EMPTY_FLOAT = imasdef.EMPTY_FLOAT - EMPTY_INT = imasdef.EMPTY_INT - ERASE_PULSE = imasdef.ERASE_PULSE - FORCE_CREATE_PULSE = imasdef.FORCE_CREATE_PULSE - FORCE_OPEN_PULSE = imasdef.FORCE_OPEN_PULSE - HDF5_BACKEND = imasdef.HDF5_BACKEND - IDS_TIME_MODE_HETEROGENEOUS = imasdef.IDS_TIME_MODE_HETEROGENEOUS - IDS_TIME_MODE_HOMOGENEOUS = imasdef.IDS_TIME_MODE_HOMOGENEOUS - IDS_TIME_MODE_INDEPENDENT = imasdef.IDS_TIME_MODE_INDEPENDENT - IDS_TIME_MODE_UNKNOWN = imasdef.IDS_TIME_MODE_UNKNOWN - IDS_TIME_MODES = imasdef.IDS_TIME_MODES - INTEGER_DATA = imasdef.INTEGER_DATA - LINEAR_INTERP = imasdef.LINEAR_INTERP - MDSPLUS_BACKEND = imasdef.MDSPLUS_BACKEND - MEMORY_BACKEND = imasdef.MEMORY_BACKEND - NODE_TYPE_STRUCTURE = imasdef.NODE_TYPE_STRUCTURE - OPEN_PULSE = imasdef.OPEN_PULSE - PREVIOUS_INTERP = imasdef.PREVIOUS_INTERP - READ_OP = imasdef.READ_OP - UDA_BACKEND = imasdef.UDA_BACKEND - UNDEFINED_INTERP = imasdef.UNDEFINED_INTERP - UNDEFINED_TIME = imasdef.UNDEFINED_TIME - WRITE_OP = imasdef.WRITE_OP - ASCII_SERIALIZER_PROTOCOL = getattr(imasdef, "ASCII_SERIALIZER_PROTOCOL", 60) - FLEXBUFFERS_SERIALIZER_PROTOCOL = getattr( - imasdef, "FLEXBUFFERS_SERIALIZER_PROTOCOL", None - ) - DEFAULT_SERIALIZER_PROTOCOL = getattr(imasdef, "DEFAULT_SERIALIZER_PROTOCOL", 60) - -else: - # Preset some constants which are used elsewhere - # this is a bit ugly, perhaps reuse the list of imports from above? - # it seems no problem to use None, since the use of the values should not - # be allowed, they are only used in operations which use the backend, - # which we (should) gate - ASCII_BACKEND = CHAR_DATA = CLOSE_PULSE = CLOSEST_INTERP = DOUBLE_DATA = None - FORCE_OPEN_PULSE = CREATE_PULSE = ERASE_PULSE = None - COMPLEX_DATA = FORCE_CREATE_PULSE = HDF5_BACKEND = None - INTEGER_DATA = LINEAR_INTERP = MDSPLUS_BACKEND = MEMORY_BACKEND = None - NODE_TYPE_STRUCTURE = OPEN_PULSE = PREVIOUS_INTERP = READ_OP = None - UDA_BACKEND = UNDEFINED_INTERP = UNDEFINED_TIME = WRITE_OP = None - # These constants are also useful when not working with the AL - EMPTY_FLOAT = -9e40 - EMPTY_INT = -999_999_999 - EMPTY_COMPLEX = complex(EMPTY_FLOAT, EMPTY_FLOAT) - IDS_TIME_MODE_UNKNOWN = EMPTY_INT - IDS_TIME_MODE_HETEROGENEOUS = 0 - IDS_TIME_MODE_HOMOGENEOUS = 1 - IDS_TIME_MODE_INDEPENDENT = 2 - IDS_TIME_MODES = [0, 1, 2] - ASCII_SERIALIZER_PROTOCOL = 60 - FLEXBUFFERS_SERIALIZER_PROTOCOL = None - DEFAULT_SERIALIZER_PROTOCOL = 60 - - -def needs_imas(func): - if has_imas: - return func - - @functools.wraps(func) - def wrapper(*args, **kwargs): - raise RuntimeError( - f"Function {func.__name__} requires IMAS, but IMAS is not available." - ) - - return wrapper +ASCII_BACKEND = imasdef.ASCII_BACKEND +CHAR_DATA = imasdef.CHAR_DATA +CLOSE_PULSE = imasdef.CLOSE_PULSE +CLOSEST_INTERP = imasdef.CLOSEST_INTERP +CREATE_PULSE = imasdef.CREATE_PULSE +DOUBLE_DATA = imasdef.DOUBLE_DATA +COMPLEX_DATA = imasdef.COMPLEX_DATA +EMPTY_COMPLEX = imasdef.EMPTY_COMPLEX +EMPTY_FLOAT = imasdef.EMPTY_FLOAT +EMPTY_INT = imasdef.EMPTY_INT +ERASE_PULSE = imasdef.ERASE_PULSE +FORCE_CREATE_PULSE = imasdef.FORCE_CREATE_PULSE +FORCE_OPEN_PULSE = imasdef.FORCE_OPEN_PULSE +HDF5_BACKEND = imasdef.HDF5_BACKEND +IDS_TIME_MODE_HETEROGENEOUS = imasdef.IDS_TIME_MODE_HETEROGENEOUS +IDS_TIME_MODE_HOMOGENEOUS = imasdef.IDS_TIME_MODE_HOMOGENEOUS +IDS_TIME_MODE_INDEPENDENT = imasdef.IDS_TIME_MODE_INDEPENDENT +IDS_TIME_MODE_UNKNOWN = imasdef.IDS_TIME_MODE_UNKNOWN +IDS_TIME_MODES = imasdef.IDS_TIME_MODES +INTEGER_DATA = imasdef.INTEGER_DATA +LINEAR_INTERP = imasdef.LINEAR_INTERP +MDSPLUS_BACKEND = imasdef.MDSPLUS_BACKEND +MEMORY_BACKEND = imasdef.MEMORY_BACKEND +NODE_TYPE_STRUCTURE = imasdef.NODE_TYPE_STRUCTURE +OPEN_PULSE = imasdef.OPEN_PULSE +PREVIOUS_INTERP = imasdef.PREVIOUS_INTERP +READ_OP = imasdef.READ_OP +UDA_BACKEND = imasdef.UDA_BACKEND +UNDEFINED_INTERP = imasdef.UNDEFINED_INTERP +UNDEFINED_TIME = imasdef.UNDEFINED_TIME +WRITE_OP = imasdef.WRITE_OP +ASCII_SERIALIZER_PROTOCOL = getattr(imasdef, "ASCII_SERIALIZER_PROTOCOL", 60) +FLEXBUFFERS_SERIALIZER_PROTOCOL = getattr( + imasdef, "FLEXBUFFERS_SERIALIZER_PROTOCOL", None +) +DEFAULT_SERIALIZER_PROTOCOL = getattr(imasdef, "DEFAULT_SERIALIZER_PROTOCOL", 60) diff --git a/imas/ids_factory.py b/imas/ids_factory.py index b840d8a8..5a8209db 100644 --- a/imas/ids_factory.py +++ b/imas/ids_factory.py @@ -41,6 +41,17 @@ def __init__( version: DD version string, e.g. "3.38.1". xml_path: XML file containing data dictionary definition. """ + if version is None and xml_path is None: + # Defer loading the DD definitions until we really need them + self.__deferred_init = True + else: + # If a specific version or xml_path is requested, we still load immediately + # so any exceptions are raise when creating the IDSfactory + self.__do_init(version, xml_path) + self.__deferred_init = False + + def __do_init(self, version: str | None, xml_path: str | pathlib.Path | None): + """Actual initialization logic""" self._xml_path = xml_path self._etree = dd_zip.dd_etree(version, xml_path) self._ids_elements = { @@ -71,10 +82,16 @@ def __dir__(self) -> Iterable[str]: return sorted(set(object.__dir__(self)).union(self._ids_elements)) def __getattr__(self, name: str) -> Any: + # Actually initialize when we deferred it before + if self.__deferred_init: + self.__do_init(None, None) + self.__deferred_init = False + return getattr(self, name) + # Check if the name matches any IDS and return a 'constructor' for it if name in self._ids_elements: # Note: returning a partial to mimic AL HLI, e.g. factory.core_profiles() return partial(IDSToplevel, self, self._ids_elements[name]) - raise AttributeError(f"{type(self)!r} object has no attribute {name!r}") + raise AttributeError(f"'IDSFactory' has no attribute {name!r}") def __iter__(self) -> Iterator[str]: """Iterate over the IDS names defined by the loaded Data Dictionary""" diff --git a/imas/ids_metadata.py b/imas/ids_metadata.py index 4d2d5dbb..cb458865 100644 --- a/imas/ids_metadata.py +++ b/imas/ids_metadata.py @@ -287,6 +287,20 @@ def __getitem__(self, path) -> "IDSMetadata": ) from None return item + @property + def ids_name(self) -> str: + """Get the root IDS name (e.g., 'core_profiles', 'equilibrium'). + + Traverses up the metadata hierarchy to find the toplevel IDS name. + + Returns: + The name of the root IDS node. + """ + current = self + while current._parent is not None: + current = current._parent + return current.name + @property def identifier_enum(self) -> Optional[Type[IDSIdentifier]]: """The identifier enum for this IDS node (if available). diff --git a/imas/ids_slice.py b/imas/ids_slice.py new file mode 100644 index 00000000..bce7d78b --- /dev/null +++ b/imas/ids_slice.py @@ -0,0 +1,596 @@ +# This file is part of IMAS-Python. +# You should have received the IMAS-Python LICENSE file with this project. +"""IDSSlice represents a collection of IDS nodes matching a slice expression. + +This module provides the IDSSlice class, which enables slicing of arrays of +structures while maintaining the hierarchy and allowing further operations on +the resulting collection. +""" + +import logging +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union + +import numpy as np + +from imas.ids_metadata import IDSMetadata + +if TYPE_CHECKING: + from imas.ids_struct_array import IDSStructArray + +logger = logging.getLogger(__name__) + + +class IDSSlice: + """Represents a slice of IDS struct array elements. + + When slicing an IDSStructArray, instead of returning a regular Python list, + an IDSSlice is returned. This allows for: + - Tracking the slice operation in the path + - Further slicing of child elements + - Child node access on all matched elements + - Iteration over matched elements + + Attributes: + metadata: Metadata from the parent array (always present) + """ + + __slots__ = [ + "metadata", + "_matched_elements", + "_slice_path", + "_parent_array", + "_virtual_shape", + "_element_hierarchy", + ] + + def __init__( + self, + metadata: IDSMetadata, + matched_elements: List[Any], + full_path: str, + parent_array: Optional["IDSStructArray"] = None, + virtual_shape: Optional[Tuple[int, ...]] = None, + element_hierarchy: Optional[List[Any]] = None, + ): + """Initialize IDSSlice. + + Args: + metadata: Metadata from the parent array (required) + matched_elements: List of elements that matched the slice + full_path: Full path from the IDS root (e.g., "profiles_1d[:].ion[:]") + parent_array: Optional reference to the parent IDSStructArray for context + virtual_shape: Optional tuple representing multi-dimensional shape + element_hierarchy: Optional tracking of element grouping + """ + self.metadata = metadata + self._matched_elements = matched_elements + self._slice_path = full_path + self._parent_array = parent_array + self._virtual_shape = virtual_shape or (len(matched_elements),) + self._element_hierarchy = element_hierarchy or [len(matched_elements)] + + @property + def _path(self) -> str: + """Return the path representation of this slice.""" + return self._slice_path + + @property + def is_ragged(self) -> bool: + """Check if the underlying data is ragged (non-rectangular). + + Ragged arrays have varying sizes at one or more dimensions. + + Returns: + True if any dimension has varying sizes, False otherwise + + """ + # Check if any level in the hierarchy has non-uniform sizes + for sizes_list in self._element_hierarchy: + # sizes_list can be a list of sizes or a single integer + if isinstance(sizes_list, list) and len(sizes_list) > 1: + if len(set(sizes_list)) > 1: + return True + return False + + @property + def shape(self) -> Tuple[int, ...]: + """Get the virtual multi-dimensional shape. + + Returns the shape of the data as if it were organized in a multi-dimensional + array, based on the hierarchy of slicing operations performed. + + Raises: + ValueError: The underlying data is ragged (non-rectangular). + Use .is_ragged to check first, or use .values() to extract + values as a flat list. + + Returns: + Tuple of dimensions. + """ + if self.is_ragged: + raise ValueError( + "Cannot get shape of ragged array: dimensions have varying " + "sizes. Use .is_ragged to check if data is ragged, or .values() " + "to get a flat list of elements." + ) + + # Build shape from hierarchy + shape = [] + for i, hierarchy_level in enumerate(self._element_hierarchy): + if isinstance(hierarchy_level, list): + # This is a list of sizes + if i == 0: + # First level with a list means grouped data + # The number of groups is the first hierarchy level (implicit) + shape.append(len(hierarchy_level)) + else: + # Subsequent levels: use first size (uniform, we checked is_ragged) + shape.append(hierarchy_level[0] if hierarchy_level else 0) + else: + # This is a single count + shape.append(hierarchy_level) + + return tuple(shape) + + def __len__(self) -> int: + """Return the number of elements matched by this slice.""" + return len(self._matched_elements) + + def __iter__(self) -> Iterator[Any]: + """Iterate over all matched elements.""" + return iter(self._matched_elements) + + def __getitem__(self, item: Union[int, slice]) -> "IDSSlice": + """Get element(s) from the slice using slice notation. + + Only slice operations are supported. Integer indexing on IDSSlice + is not allowed to avoid confusion with array-wise operations. + Use direct indexing on the IDS structure instead. + + Args: + item: Slice object to apply + + Returns: + IDSSlice: A new slice with the applied slice operation + + Raises: + TypeError: If item is an integer (not supported) + + Examples: + Slice operations (supported):: + + # Get ions 0 through 2 from all profiles + result = cp.profiles_1d[:].ion[:3] # OK - returns IDSSlice + result = cp.profiles_1d[:].ion[1:3] # OK - returns IDSSlice + result = cp.profiles_1d[:].ion[::2] # OK - returns IDSSlice + + Integer indexing (NOT supported):: + + # These will raise TypeError + result = cp.profiles_1d[:].ion[0] # ERROR! + + Recommended alternatives to integer indexing:: + + # Option 1: Direct indexing (best - most efficient, clearest) + result = cp.profiles_1d[0].ion[:] + + # Option 2: Convert slice to list first + ions_list = list(cp.profiles_1d[:].ion) + result = ions_list[0] + + # Option 3: Extract values + ions_values = cp.profiles_1d[:].ion.values() + result = ions_values[0] + """ + if isinstance(item, slice): + return self._handle_list_slice(item) + else: + # Integer indexing not allowed + raise TypeError( + f"Cannot index IDSSlice with integer {item}. " + f"IDSSlice only supports slice notation (e.g., [0:5], [::2]).\n\n" + f"To access elements, use one of these alternatives:\n" + f" 1. Direct indexing (recommended):\n" + f" ids[{item}].node # Access element directly\n" + f" 2. Convert to list first:\n" + f" list(ids)[{item}] # Convert slice to list\n" + f" 3. Extract values:\n" + f" ids.values()[{item}] # Get values as flat list" + ) + + def _handle_list_slice(self, item: slice) -> "IDSSlice": + """Apply a slice operation to the matched elements list. + + Updates the first dimension of the shape to reflect the new + number of elements after slicing. + + Args: + item: The slice object to apply + + Returns: + IDSSlice with updated shape and hierarchy + """ + from imas.ids_struct_array import IDSStructArray + + slice_str = self._format_slice(item) + # Full path: current path + slice operation + full_path = self._path + slice_str + + # Check if matched elements are IDSStructArray (nested arrays) + if self._matched_elements and isinstance( + self._matched_elements[0], IDSStructArray + ): + # When slicing nested arrays, apply slice to each array and then flatten + flattened_elements = [] + new_hierarchy_values = [] + for array in self._matched_elements: + sliced_array = array[item] + new_hierarchy_values.append(len(sliced_array)) + # Flatten: add each element from the sliced array to flattened list + for element in sliced_array: + flattened_elements.append(element) + + # Build new hierarchy + # The key is: if we have a multi-level grouped hierarchy + # (like [3, [2, 2, 2], ...]), we're dealing with a nested + # structure that's already been flattened. We should only update + # the innermost level, NOT create a new top-level grouping. + + num_groups = len(self._matched_elements) + + if ( + len(self._element_hierarchy) >= 2 + and isinstance(self._element_hierarchy[0], int) + and isinstance(self._element_hierarchy[1], list) + ): + # Multi-level hierarchy like [3, [2, 2, 2], ...] + # The top level is the original grouping, so DON'T recreate it + # Just replace the last (innermost) level + new_hierarchy = self._element_hierarchy[:-1] + [new_hierarchy_values] + else: + # Single level or not grouped yet - create new grouping + new_hierarchy = [num_groups, new_hierarchy_values] + + return IDSSlice( + self.metadata, + flattened_elements, + full_path, + parent_array=self._parent_array, + virtual_shape=(len(flattened_elements),), + element_hierarchy=new_hierarchy, + ) + else: + # Normal slice on outer list + sliced_elements = self._matched_elements[item] + + # Update shape to reflect the slice on first dimension + new_virtual_shape = (len(sliced_elements),) + self._virtual_shape[1:] + new_element_hierarchy = [len(sliced_elements)] + self._element_hierarchy[1:] + + return IDSSlice( + self.metadata, + sliced_elements, + full_path, + parent_array=self._parent_array, + virtual_shape=new_virtual_shape, + element_hierarchy=new_element_hierarchy, + ) + + def __getattr__(self, name: str) -> "IDSSlice": + """Access a child node on all matched elements. + + Returns a new IDSSlice containing the child node from each matched + element. Validates the attribute name against metadata, allowing + empty slices with valid child node names. + + Args: + name: Name of the node to access + + Returns: + A new IDSSlice containing the child node from each matched element, + or an empty IDSSlice if the matched_elements is empty but the + attribute name is valid according to metadata. + + Raises: + AttributeError: If name is not a valid child node in the metadata + """ + from imas.ids_struct_array import IDSStructArray + from imas.ids_primitive import IDSNumericArray + + # Validate attribute name via metadata + try: + child_metadata = self.metadata[name] + except (KeyError, TypeError): + raise AttributeError( + f"'{self.metadata.name}' has no child node '{name}'" + ) from None + + # Full path: current path + attribute access + full_path = self._path + "." + name + + # Handle empty slice - valid if metadata says it's a valid node + if not self._matched_elements: + return IDSSlice( + child_metadata, + [], + full_path, + parent_array=self._parent_array, + virtual_shape=(0,), + element_hierarchy=[0], + ) + + # Get attributes from all non-empty matched elements + # Special case: if matched_elements are IDSStructArray, keep them grouped + if self._matched_elements and isinstance( + self._matched_elements[0], IDSStructArray + ): + # For nested arrays, return the arrays themselves, not attributes from them + # This allows chaining like .ion[:].element[:] to work + child_elements = self._matched_elements + else: + child_elements = [getattr(element, name) for element in self] + + # Check if children are IDSStructArray (nested arrays) or IDSNumericArray + if not child_elements: + # Empty child elements + return IDSSlice( + child_metadata, + child_elements, + full_path, + parent_array=self._parent_array, + virtual_shape=self._virtual_shape, + element_hierarchy=self._element_hierarchy, + ) + + # If matched_elements are IDSStructArray and we're accessing an + # attribute on them, we need to get that attribute from each + # array's elements + if isinstance(self._matched_elements[0], IDSStructArray): + # Accessing attribute on nested arrays: get attr from each + # array's elements + flattened_elements = [] + for array in child_elements: + # array is IDSStructArray, get attribute from its elements + for element in array: + flattened_elements.append(getattr(element, name)) + + # Keep track of grouping for shape preservation + child_sizes = [len(array) for array in child_elements] + + return IDSSlice( + child_metadata, + flattened_elements, + full_path, + parent_array=self._parent_array, + virtual_shape=self._virtual_shape + (None,), + element_hierarchy=self._element_hierarchy + [child_sizes], + ) + + if isinstance(child_elements[0], IDSStructArray): + # Children are IDSStructArray - track the new dimension + child_sizes = [len(arr) for arr in child_elements] + + # New virtual shape: current shape + new dimension + # Store actual sizes (may be ragged) - don't assume all are the same! + new_virtual_shape = self._virtual_shape + (None,) + new_hierarchy = self._element_hierarchy + [child_sizes] + + return IDSSlice( + child_metadata, + child_elements, + full_path, + parent_array=self._parent_array, + virtual_shape=new_virtual_shape, + element_hierarchy=new_hierarchy, + ) + elif isinstance(child_elements[0], IDSNumericArray): + # Children are IDSNumericArray - track the array dimension + # Each IDSNumericArray has a size (length of its data) + child_sizes = [len(arr) for arr in child_elements] + + # New virtual shape: current shape + new dimension + # Store actual sizes (may be ragged) - don't assume all are the same! + new_virtual_shape = self._virtual_shape + (None,) + new_hierarchy = self._element_hierarchy + [child_sizes] + + return IDSSlice( + child_metadata, + child_elements, + full_path, + parent_array=self._parent_array, + virtual_shape=new_virtual_shape, + element_hierarchy=new_hierarchy, + ) + else: + # Children are not arrays (structures or other primitives) + return IDSSlice( + child_metadata, + child_elements, + full_path, + parent_array=self._parent_array, + virtual_shape=self._virtual_shape, + element_hierarchy=self._element_hierarchy, + ) + + def __repr__(self) -> str: + """Build a string representation of this IDSSlice. + + Returns a string showing: + - The IDS type name (e.g., 'equilibrium') + - The full path including slice operations (e.g., 'profiles_1d[:].ion[:]') + - The number of matched elements + + Returns: + String representation like: + '' + """ + ids_name = self.metadata.ids_name + item_word = "item" if len(self) == 1 else "items" + return ( + f"<{type(self).__name__} (IDS:{ids_name}, {self._path} with " + f"{len(self)} {item_word})>" + ) + + def values(self) -> List[Any]: + """Extract raw values from elements in this slice. + + For IDSPrimitive elements, this extracts the wrapped value. + For other element types, returns them as-is. + + Returns a flat list of extracted values. This is useful for getting + the actual data without the IDS wrapper when accessing scalar fields + through a slice, without requiring explicit looping through the + original collection. + + For multi-dimensional access to values, use one of these approaches: + + - Use direct indexing: ``ids_obj[i1].collection[i2].value`` (best + performance and clarity) + - Use ``.to_array()`` if you need numpy array integration + + Returns: + List of raw Python/numpy values or unwrapped elements + + Examples: + Extract scalar values from a 1D slice:: + + # Get list of temperatures from all profiles + temps = core_profiles.profiles_1d[:].te.values() + + For multi-dimensional access, use direct indexing instead:: + + # Get a specific temperature (more efficient than slicing) + temp = core_profiles.profiles_1d[0].te.values()[5] + + # Or better yet, direct access + temp_value = core_profiles.profiles_1d[0].te[5] + + For converting to numpy arrays:: + + # Use to_array() for tensorization + array = core_profiles.profiles_1d[:].te.to_array() + """ + from imas.ids_primitive import IDSPrimitive + + result = [] + for element in self._matched_elements: + if isinstance(element, IDSPrimitive): + result.append(element.value) + else: + result.append(element) + return result + + def to_array(self) -> np.ndarray: + """Convert this slice to a numpy array - for leaf node slices only. + + This method converts a slice containing scalar or numeric array leaf nodes + to a regular numpy array with shape self.shape. It is designed for + tensorization of leaf nodes only (e.g., slices of FLT_1D, profiles, etc.). + + For multi-dimensional access to non-leaf nodes, use direct indexing instead: + ``ids[i1][i2]`` rather than slicing with ``.to_array()``. + + Returns: + numpy.ndarray with shape self.shape containing the extracted values. + + Raises: + ValueError: If slice refers to IDSStructure or IDSStructArray elements + (non-leaf nodes). Use direct indexing instead. + ValueError: If the data is ragged/non-rectangular (dimensions have + varying sizes). Use direct indexing or ``.values()`` instead. + ValueError: If values cannot be converted to numpy array. + + Examples: + Tensorize a 1D slice of numeric data:: + + # Works: leaf nodes are numeric arrays + array = core_profiles.profiles_1d[:].te.to_array() + # Shape: (n_profiles,) + + Multi-dimensional tensorization:: + + # Works: accessing leaf nodes from nested structure + array = core_profiles.profiles_1d[:].te.to_array() + # Shape: (n_profiles,) + + Direct indexing for non-leaf nodes:: + + # Don't do this - will raise ValueError + # array = core_profiles.profiles_1d[:].to_array() # ERROR! + + # Do this instead + profile = core_profiles.profiles_1d[0] # Direct access + te = profile.te.to_array() # Then tensorize + """ + from imas.ids_primitive import IDSPrimitive, IDSNumericArray + from imas.ids_struct_array import IDSStructArray + from imas.ids_structure import IDSStructure + + # Validate: slice must refer to leaf nodes only + if self._matched_elements: + first = self._matched_elements[0] + if isinstance(first, (IDSStructure, IDSStructArray)): + raise ValueError( + f"Cannot tensorize {type(first).__name__} slice - only " + f"works for leaf nodes (scalars, numeric arrays). Use " + f"direct indexing instead: ids[i][j] to access structures." + ) + + # Validate: data must be rectangular (not ragged) + if self.is_ragged: + raise ValueError( + "Cannot tensorize ragged array - dimensions have varying " + "sizes. Use .values() to get a flat list, or use direct " + "indexing for multi-dimensional access." + ) + + # Get the target shape (we validated it's not ragged) + actual_shape = self.shape + + # Handle empty slice + if len(self._matched_elements) == 0: + return np.empty(actual_shape, dtype=float) + + # Extract values from leaf nodes + flat_values = [] + for element in self._matched_elements: + if isinstance(element, IDSPrimitive): + flat_values.append(element.value) + elif isinstance(element, IDSNumericArray): + flat_values.append(element.value) + else: + flat_values.append(element) + + # Tensorize to target shape + arr = np.array(flat_values) + + # For 1D, no reshape needed + if len(actual_shape) == 1: + return arr + + # For multi-dimensional, reshape to target shape + try: + return arr.reshape(actual_shape) + except (ValueError, TypeError) as e: + raise ValueError( + f"Failed to convert slice to array with shape {actual_shape}: {e}" + ) + + @staticmethod + def _format_slice(slice_obj: slice) -> str: + """Format a slice object as a string. + + Args: + slice_obj: The slice object to format + + Returns: + String representation like "[1:5]", "[::2]", etc. + """ + start = slice_obj.start if slice_obj.start is not None else "" + stop = slice_obj.stop if slice_obj.stop is not None else "" + step = slice_obj.step if slice_obj.step is not None else "" + + if step: + return f"[{start}:{stop}:{step}]" + else: + return f"[{start}:{stop}]" diff --git a/imas/ids_struct_array.py b/imas/ids_struct_array.py index b1768649..41ed5094 100644 --- a/imas/ids_struct_array.py +++ b/imas/ids_struct_array.py @@ -121,12 +121,41 @@ def _element_structure(self): return struct def __getitem__(self, item): - # value is a list, so the given item should be convertable to integer - # TODO: perhaps we should allow slices as well? - list_idx = int(item) - if self._lazy: - self._load(item) - return self.value[list_idx] + """Get element(s) from the struct array. + + Args: + item: Integer index or slice object + + Returns: + A single IDSStructure if item is an int, or an IDSSlice if item is a slice + """ + if isinstance(item, slice): + from imas.ids_slice import IDSSlice + from imas.util import get_full_path + + if self._lazy: + # Use __getitem__ for each index to trigger proper lazy loading + matched_elements = [self[i] for i in range(*item.indices(len(self)))] + else: + # Direct slice for non-lazy case + matched_elements = self.value[item] + + slice_str = IDSSlice._format_slice(item) + # Build full path: parent path + this array name + slice + full_path = get_full_path(self) + slice_str + + return IDSSlice( + self.metadata, + matched_elements, + full_path, + parent_array=self, + ) + else: + # Handle integer index + list_idx = int(item) + if self._lazy: + self._load(item) + return self.value[list_idx] def __setitem__(self, item, value): # value is a list, so the given item should be convertable to integer diff --git a/imas/ids_toplevel.py b/imas/ids_toplevel.py index 947bf72f..fcda5f0d 100644 --- a/imas/ids_toplevel.py +++ b/imas/ids_toplevel.py @@ -22,7 +22,6 @@ IDS_TIME_MODE_INDEPENDENT, IDS_TIME_MODE_UNKNOWN, IDS_TIME_MODES, - needs_imas, ) from imas.ids_metadata import IDSMetadata, IDSType, get_toplevel_metadata from imas.ids_structure import IDSStructure @@ -99,7 +98,6 @@ def default_serializer_protocol(): """Return the default serializer protocol.""" return DEFAULT_SERIALIZER_PROTOCOL - @needs_imas def serialize(self, protocol=None) -> bytes: """Serialize this IDS to a data buffer. @@ -169,7 +167,6 @@ def serialize(self, protocol=None) -> bytes: return bytes(buffer) raise ValueError(f"Unrecognized serialization protocol: {protocol}") - @needs_imas def deserialize(self, data: bytes) -> None: """Deserialize the data buffer into this IDS. @@ -289,7 +286,6 @@ def _validate(self): for child in self.iter_nonempty_(accept_lazy=True): child._validate() - @needs_imas def get(self, occurrence: int = 0, db_entry: Optional["DBEntry"] = None) -> None: """Get data from AL backend storage format. @@ -300,7 +296,6 @@ def get(self, occurrence: int = 0, db_entry: Optional["DBEntry"] = None) -> None raise NotImplementedError() db_entry.get(self.metadata.name, occurrence, destination=self) - @needs_imas def getSlice( self, time_requested: float, @@ -323,7 +318,6 @@ def getSlice( destination=self, ) - @needs_imas def putSlice( self, occurrence: int = 0, db_entry: Optional["DBEntry"] = None ) -> None: @@ -336,7 +330,6 @@ def putSlice( raise NotImplementedError() db_entry.put_slice(self, occurrence) - @needs_imas def deleteData( self, occurrence: int = 0, db_entry: Optional["DBEntry"] = None ) -> None: @@ -349,7 +342,6 @@ def deleteData( raise NotImplementedError() db_entry.delete_data(self, occurrence) - @needs_imas def put(self, occurrence: int = 0, db_entry: Optional["DBEntry"] = None) -> None: """Put this IDS to the backend. diff --git a/imas/test/test_cli.py b/imas/test/test_cli.py index 0f4b305e..130aa287 100644 --- a/imas/test/test_cli.py +++ b/imas/test/test_cli.py @@ -17,7 +17,7 @@ def test_imas_version(): @pytest.mark.cli -def test_db_analysis(tmp_path, requires_imas): +def test_db_analysis(tmp_path): # This only tests the happy flow, error handling is not tested db_path = tmp_path / "test_db_analysis" with DBEntry(f"imas:hdf5?path={db_path}", "w") as entry: @@ -42,7 +42,7 @@ def test_db_analysis(tmp_path, requires_imas): @pytest.mark.cli -def test_db_analysis_csv(tmp_path, requires_imas): +def test_db_analysis_csv(tmp_path): with DBEntry(f"imas:hdf5?path={tmp_path}/entry1", "w") as entry: eq = entry.factory.equilibrium() eq.ids_properties.homogeneous_time = 2 diff --git a/imas/test/test_dbentry.py b/imas/test/test_dbentry.py index e13d82a4..f014eb9b 100644 --- a/imas/test/test_dbentry.py +++ b/imas/test/test_dbentry.py @@ -6,7 +6,7 @@ from imas.test.test_helpers import compare_children, open_dbentry -def test_dbentry_contextmanager(requires_imas): +def test_dbentry_contextmanager(): entry = imas.DBEntry(imas.ids_defs.MEMORY_BACKEND, "test", 1, 1) entry.create() ids = entry.factory.core_profiles() @@ -22,7 +22,7 @@ def test_dbentry_contextmanager(requires_imas): assert entry2._dbe_impl is None -def test_dbentry_contextmanager_uri(tmp_path, requires_imas): +def test_dbentry_contextmanager_uri(tmp_path): entry = imas.DBEntry(f"imas:ascii?path={tmp_path}/testdb", "w") ids = entry.factory.core_profiles() ids.ids_properties.homogeneous_time = 0 @@ -77,7 +77,7 @@ def test_dbentry_constructor(): assert get_entry_attrs(entry) == (1, 2, 3, 4, None, 6) -def test_ignore_unknown_dd_version(monkeypatch, worker_id, tmp_path, requires_imas): +def test_ignore_unknown_dd_version(monkeypatch, worker_id, tmp_path): entry = open_dbentry(imas.ids_defs.MEMORY_BACKEND, "w", worker_id, tmp_path) ids = entry.factory.core_profiles() ids.ids_properties.homogeneous_time = 0 diff --git a/imas/test/test_exception.py b/imas/test/test_exception.py index 37bebfce..c0b66230 100644 --- a/imas/test/test_exception.py +++ b/imas/test/test_exception.py @@ -4,7 +4,7 @@ from imas.backends.imas_core.imas_interface import ll_interface -def test_catch_al_exception(requires_imas): +def test_catch_al_exception(): # Do something which lets the lowlevel Cython interface throw an ALException # Ensure we can catch it: with pytest.raises(imas.exception.ALException): diff --git a/imas/test/test_ids_ascii_data.py b/imas/test/test_ids_ascii_data.py index d15fecf1..20ae8a66 100644 --- a/imas/test/test_ids_ascii_data.py +++ b/imas/test/test_ids_ascii_data.py @@ -18,7 +18,7 @@ def test_data_exists(): @pytest.fixture -def test_data(requires_imas): +def test_data(): db_entry = imas.training.get_training_db_entry() yield db_entry db_entry.close() diff --git a/imas/test/test_ids_slice.py b/imas/test/test_ids_slice.py new file mode 100644 index 00000000..cb5e7f1d --- /dev/null +++ b/imas/test/test_ids_slice.py @@ -0,0 +1,464 @@ +# This file is part of IMAS-Python. +# You should have received the IMAS-Python LICENSE file with this project. + +import numpy as np +import pytest + +from imas.ids_factory import IDSFactory +from imas.ids_slice import IDSSlice + + +@pytest.fixture +def wall_with_units(): + return create_wall_with_units() + + +@pytest.fixture +def wall_varying_sizes(): + return create_wall_with_units(total_units=2, element_counts=[4, 2]) + + +def create_wall_with_units( + total_units: int = 12, + element_counts=None, + *, + dd_version: str = "3.39.0", +): + + if total_units < 2: + raise ValueError("Need at least two units to exercise slice edge cases.") + + wall = IDSFactory(dd_version).wall() + wall.description_2d.resize(1) + + units = wall.description_2d[0].vessel.unit + units.resize(total_units) + + if element_counts is None: + element_counts = [4, 2] + [3] * (total_units - 2) + + element_counts = list(element_counts) + if len(element_counts) != total_units: + raise ValueError("element_counts length must match total_units.") + + for unit_idx, unit in enumerate(units): + unit.name = f"unit-{unit_idx}" + unit.element.resize(element_counts[unit_idx]) + for elem_idx, element in enumerate(unit.element): + element.name = f"element-{unit_idx}-{elem_idx}" + + return wall + + +def safe_element_lookup(units_slice, element_index: int): + collected = [] + skipped_units = [] + for idx, unit in enumerate(units_slice): + elements = unit.element + if element_index >= len(elements): + skipped_units.append(idx) + continue + collected.append(elements[element_index].name.value) + return {"collected": collected, "skipped_units": skipped_units} + + +class TestBasicSlicing: + + def test_slice_with_start_and_stop(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + result = cp.profiles_1d[3:7] + assert isinstance(result, IDSSlice) + assert len(result) == 4 + + result = cp.profiles_1d[::2] + assert isinstance(result, IDSSlice) + assert len(result) == 5 + + result = cp.profiles_1d[-5:] + assert isinstance(result, IDSSlice) + assert len(result) == 5 + + def test_slice_corner_cases(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + result = cp.profiles_1d[0:100] + assert len(result) == 10 + + result = cp.profiles_1d[10:20] + assert len(result) == 0 + + result = cp.profiles_1d[::-1] + assert len(result) == 10 + + def test_integer_index_still_works(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + result = cp.profiles_1d[5] + assert not isinstance(result, IDSSlice) + assert hasattr(result, "_path") + + +class TestIDSSlicePath: + + def test_slice_path_representation(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + result = cp.profiles_1d[5:8] + expected_path = "[5:8]" + assert expected_path in result._path + + result = cp.profiles_1d[5:8][1:3] + assert "[" in result._path + + def test_attribute_access_path(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit[8:] + + element_slice = units.element + assert "element" in element_slice._path + + +class TestIDSSliceIteration: + + def test_iteration_and_len(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + + slice_obj = cp.profiles_1d[1:4] + + items = list(slice_obj) + assert len(items) == 3 + + assert len(slice_obj) == 3 + + +class TestIDSSliceIndexing: + + def test_integer_indexing_slice(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + slice_obj = cp.profiles_1d[3:7] + # Integer indexing not supported on IDSSlice - use list() conversion instead + elements_list = list(slice_obj) + element = elements_list[1] + assert not isinstance(element, IDSSlice) + + def test_slice_indexing_slice(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + slice_obj = cp.profiles_1d[2:8] + nested_slice = slice_obj[1:4] + assert isinstance(nested_slice, IDSSlice) + assert len(nested_slice) == 3 + + +class TestIDSSliceAttributeAccess: + + def test_attribute_access_nested_attributes(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit[8:] + + names = units.name + assert isinstance(names, IDSSlice) + assert len(names) == 4 + + units_full = wall.description_2d[0].vessel.unit + elements = units_full[:].element + assert isinstance(elements, IDSSlice) + + +class TestIDSSliceRepr: + + def test_repr_count_display(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + slice_obj = cp.profiles_1d[5:6] + repr_str = repr(slice_obj) + assert "IDSSlice" in repr_str + assert "1 item" in repr_str + + slice_obj = cp.profiles_1d[5:8] + repr_str = repr(slice_obj) + assert "IDSSlice" in repr_str + assert "3 items" in repr_str + + +class TestWallExampleSlicing: + + def test_wall_units_nested_element_access(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit + + units_slice = units[8:] + assert isinstance(units_slice, IDSSlice) + assert len(units_slice) == 4 + + elements_slice = units_slice.element + assert isinstance(elements_slice, IDSSlice) + + +class TestEdgeCases: + + def test_slice_empty_array(self): + cp = IDSFactory("3.39.0").core_profiles() + + result = cp.profiles_1d[:] + assert isinstance(result, IDSSlice) + assert len(result) == 0 + + def test_slice_single_element(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(1) + + result = cp.profiles_1d[:] + assert isinstance(result, IDSSlice) + assert len(result) == 1 + + def test_invalid_step_zero(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + with pytest.raises(ValueError): + cp.profiles_1d[::0] + + +class TestFlatten: + + def test_flatten_basic_and_partial(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + + for profile in cp.profiles_1d: + profile.ion.resize(5) + + slice_obj = cp.profiles_1d[:].ion + flattened = slice_obj[:] + assert isinstance(flattened, IDSSlice) + assert len(flattened) == 15 + + cp2 = IDSFactory("3.39.0").core_profiles() + cp2.profiles_1d.resize(4) + for profile in cp2.profiles_1d: + profile.ion.resize(3) + flattened2 = cp2.profiles_1d[:2].ion[:] + assert len(flattened2) == 6 + + def test_flatten_empty_and_single(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(2) + empty_flattened = cp.profiles_1d[:].ion[:] + assert len(empty_flattened) == 0 + + cp2 = IDSFactory("3.39.0").core_profiles() + cp2.profiles_1d.resize(1) + cp2.profiles_1d[0].ion.resize(4) + single_flattened = cp2.profiles_1d[:].ion[:] + assert len(single_flattened) == 4 + + def test_flatten_indexing_and_slicing(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(2) + + for i, profile in enumerate(cp.profiles_1d): + profile.ion.resize(3) + for j, ion in enumerate(profile.ion): + ion.label = f"ion_{i}_{j}" + + flattened = cp.profiles_1d[:].ion[:] + + # Integer indexing not supported on IDSSlice - use values() method instead + flattened_list = list(flattened) + assert flattened_list[0].label == "ion_0_0" + assert flattened_list[3].label == "ion_1_0" + + subset = flattened[1:4] + assert isinstance(subset, IDSSlice) + assert len(subset) == 3 + labels = [ion.label for ion in subset] + assert labels == ["ion_0_1", "ion_0_2", "ion_1_0"] + + def test_flatten_repr_and_path(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(2) + for profile in cp.profiles_1d: + profile.ion.resize(2) + + flattened = cp.profiles_1d[:].ion[:] + repr_str = repr(flattened) + + assert "IDSSlice" in repr_str + assert "4 items" in repr_str + assert "[:]" in flattened._path + + def test_flatten_complex_case(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit[:5] + + all_elements = units.element[:] + assert len(all_elements) == 4 + 2 + 3 + 3 + 3 + + +class TestVaryingArraySizeIndexing: + + def test_unit_slice_element_integer_indexing(self, wall_varying_sizes): + units = wall_varying_sizes.description_2d[0].vessel.unit + units_slice = units[:2] + element_slice = units_slice.element + + # Integer indexing not supported on IDSSlice + # Use list() to check length instead + elements_list = list(element_slice) + assert len(elements_list) == 2 + # Access beyond available elements should be handled via list indexing + with pytest.raises(IndexError): + elements_list[2] + + def test_unit_slice_element_safe_indexing_scenarios(self, wall_varying_sizes): + units = wall_varying_sizes.description_2d[0].vessel.unit + units_slice = units[:2] + + result = safe_element_lookup(units_slice, 1) + assert len(result["collected"]) == 2 + assert result["collected"] == ["element-0-1", "element-1-1"] + + result = safe_element_lookup(units_slice, 2) + assert len(result["collected"]) == 1 + assert result["skipped_units"] == [1] + + result = safe_element_lookup(units_slice, 4) + assert len(result["collected"]) == 0 + assert result["skipped_units"] == [0, 1] + + def test_unit_slice_element_individual_access(self, wall_varying_sizes): + units = wall_varying_sizes.description_2d[0].vessel.unit + element_slice = units[:2].element + + # Access coordinate data from the flattened element arrays + arrays = list(element_slice) + assert len(arrays[0]) == 4 + assert arrays[0][2].name.value == "element-0-2" + + assert len(arrays[1]) == 2 + + with pytest.raises(IndexError): + arrays[1][2] + + def test_wall_with_diverse_element_counts(self): + wall = create_wall_with_units(total_units=5, element_counts=[3, 1, 4, 2, 5]) + + units = wall.description_2d[0].vessel.unit + units_slice = units[:3] + element_slice = units_slice.element + + # Access coordinate data from the flattened element arrays + arrays = list(element_slice) + assert len(arrays[0]) == 3 + assert len(arrays[2]) == 4 + + result = safe_element_lookup(units_slice, 2) + assert len(result["collected"]) == 2 + assert result["skipped_units"] == [1] + + +class TestIDSSliceValues: + + def test_values_basic_extraction(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit + + names_slice = units[:].name + names = names_slice.values() + + assert isinstance(names, list) + assert len(names) == 12 + assert all(isinstance(name, str) and name.startswith("unit-") for name in names) + assert names == [f"unit-{i}" for i in range(12)] + + def test_values_integer_and_float_extraction(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + + for profile in cp.profiles_1d: + profile.ion.resize(2) + for i, ion in enumerate(profile.ion): + ion.neutral_index = i + ion.z_ion = float(i + 1) + + ions = cp.profiles_1d[:].ion[:] + indices = ions[:].neutral_index.values() + assert all(isinstance(idx, (int, np.integer)) for idx in indices) + + z_values = ions[:].z_ion.values() + assert all(isinstance(z, (float, np.floating)) for z in z_values) + + def test_values_partial_and_empty_slices(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit + + names = units[:5].name.values() + assert len(names) == 5 + assert names == [f"unit-{i}" for i in range(5)] + + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + # Empty slices return empty values when accessing attributes + empty_values = cp.profiles_1d[5:10].time.values() + assert len(empty_values) == 0 + + def test_values_with_step_and_negative_indices(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit + + names_step = units[::2].name.values() + assert len(names_step) == 6 + assert names_step == [f"unit-{i}" for i in range(0, 12, 2)] + + names_neg = units[-3:].name.values() + assert len(names_neg) == 3 + assert names_neg == [f"unit-{i}" for i in range(9, 12)] + + def test_values_structure_preservation(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + + for profile in cp.profiles_1d: + profile.ion.resize(2) + + ions = cp.profiles_1d[:].ion[:].values() + + assert len(ions) == 6 + for ion in ions: + assert hasattr(ion, "_path") + from imas.ids_primitive import IDSPrimitive + + assert not isinstance(ion, IDSPrimitive) + + def test_values_array_primitives(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(2) + + cp.profiles_1d[0].grid.psi = np.linspace(0, 1, 10) + cp.profiles_1d[1].grid.psi = np.linspace(1, 2, 10) + + psi_values = cp.profiles_1d[:].grid.psi.values() + + assert len(psi_values) == 2 + assert all(isinstance(psi, np.ndarray) for psi in psi_values) + + def test_values_consistency_with_iteration(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit + + names_via_values = units[:5].name.values() + + names_via_iteration = [unit.name.value for unit in units[:5]] + + assert names_via_values == names_via_iteration diff --git a/imas/test/test_ids_struct_array.py b/imas/test/test_ids_struct_array.py index ab128dfa..8c31f221 100644 --- a/imas/test/test_ids_struct_array.py +++ b/imas/test/test_ids_struct_array.py @@ -87,3 +87,15 @@ def test_struct_array_eq(): assert cp1.profiles_1d != cp2.profiles_1d cp2.profiles_1d[0].time = 1 assert cp1.profiles_1d == cp2.profiles_1d + + +def test_struct_array_slice(): + cp1 = IDSFactory("3.39.0").core_profiles() + cp1.profiles_1d.resize(20) + + assert len(cp1.profiles_1d) == 20 + assert len(cp1.profiles_1d[:]) == 20 + assert len(cp1.profiles_1d[5:10]) == 5 + assert len(cp1.profiles_1d[10:]) == 10 + assert len(cp1.profiles_1d[:5]) == 5 + assert len(cp1.profiles_1d[::2]) == 10 diff --git a/imas/test/test_ids_toplevel.py b/imas/test/test_ids_toplevel.py index a5855817..e55bac4d 100644 --- a/imas/test/test_ids_toplevel.py +++ b/imas/test/test_ids_toplevel.py @@ -46,7 +46,7 @@ def test_pretty_print(ids): assert pprint.pformat(ids) == "" -def test_serialize_nondefault_dd_version(requires_imas): +def test_serialize_nondefault_dd_version(): ids = IDSFactory("3.31.0").core_profiles() fill_with_random_data(ids) data = ids.serialize() diff --git a/imas/test/test_ids_validate.py b/imas/test/test_ids_validate.py index 7970c7e2..c3f8f157 100644 --- a/imas/test/test_ids_validate.py +++ b/imas/test/test_ids_validate.py @@ -245,7 +245,7 @@ def test_validate_coordinate_same_as(): (None, True), ], ) -def test_validate_on_put(monkeypatch, env_value, should_validate, requires_imas): +def test_validate_on_put(monkeypatch, env_value, should_validate): dbentry = DBEntry(MEMORY_BACKEND, "test", 1, 1) dbentry.create() ids = dbentry.factory.core_profiles() diff --git a/imas/test/test_latest_dd_autofill.py b/imas/test/test_latest_dd_autofill.py index 6d34b766..6b7fbb6a 100644 --- a/imas/test/test_latest_dd_autofill.py +++ b/imas/test/test_latest_dd_autofill.py @@ -55,7 +55,7 @@ def test_latest_dd_autofill(ids_name, backend, worker_id, tmp_path): @pytest.mark.parametrize( "serializer", [ASCII_SERIALIZER_PROTOCOL, FLEXBUFFERS_SERIALIZER_PROTOCOL] ) -def test_latest_dd_autofill_serialize(serializer, ids_name, has_imas): +def test_latest_dd_autofill_serialize(serializer, ids_name): """Serialize and then deserialize again all IDSToplevels""" if serializer is None: pytest.skip("Unsupported serializer") @@ -64,8 +64,6 @@ def test_latest_dd_autofill_serialize(serializer, ids_name, has_imas): ids = factory.new(ids_name) fill_with_random_data(ids) - if not has_imas: - return # rest of the test requires an IMAS install data = ids.serialize(serializer) ids2 = factory.new(ids_name) diff --git a/imas/test/test_lazy_loading.py b/imas/test/test_lazy_loading.py index 4a7c65ca..1dcd0bff 100644 --- a/imas/test/test_lazy_loading.py +++ b/imas/test/test_lazy_loading.py @@ -94,7 +94,7 @@ def iterate(structure): dbentry.close() -def test_lazy_load_close_dbentry(requires_imas): +def test_lazy_load_close_dbentry(): dbentry = DBEntry(MEMORY_BACKEND, "ITER", 1, 1) dbentry.create() @@ -109,7 +109,7 @@ def test_lazy_load_close_dbentry(requires_imas): print(lazy_ids.time) -def test_lazy_load_readonly(requires_imas): +def test_lazy_load_readonly(): dbentry = DBEntry(MEMORY_BACKEND, "ITER", 1, 1) dbentry.create() run_lazy_load_readonly(dbentry) @@ -151,7 +151,7 @@ def run_lazy_load_readonly(dbentry): dbentry.close() -def test_lazy_load_no_put(requires_imas): +def test_lazy_load_no_put(): dbentry = DBEntry(MEMORY_BACKEND, "ITER", 1, 1) dbentry.create() @@ -169,7 +169,7 @@ def test_lazy_load_no_put(requires_imas): dbentry.close() -def test_lazy_load_with_new_aos(requires_imas): +def test_lazy_load_with_new_aos(): dbentry = DBEntry(MEMORY_BACKEND, "ITER", 1, 1, dd_version="3.30.0") dbentry.create() et = dbentry.factory.edge_transport() @@ -214,7 +214,7 @@ def test_lazy_load_with_new_aos_netcdf(tmp_path): assert len(lazy_et.model[0].ggd[0].electrons.particles.d_radial) == 0 -def test_lazy_load_with_new_structure(requires_imas): +def test_lazy_load_with_new_structure(): dbentry = DBEntry(MEMORY_BACKEND, "ITER", 1, 1, dd_version="3.30.0") dbentry.create() diff --git a/imas/test/test_multidim_slicing.py b/imas/test/test_multidim_slicing.py new file mode 100644 index 00000000..06114897 --- /dev/null +++ b/imas/test/test_multidim_slicing.py @@ -0,0 +1,374 @@ +# This file is part of IMAS-Python. +# You should have received the IMAS-Python LICENSE file with this project. +"""Tests for multi-dimensional slicing support in IDSSlice.""" + +import numpy as np +import pytest + +from imas.ids_factory import IDSFactory + + +class TestMultiDimSlicing: + """Shape tracking and conversion methods.""" + + def test_shape_property_single_level(self): + """Test shape property for single-level slice.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + result = cp.profiles_1d[:] + assert hasattr(result, "shape") + assert result.shape == (10,) + + def test_shape_property_two_level(self): + """Test shape property for 2D array access.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + assert result.shape == (5, 3) + + def test_shape_property_three_level(self): + """Test shape property for 3D nested structure.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p in cp.profiles_1d: + p.ion.resize(2) + for i in p.ion: + i.element.resize(2) + + result = cp.profiles_1d[:].ion[:].element[:] + assert result.shape == (3, 2, 2) + + def test_to_array_2d_regular(self): + """Test to_array() with regular 2D array.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for i, p in enumerate(cp.profiles_1d): + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + array = result.to_array() + + assert isinstance(array, np.ndarray) + assert array.shape == (5, 3) + assert np.allclose(array[0], [0.0, 0.5, 1.0]) + assert np.allclose(array[4], [0.0, 0.5, 1.0]) + + def test_to_array_3d_regular(self): + """Test to_array() with regular 3D array.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p in cp.profiles_1d: + p.ion.resize(2) + for i_idx, i in enumerate(p.ion): + i.element.resize(2) + for e_idx, e in enumerate(i.element): + e.z_n = float(e_idx) + + result = cp.profiles_1d[:].ion[:].element[:].z_n + array = result.to_array() + + assert isinstance(array, np.ndarray) + assert array.shape == (3, 2, 2) + assert np.allclose(array[0, 0, :], [0.0, 1.0]) + assert np.allclose(array[0, 1, :], [0.0, 1.0]) + + def test_to_array_variable_size(self): + """Test to_array() raises error for ragged arrays.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + cp.profiles_1d[0].grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + cp.profiles_1d[1].grid.rho_tor_norm = np.array([0.0, 0.25, 0.5, 0.75, 1.0]) + cp.profiles_1d[2].grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + + # to_array() should raise ValueError for ragged data + with pytest.raises(ValueError, match="Cannot tensorize ragged array"): + result.to_array() + + # But .values() should still work + values = result.values() + assert len(values) == 3 + assert len(values[0]) == 3 + assert len(values[1]) == 5 + assert len(values[2]) == 3 + + def test_enhanced_values_2d(self): + """Test enhanced values() method for 2D extraction.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p in cp.profiles_1d: + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + values = result.values() + + # Should be a list of 3 arrays + assert isinstance(values, list) + assert len(values) == 3 + for v in values: + assert isinstance(v, np.ndarray) + assert len(v) == 3 + + def test_enhanced_values_3d(self): + """Test enhanced values() method for 3D extraction.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(2) + for p in cp.profiles_1d: + p.ion.resize(2) + for i in p.ion: + i.element.resize(2) + for e_idx, e in enumerate(i.element): + e.z_n = float(e_idx) + + result = cp.profiles_1d[:].ion[:].element[:].z_n + values = result.values() + + assert isinstance(values, list) + assert len(values) == 8 # 2 profiles * 2 ions * 2 elements + + def test_slice_preserves_groups(self): + """Test that slicing preserves group structure.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + for p in cp.profiles_1d: + p.ion.resize(3) + + # Get all ions, then slice + result = cp.profiles_1d[:].ion[:] + + # Should still know the structure: 10 profiles, 3 ions each + assert result.shape == (10, 3) + assert len(result) == 30 # Flattened for iteration, but shape preserved + + def test_integer_index_not_supported(self): + """Test that integer indexing on IDSSlice is not supported.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for i, p in enumerate(cp.profiles_1d): + p.ion.resize(2) + for j, ion in enumerate(p.ion): + ion.label = f"ion_{i}_{j}" + + # Integer indexing on IDSSlice should raise TypeError + with pytest.raises(TypeError, match="Cannot index IDSSlice with integer"): + cp.profiles_1d[:].ion[0] + + # Show the correct alternatives + # Option 1: Direct indexing (recommended) + ion_0_from_first_profile = cp.profiles_1d[0].ion[:1] # Use slice, not int index + assert len(ion_0_from_first_profile) == 1 + + # Option 2: Convert to list + ions_list = list(cp.profiles_1d[:].ion) + ions_from_first_profile = ions_list[0] + assert len(ions_from_first_profile) == 2 + + # Option 3: Extract values + ions_values = cp.profiles_1d[:].ion.values() + first_profile_ions = ions_values[0] + assert len(first_profile_ions) == 2 + + def test_slice_on_nested_arrays(self): + """Test slicing on nested arrays.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.ion.resize(4) + + # Get first 2 ions from each profile + result = cp.profiles_1d[:].ion[:2] + + assert result.shape == (5, 2) + assert len(result) == 10 # 5 profiles * 2 ions each + + def test_step_slicing_on_nested(self): + """Test step slicing on nested structures.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.ion.resize(6) + + # Get every other ion + result = cp.profiles_1d[:].ion[::2] + + assert result.shape == (5, 3) # 5 profiles, 3 ions each (0, 2, 4) + assert len(result) == 15 + + def test_negative_indexing_not_supported(self): + """Test that negative integer indexing on IDSSlice is not supported.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.ion.resize(3) + for j, ion in enumerate(p.ion): + ion.label = f"ion_{j}" + + # Negative integer indexing on IDSSlice should raise TypeError + with pytest.raises(TypeError, match="Cannot index IDSSlice with integer"): + cp.profiles_1d[:].ion[-1] + + # Show the correct alternative: use slice instead + # Get last ion from each profile using slice + result = cp.profiles_1d[:].ion[2:3] # Get last element with slice + assert result.shape == (5, 1) + + # Or better: direct indexing + last_ions = [p.ion[-1] for p in cp.profiles_1d] + assert len(last_ions) == 5 + assert all(ion.label == "ion_2" for ion in last_ions) + + def test_to_array_grouped_structure(self): + """Test that to_array preserves grouped structure.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p_idx, p in enumerate(cp.profiles_1d): + p.ion.resize(2) + for i_idx, i in enumerate(p.ion): + i.z_ion = float(p_idx * 10 + i_idx) + + result = cp.profiles_1d[:].ion[:].z_ion + array = result.to_array() + + # Should be (3, 2) array + assert array.shape == (3, 2) + assert array[0, 0] == 0.0 + assert array[1, 0] == 10.0 + assert array[2, 1] == 21.0 + + @pytest.mark.skip(reason="Phase 3 feature - boolean indexing not yet implemented") + def test_boolean_indexing_simple(self): + """Test boolean indexing on slices.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for i, p in enumerate(cp.profiles_1d): + p.electrons.density = np.array([float(i)] * 5) + + result = cp.profiles_1d[:].electrons.density + + mask = np.array([True, False, True, False, True]) + filtered = result[mask] + assert len(filtered) == 3 + + def test_assignment_on_slice(self): + """Test assignment through slices.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p in cp.profiles_1d: + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + # This requires assignment support + # cp.profiles_1d[:].grid.rho_tor_norm[:] = new_values + # For now, verify slicing works for reading + + result = cp.profiles_1d[:].grid.rho_tor_norm + array = result.to_array() + assert array.shape == (3, 3) + + def test_xarray_integration_compatible(self): + """Test that output is compatible with xarray.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + cp.time = np.array([1.0, 2.0, 3.0]) + + for i, p in enumerate(cp.profiles_1d): + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + p.electrons.temperature = np.array([1.0, 2.0, 3.0]) * (i + 1) + + # Test that we can extract values in xarray-compatible format + temps = cp.profiles_1d[:].electrons.temperature.to_array() + times = cp.time + + assert temps.shape == (3, 3) + assert len(times) == 3 + + def test_performance_large_hierarchy(self): + """Test performance with large nested hierarchies.""" + cp = IDSFactory("3.39.0").core_profiles() + n_profiles = 50 + cp.profiles_1d.resize(n_profiles) + + for p in cp.profiles_1d: + p.grid.rho_tor_norm = np.linspace(0, 1, 100) + p.ion.resize(5) + for i in p.ion: + i.element.resize(3) + + # Should handle large data without significant slowdown + result = cp.profiles_1d[:].grid.rho_tor_norm + array = result.to_array() + + assert array.shape == (n_profiles, 100) + + def test_lazy_loading_with_multidim(self): + """Test that lazy loading works with multi-dimensional slicing.""" + # This would require a database, so we'll test with in-memory + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + + # Verify IDSSlice attributes are preserved + assert hasattr(result, "_parent_array") + assert hasattr(result, "_matched_elements") + assert len(result._matched_elements) > 0 + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_slice(self): + """Test slicing that results in empty arrays.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.ion.resize(0) + + result = cp.profiles_1d[:].ion + assert len(result) == 5 + for ions in result: + # Each should be empty + pass + + def test_single_element_2d(self): + """Test 2D extraction with single element.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(1) + cp.profiles_1d[0].grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + assert result.shape == (1, 3) + + def test_single_dimension_value(self): + """Test accessing scalar values from nested structures.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p in cp.profiles_1d: + p.ion.resize(2) + for i in p.ion: + i.z_ion = 1.0 + + # Use slice notation instead of integer indexing + result = cp.profiles_1d[:].ion[:1].z_ion # Get first ion only + + # Should be 3 items (one per profile, one ion per profile) + assert len(result) == 3 + + def test_slice_of_slice(self): + """Test slicing a slice.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + for p in cp.profiles_1d: + p.ion.resize(3) + + result1 = cp.profiles_1d[::2].ion # Every other profile's ions + assert result1.shape == (5, 3) + + result2 = result1[:2] # First 2 from each + assert result2.shape == (5, 2) diff --git a/imas/test/test_nbc_change.py b/imas/test/test_nbc_change.py index 91ede0e3..b34949df 100644 --- a/imas/test/test_nbc_change.py +++ b/imas/test/test_nbc_change.py @@ -49,7 +49,7 @@ def test_nbc_structure_to_aos(caplog): assert caplog.record_tuples[0][:2] == ("imas.ids_convert", logging.WARNING) -def test_nbc_0d_to_1d(caplog, requires_imas): +def test_nbc_0d_to_1d(caplog): # channel/filter_spectrometer/radiance_calibration in spectrometer visible changed # from FLT_0D to FLT_1D in DD 3.39.0 ids = IDSFactory("3.32.0").spectrometer_visible() diff --git a/imas/test/test_snippets.py b/imas/test/test_snippets.py index 0574b185..8ed49a83 100644 --- a/imas/test/test_snippets.py +++ b/imas/test/test_snippets.py @@ -13,7 +13,7 @@ @pytest.mark.skip(reason="skipping hli test") @pytest.mark.filterwarnings("ignore:The input coordinates to pcolormesh:UserWarning") @pytest.mark.parametrize("snippet", course_snippets) -def test_script_execution(snippet, monkeypatch, tmp_path, requires_imas): +def test_script_execution(snippet, monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) # Prevent showing plots in a GUI monkeypatch.delenv("DISPLAY", raising=False) diff --git a/imas/test/test_static_ids.py b/imas/test/test_static_ids.py index 2c66811d..05133615 100644 --- a/imas/test/test_static_ids.py +++ b/imas/test/test_static_ids.py @@ -21,7 +21,7 @@ def test_ids_valid_type(): assert ids_types in ({IDSType.NONE}, {IDSType.CONSTANT, IDSType.DYNAMIC}) -def test_constant_ids(caplog, requires_imas): +def test_constant_ids(caplog): ids = imas.IDSFactory().new("amns_data") if ids.metadata.type is IDSType.NONE: pytest.skip("IDS definition has no constant IDSs") diff --git a/imas/test/test_to_xarray.py b/imas/test/test_to_xarray.py index 1767a6d9..a5df6a1e 100644 --- a/imas/test/test_to_xarray.py +++ b/imas/test/test_to_xarray.py @@ -9,7 +9,7 @@ @pytest.fixture -def entry(requires_imas, monkeypatch): +def entry(monkeypatch): monkeypatch.setenv("IMAS_VERSION", "3.39.0") # Use fixed DD version return imas.training.get_training_db_entry() diff --git a/imas/test/test_util.py b/imas/test/test_util.py index 15a2a8c0..1834af9c 100644 --- a/imas/test/test_util.py +++ b/imas/test/test_util.py @@ -54,7 +54,7 @@ def test_inspect(): inspect(cp.profiles_1d[1].grid.rho_tor_norm) # IDSPrimitive -def test_inspect_lazy(requires_imas): +def test_inspect_lazy(): with get_training_db_entry() as entry: cp = entry.get("core_profiles", lazy=True) inspect(cp) @@ -141,7 +141,7 @@ def test_idsdiffgen(): assert diff[0] == ("profiles_1d/time", -1, 0) -def test_idsdiff(requires_imas): +def test_idsdiff(): # Test the diff rendering for two sample IDSs with get_training_db_entry() as entry: imas.util.idsdiff(entry.get("core_profiles"), entry.get("equilibrium")) @@ -179,7 +179,7 @@ def test_get_toplevel(): assert get_toplevel(cp) is cp -def test_is_lazy_loaded(requires_imas): +def test_is_lazy_loaded(): with get_training_db_entry() as entry: assert is_lazy_loaded(entry.get("core_profiles")) is False assert is_lazy_loaded(entry.get("core_profiles", lazy=True)) is True