diff --git a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py index b26e12085..a862b23ba 100644 --- a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py +++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py @@ -22,7 +22,7 @@ import pytest from datafusion import SessionContext, Table from datafusion.catalog import Schema -from datafusion_ffi_example import MyCatalogProvider +from datafusion_ffi_example import MyCatalogProvider, MyCatalogProviderList def create_test_dataset() -> Table: @@ -35,6 +35,30 @@ def create_test_dataset() -> Table: return Table(dataset) +@pytest.mark.parametrize("inner_capsule", [True, False]) +def test_ffi_catalog_provider_list(inner_capsule: bool) -> None: + """Test basic FFI CatalogProviderList functionality.""" + ctx = SessionContext() + + # Register FFI catalog + catalog_provider_list = MyCatalogProviderList() + if inner_capsule: + catalog_provider_list = ( + catalog_provider_list.__datafusion_catalog_provider_list__(ctx) + ) + + ctx.register_catalog_provider_list(catalog_provider_list) + + # Verify the catalog exists + catalog = ctx.catalog("auto_ffi_catalog") + schema_names = catalog.names() + assert "my_schema" in schema_names + + ctx.register_catalog_provider("second", MyCatalogProvider()) + + assert ctx.catalog_names() == {"auto_ffi_catalog", "second"} + + @pytest.mark.parametrize("inner_capsule", [True, False]) def test_ffi_catalog_provider_basic(inner_capsule: bool) -> None: """Test basic FFI CatalogProvider functionality.""" diff --git a/examples/datafusion-ffi-example/src/catalog_provider.rs b/examples/datafusion-ffi-example/src/catalog_provider.rs index 570222748..aee23602f 100644 --- a/examples/datafusion-ffi-example/src/catalog_provider.rs +++ b/examples/datafusion-ffi-example/src/catalog_provider.rs @@ -22,11 +22,12 @@ use std::sync::Arc; use arrow::datatypes::Schema; use async_trait::async_trait; use datafusion_catalog::{ - CatalogProvider, MemTable, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, - TableProvider, + CatalogProvider, CatalogProviderList, MemTable, MemoryCatalogProvider, + MemoryCatalogProviderList, MemorySchemaProvider, SchemaProvider, TableProvider, }; use datafusion_common::error::{DataFusionError, Result}; use datafusion_ffi::catalog_provider::FFI_CatalogProvider; +use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList; use datafusion_ffi::schema_provider::FFI_SchemaProvider; use pyo3::types::PyCapsule; use pyo3::{pyclass, pymethods, Bound, PyAny, PyResult, Python}; @@ -203,3 +204,67 @@ impl MyCatalogProvider { PyCapsule::new(py, provider, Some(name)) } } + +/// This catalog provider list is intended only for unit tests. +/// It pre-populates with a single catalog. +#[pyclass( + name = "MyCatalogProviderList", + module = "datafusion_ffi_example", + subclass +)] +#[derive(Debug, Clone)] +pub(crate) struct MyCatalogProviderList { + inner: Arc, +} + +impl CatalogProviderList for MyCatalogProviderList { + fn as_any(&self) -> &dyn Any { + self + } + + fn catalog_names(&self) -> Vec { + self.inner.catalog_names() + } + + fn catalog(&self, name: &str) -> Option> { + self.inner.catalog(name) + } + + fn register_catalog( + &self, + name: String, + catalog: Arc, + ) -> Option> { + self.inner.register_catalog(name, catalog) + } +} + +#[pymethods] +impl MyCatalogProviderList { + #[new] + pub fn new() -> PyResult { + let inner = Arc::new(MemoryCatalogProviderList::new()); + + inner.register_catalog( + "auto_ffi_catalog".to_owned(), + Arc::new(MyCatalogProvider::new()?), + ); + + Ok(Self { inner }) + } + + pub fn __datafusion_catalog_provider_list__<'py>( + &self, + py: Python<'py>, + session: Bound, + ) -> PyResult> { + let name = cr"datafusion_catalog_provider_list".into(); + + let provider = Arc::clone(&self.inner) as Arc; + + let codec = ffi_logical_codec_from_pycapsule(session)?; + let provider = FFI_CatalogProviderList::new_with_ffi_codec(provider, None, codec); + + PyCapsule::new(py, provider, Some(name)) + } +} diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index 005d8b80a..6c64c9fe5 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -18,7 +18,7 @@ use pyo3::prelude::*; use crate::aggregate_udf::MySumUDF; -use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider}; +use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogProviderList}; use crate::scalar_udf::IsNullUDF; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; @@ -37,6 +37,7 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 16c3ccc2a..9bb39df21 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -38,13 +38,61 @@ __all__ = [ "Catalog", + "CatalogList", "CatalogProvider", + "CatalogProviderList", "Schema", "SchemaProvider", "Table", ] +class CatalogList: + """DataFusion data catalog list.""" + + def __init__(self, catalog_list: df_internal.catalog.RawCatalogList) -> None: + """This constructor is not typically called by the end user.""" + self.catalog_list = catalog_list + + def __repr__(self) -> str: + """Print a string representation of the catalog list.""" + return self.catalog_list.__repr__() + + def names(self) -> set[str]: + """This is an alias for `catalog_names`.""" + return self.catalog_names() + + def catalog_names(self) -> set[str]: + """Returns the list of schemas in this catalog.""" + return self.catalog_list.catalog_names() + + @staticmethod + def memory_catalog(ctx: SessionContext | None = None) -> CatalogList: + """Create an in-memory catalog provider list.""" + catalog_list = df_internal.catalog.RawCatalogList.memory_catalog(ctx) + return CatalogList(catalog_list) + + def catalog(self, name: str = "datafusion") -> Schema: + """Returns the catalog with the given ``name`` from this catalog.""" + catalog = self.catalog_list.catalog(name) + + return ( + Catalog(catalog) + if isinstance(catalog, df_internal.catalog.RawCatalog) + else catalog + ) + + def register_catalog( + self, + name: str, + catalog: Catalog | CatalogProvider | CatalogProviderExportable, + ) -> Catalog | None: + """Register a catalog with this catalog list.""" + if isinstance(catalog, Catalog): + return self.catalog_list.register_catalog(name, catalog.catalog) + return self.catalog_list.register_catalog(name, catalog) + + class Catalog: """DataFusion data catalog.""" @@ -195,6 +243,38 @@ def kind(self) -> str: return self._inner.kind +class CatalogProviderList(ABC): + """Abstract class for defining a Python based Catalog Provider List.""" + + @abstractmethod + def catalog_names(self) -> set[str]: + """Set of the names of all catalogs in this catalog list.""" + ... + + @abstractmethod + def catalog(self, name: str) -> Catalog | None: + """Retrieve a specific catalog from this catalog list.""" + ... + + def register_catalog( # noqa: B027 + self, name: str, catalog: CatalogProviderExportable | CatalogProvider | Catalog + ) -> None: + """Add a catalog to this catalog list. + + This method is optional. If your catalog provides a fixed list of catalogs, you + do not need to implement this method. + """ + + +class CatalogProviderListExportable(Protocol): + """Type hint for object that has __datafusion_catalog_provider_list__ PyCapsule. + + https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProviderList.html + """ + + def __datafusion_catalog_provider_list__(self, session: Any) -> object: ... + + class CatalogProvider(ABC): """Abstract class for defining a Python based Catalog Provider.""" @@ -229,6 +309,15 @@ def deregister_schema(self, name: str, cascade: bool) -> None: # noqa: B027 """ +class CatalogProviderExportable(Protocol): + """Type hint for object that has __datafusion_catalog_provider__ PyCapsule. + + https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html + """ + + def __datafusion_catalog_provider__(self, session: Any) -> object: ... + + class SchemaProvider(ABC): """Abstract class for defining a Python based Schema Provider.""" diff --git a/python/datafusion/context.py b/python/datafusion/context.py index be647feff..19722e100 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -31,7 +31,13 @@ import pyarrow as pa -from datafusion.catalog import Catalog +from datafusion.catalog import ( + Catalog, + CatalogList, + CatalogProviderExportable, + CatalogProviderList, + CatalogProviderListExportable, +) from datafusion.dataframe import DataFrame from datafusion.expr import sort_list_to_raw_sort_list from datafusion.record_batch import RecordBatchStream @@ -91,15 +97,6 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self, session: Any) -> object: ... # noqa: D105 -class CatalogProviderExportable(Protocol): - """Type hint for object that has __datafusion_catalog_provider__ PyCapsule. - - https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html - """ - - def __datafusion_catalog_provider__(self, session: Any) -> object: ... # noqa: D105 - - class SessionConfig: """Session configuration options.""" @@ -832,6 +829,16 @@ def catalog_names(self) -> set[str]: """Returns the list of catalogs in this context.""" return self.ctx.catalog_names() + def register_catalog_provider_list( + self, + provider: CatalogProviderListExportable | CatalogProviderList | CatalogList, + ) -> None: + """Register a catalog provider list.""" + if isinstance(provider, CatalogList): + self.ctx.register_catalog_provider_list(provider.catalog) + else: + self.ctx.register_catalog_provider_list(provider) + def register_catalog_provider( self, name: str, provider: CatalogProviderExportable | CatalogProvider | Catalog ) -> None: diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 08f494dee..a0c9bb920 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -16,11 +16,16 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + import datafusion as dfn import pyarrow as pa import pyarrow.dataset as ds import pytest -from datafusion import SessionContext, Table, udtf +from datafusion import Catalog, SessionContext, Table, udtf + +if TYPE_CHECKING: + from datafusion.catalog import CatalogProvider, CatalogProviderExportable # Note we take in `database` as a variable even though we don't use @@ -93,6 +98,34 @@ def deregister_schema(self, name, cascade: bool): del self.schemas[name] +class CustomCatalogProviderList(dfn.catalog.CatalogProviderList): + def __init__(self): + self.catalogs = {"my_catalog": CustomCatalogProvider()} + + def catalog_names(self) -> set[str]: + return set(self.catalogs.keys()) + + def catalog(self, name: str) -> Catalog | None: + return self.catalogs[name] + + def register_catalog( + self, name: str, catalog: CatalogProviderExportable | CatalogProvider | Catalog + ) -> None: + self.catalogs[name] = catalog + + +def test_python_catalog_provider_list(ctx: SessionContext): + ctx.register_catalog_provider_list(CustomCatalogProviderList()) + + # Ensure `datafusion` catalog does not exist since + # we replaced the catalog list + assert ctx.catalog_names() == {"my_catalog"} + + # Ensure registering works + ctx.register_catalog_provider("second_catalog", CustomCatalogProvider()) + assert ctx.catalog_names() == {"my_catalog", "second_catalog"} + + def test_python_catalog_provider(ctx: SessionContext): ctx.register_catalog_provider("my_catalog", CustomCatalogProvider()) diff --git a/src/catalog.rs b/src/catalog.rs index 10ca1dd12..b5b983970 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -21,10 +21,12 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::catalog::{ - CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, + CatalogProvider, CatalogProviderList, MemoryCatalogProvider, MemoryCatalogProviderList, + MemorySchemaProvider, SchemaProvider, }; use datafusion::common::DataFusionError; use datafusion::datasource::TableProvider; +use datafusion_ffi::catalog_provider::FFI_CatalogProvider; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::schema_provider::FFI_SchemaProvider; use pyo3::exceptions::PyKeyError; @@ -40,6 +42,18 @@ use crate::utils::{ wait_for_future, }; +#[pyclass( + frozen, + name = "RawCatalogList", + module = "datafusion.catalog", + subclass +)] +#[derive(Clone)] +pub struct PyCatalogList { + pub catalog_list: Arc, + codec: Arc, +} + #[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)] #[derive(Clone)] pub struct PyCatalog { @@ -72,6 +86,77 @@ impl PySchema { } } +#[pymethods] +impl PyCatalogList { + #[new] + pub fn new( + py: Python, + catalog_list: Py, + session: Option>, + ) -> PyResult { + let codec = extract_logical_extension_codec(py, session)?; + let catalog_list = Arc::new(RustWrappedPyCatalogProviderList::new( + catalog_list, + codec.clone(), + )) as Arc; + Ok(Self { + catalog_list, + codec, + }) + } + + #[staticmethod] + pub fn memory_catalog_list(py: Python, session: Option>) -> PyResult { + let codec = extract_logical_extension_codec(py, session)?; + let catalog_list = + Arc::new(MemoryCatalogProviderList::default()) as Arc; + Ok(Self { + catalog_list, + codec, + }) + } + + pub fn catalog_names(&self) -> HashSet { + self.catalog_list.catalog_names().into_iter().collect() + } + + #[pyo3(signature = (name="public"))] + pub fn catalog(&self, name: &str) -> PyResult> { + let catalog = self + .catalog_list + .catalog(name) + .ok_or(PyKeyError::new_err(format!( + "Schema with name {name} doesn't exist." + )))?; + + Python::attach(|py| { + match catalog + .as_any() + .downcast_ref::() + { + Some(wrapped_catalog) => Ok(wrapped_catalog.catalog_provider.clone_ref(py)), + None => PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py), + } + }) + } + + pub fn register_catalog(&self, name: &str, catalog_provider: Bound<'_, PyAny>) -> PyResult<()> { + let provider = extract_catalog_provider_from_pyobj(catalog_provider, self.codec.as_ref())?; + + let _ = self + .catalog_list + .register_catalog(name.to_owned(), provider); + + Ok(()) + } + + pub fn __repr__(&self) -> PyResult { + let mut names: Vec = self.catalog_names().into_iter().collect(); + names.sort(); + Ok(format!("CatalogList(catalog_names=[{}])", names.join(", "))) + } +} + #[pymethods] impl PyCatalog { #[new] @@ -373,8 +458,9 @@ impl CatalogProvider for RustWrappedPyCatalogProvider { Python::attach(|py| { let provider = self.catalog_provider.bind(py); provider - .getattr("schema_names") - .and_then(|names| names.extract::>()) + .call_method0("schema_names") + .and_then(|names| names.extract::>()) + .map(|names| names.into_iter().collect()) .unwrap_or_else(|err| { log::error!("Unable to get schema_names: {err}"); Vec::default() @@ -442,6 +528,138 @@ impl CatalogProvider for RustWrappedPyCatalogProvider { } } +#[derive(Debug)] +pub(crate) struct RustWrappedPyCatalogProviderList { + pub(crate) catalog_provider_list: Py, + codec: Arc, +} + +impl RustWrappedPyCatalogProviderList { + pub fn new(catalog_provider_list: Py, codec: Arc) -> Self { + Self { + catalog_provider_list, + codec, + } + } + + fn catalog_inner(&self, name: &str) -> PyResult>> { + Python::attach(|py| { + let provider = self.catalog_provider_list.bind(py); + + let py_schema = provider.call_method1("catalog", (name,))?; + if py_schema.is_none() { + return Ok(None); + } + + extract_catalog_provider_from_pyobj(py_schema, self.codec.as_ref()).map(Some) + }) + } +} + +#[async_trait] +impl CatalogProviderList for RustWrappedPyCatalogProviderList { + fn as_any(&self) -> &dyn Any { + self + } + + fn catalog_names(&self) -> Vec { + Python::attach(|py| { + let provider = self.catalog_provider_list.bind(py); + provider + .call_method0("catalog_names") + .and_then(|names| names.extract::>()) + .map(|names| names.into_iter().collect()) + .unwrap_or_else(|err| { + log::error!("Unable to get catalog_names: {err}"); + Vec::default() + }) + }) + } + + fn catalog(&self, name: &str) -> Option> { + self.catalog_inner(name).unwrap_or_else(|err| { + log::error!("CatalogProvider catalog returned error: {err}"); + None + }) + } + + fn register_catalog( + &self, + name: String, + catalog: Arc, + ) -> Option> { + Python::attach(|py| { + let py_catalog = match catalog + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => wrapped_schema.catalog_provider.as_any().clone_ref(py), + None => { + match PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py) { + Ok(c) => c, + Err(err) => { + log::error!( + "register_catalog returned error during conversion to PyAny: {err}" + ); + return None; + } + } + } + }; + + let provider = self.catalog_provider_list.bind(py); + let catalog = match provider.call_method1("register_catalog", (name, py_catalog)) { + Ok(c) => c, + Err(err) => { + log::error!("register_catalog returned error: {err}"); + return None; + } + }; + if catalog.is_none() { + return None; + } + + let catalog = Arc::new(RustWrappedPyCatalogProvider::new( + catalog.into(), + self.codec.clone(), + )) as Arc; + + Some(catalog) + }) + } +} + +fn extract_catalog_provider_from_pyobj( + mut catalog_provider: Bound, + codec: &FFI_LogicalExtensionCodec, +) -> PyResult> { + if catalog_provider.hasattr("__datafusion_catalog_provider__")? { + let py = catalog_provider.py(); + let codec_capsule = create_logical_extension_capsule(py, codec)?; + catalog_provider = catalog_provider + .getattr("__datafusion_catalog_provider__")? + .call1((codec_capsule,))?; + } + + let provider = if let Ok(capsule) = catalog_provider.downcast::() { + validate_pycapsule(capsule, "datafusion_catalog_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: Arc = provider.into(); + provider as Arc + } else { + match catalog_provider.extract::() { + Ok(py_catalog) => py_catalog.catalog, + Err(_) => Arc::new(RustWrappedPyCatalogProvider::new( + catalog_provider.into(), + Arc::new(codec.clone()), + )) as Arc, + } + }; + + Ok(provider) +} + fn extract_schema_provider_from_pyobj( mut schema_provider: Bound, codec: &FFI_LogicalExtensionCodec, diff --git a/src/context.rs b/src/context.rs index 1cd04ac2f..145f89ab7 100644 --- a/src/context.rs +++ b/src/context.rs @@ -26,7 +26,7 @@ use arrow::pyarrow::FromPyArrow; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::catalog::CatalogProvider; +use datafusion::catalog::{CatalogProvider, CatalogProviderList}; use datafusion::common::{exec_err, ScalarValue, TableReference}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::file_format::parquet::ParquetFormat; @@ -47,6 +47,7 @@ use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; use datafusion_ffi::catalog_provider::FFI_CatalogProvider; +use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList; use datafusion_ffi::execution::FFI_TaskContextProvider; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; @@ -58,7 +59,9 @@ use pyo3::IntoPyObjectExt; use url::Url; use uuid::Uuid; -use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider}; +use crate::catalog::{ + PyCatalog, PyCatalogList, RustWrappedPyCatalogProvider, RustWrappedPyCatalogProviderList, +}; use crate::common::data_type::PyScalarValue; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; @@ -627,6 +630,40 @@ impl PySessionContext { Ok(()) } + pub fn register_catalog_provider_list( + &self, + mut provider: Bound, + ) -> PyDataFusionResult<()> { + if provider.hasattr("__datafusion_catalog_provider_list__")? { + let py = provider.py(); + let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?; + provider = provider + .getattr("__datafusion_catalog_provider_list__")? + .call1((codec_capsule,))?; + } + + let provider = + if let Ok(capsule) = provider.downcast::().map_err(py_datafusion_err) { + validate_pycapsule(capsule, "datafusion_catalog_provider_list")?; + + let provider = unsafe { capsule.reference::() }; + let provider: Arc = provider.into(); + provider as Arc + } else { + match provider.extract::() { + Ok(py_catalog_list) => py_catalog_list.catalog_list, + Err(_) => Arc::new(RustWrappedPyCatalogProviderList::new( + provider.into(), + Arc::clone(&self.logical_codec), + )) as Arc, + } + }; + + self.ctx.register_catalog_list(provider); + + Ok(()) + } + pub fn register_catalog_provider( &self, name: &str,