diff --git a/src/oceanum/datamesh/cache.py b/src/oceanum/datamesh/cache.py index 031badb..f7047d2 100644 --- a/src/oceanum/datamesh/cache.py +++ b/src/oceanum/datamesh/cache.py @@ -5,6 +5,7 @@ import xarray as xr import pandas as pd import geopandas as gpd +from zarr.storage import ZipStore from .query import Query @@ -51,14 +52,15 @@ def unlock(self, query): def _get(self, query): cache_file = self._cachepath(query) try: - if os.path.exists(cache_file + ".nc"): + if os.path.exists(cache_file + ".zarr.zip"): if ( - os.path.getmtime(cache_file + ".nc") + self.cache_timeout + os.path.getmtime(cache_file + ".zarr.zip") + self.cache_timeout < time.time() ): - os.remove(cache_file + ".nc") + os.remove(cache_file + ".zarr.zip") return None - return xr.open_dataset(cache_file + ".nc") + with ZipStore(cache_file + ".zarr.zip") as store: + return xr.open_zarr(store, consolidated=True).load() elif os.path.exists(cache_file + ".gpq"): if ( os.path.getmtime(cache_file + ".gpq") + self.cache_timeout @@ -95,7 +97,8 @@ def copy(self, query, fname, ext): def put(self, query, data): cache_file = self._cachepath(query) if isinstance(data, xr.Dataset): - data.to_netcdf(cache_file + ".nc") + with ZipStore(cache_file + ".zarr.zip") as store: + data.to_zarr(store, consolidated=True) elif isinstance(data, gpd.GeoDataFrame): data.to_parquet(cache_file + ".gpq") elif isinstance(data, pd.DataFrame): diff --git a/src/oceanum/datamesh/connection.py b/src/oceanum/datamesh/connection.py index b837e5a..2da10c9 100644 --- a/src/oceanum/datamesh/connection.py +++ b/src/oceanum/datamesh/connection.py @@ -25,6 +25,7 @@ import numbers import urllib3 from pydantic import ValidationError +from zarr.storage import ZipStore from .datasource import Datasource from .catalog import Catalog @@ -354,7 +355,7 @@ def _query(self, query, use_dask=False, cache_timeout=0, retry=0): if cache_timeout: localcache.lock(query) transfer_format = ( - "application/x-netcdf4" + "application/zarr+zip" if stage.container == Container.Dataset else "application/parquet" ) @@ -392,10 +393,13 @@ def _query(self, query, use_dask=False, cache_timeout=0, retry=0): f.write(resp.content) f.seek(0) if stage.container == Container.Dataset: - ds = xarray.load_dataset( - f.name, decode_coords="all", mask_and_scale=True - ) - ext = ".nc" + with ZipStore(f.name) as store: + ds = xarray.open_zarr( + store, + decode_coords="all", + mask_and_scale=True, + ).load() + ext = ".zarr.zip" elif stage.container == Container.GeoDataFrame: ds = geopandas.read_parquet(f.name) ext = ".gpq" diff --git a/tests/test_cache.py b/tests/test_cache.py index 0c86779..e6a09a3 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -177,7 +177,7 @@ def test_put_and_get_xarray_dataset( ): """Test storing and retrieving xarray Dataset.""" cache.put(sample_query, sample_xarray_dataset) - cached_file = cache._cachepath(sample_query) + ".nc" + cached_file = cache._cachepath(sample_query) + ".zarr.zip" assert os.path.exists(cached_file) retrieved = cache.get(sample_query) @@ -260,7 +260,7 @@ def test_get_with_lock_timeout(self, cache, sample_query): def test_get_corrupted_cache_returns_none(self, cache, sample_query): """Test that corrupted cache file returns None.""" - cached_file = cache._cachepath(sample_query) + ".nc" + cached_file = cache._cachepath(sample_query) + ".zarr.zip" # Create a corrupted file with open(cached_file, "w") as f: @@ -275,7 +275,7 @@ class TestCopy: def test_copy_moves_file(self, cache, sample_query, temp_cache_dir): """Test that copy moves file to cache location.""" - source_file = os.path.join(temp_cache_dir, "source.nc") + source_file = os.path.join(temp_cache_dir, "source.zarr.zip") # Create a source file with open(source_file, "w") as f: @@ -283,9 +283,9 @@ def test_copy_moves_file(self, cache, sample_query, temp_cache_dir): assert os.path.exists(source_file) - cache.copy(sample_query, source_file, ".nc") + cache.copy(sample_query, source_file, ".zarr.zip") - cached_file = cache._cachepath(sample_query) + ".nc" + cached_file = cache._cachepath(sample_query) + ".zarr.zip" assert os.path.exists(cached_file) assert not os.path.exists(source_file) # Original should be moved diff --git a/tests/test_datamesh_query.py b/tests/test_datamesh_query.py index fd00bd6..0d42101 100644 --- a/tests/test_datamesh_query.py +++ b/tests/test_datamesh_query.py @@ -48,7 +48,7 @@ def test_query_table(conn): def test_query_table_cache(conn): q = Query(**{"datasource": "oceanum-sea-level-rise"}) cache = LocalCache(cache_timeout=600) - cached_file = cache._cachepath(q) + ".nc" + cached_file = cache._cachepath(q) + ".zarr.zip" if os.path.exists(cached_file): os.remove(cached_file) ds0 = conn.query(q, cache_timeout=600) @@ -91,7 +91,7 @@ def test_query_dataset_cache(conn): ) cache = LocalCache(cache_timeout=600) - cached_file = cache._cachepath(q) + ".nc" + cached_file = cache._cachepath(q) + ".zarr.zip" if os.path.exists(cached_file): os.remove(cached_file) ds0 = conn.query(q, use_dask=False, cache_timeout=600)