Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pytest
from datafusion import SessionContext, Table
from datafusion.catalog import Schema
from datafusion_ffi_example import MyCatalogProvider
from datafusion_ffi_example import MyCatalogProvider, MyCatalogProviderList


def create_test_dataset() -> Table:
Expand All @@ -35,6 +35,30 @@ def create_test_dataset() -> Table:
return Table(dataset)


@pytest.mark.parametrize("inner_capsule", [True, False])
def test_ffi_catalog_provider_list(inner_capsule: bool) -> None:
"""Test basic FFI CatalogProviderList functionality."""
ctx = SessionContext()

# Register FFI catalog
catalog_provider_list = MyCatalogProviderList()
if inner_capsule:
catalog_provider_list = (
catalog_provider_list.__datafusion_catalog_provider_list__(ctx)
)

ctx.register_catalog_provider_list(catalog_provider_list)

# Verify the catalog exists
catalog = ctx.catalog("auto_ffi_catalog")
schema_names = catalog.names()
assert "my_schema" in schema_names

ctx.register_catalog_provider("second", MyCatalogProvider())

assert ctx.catalog_names() == {"auto_ffi_catalog", "second"}


@pytest.mark.parametrize("inner_capsule", [True, False])
def test_ffi_catalog_provider_basic(inner_capsule: bool) -> None:
"""Test basic FFI CatalogProvider functionality."""
Expand Down
69 changes: 67 additions & 2 deletions examples/datafusion-ffi-example/src/catalog_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ use std::sync::Arc;
use arrow::datatypes::Schema;
use async_trait::async_trait;
use datafusion_catalog::{
CatalogProvider, MemTable, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
TableProvider,
CatalogProvider, CatalogProviderList, MemTable, MemoryCatalogProvider,
MemoryCatalogProviderList, MemorySchemaProvider, SchemaProvider, TableProvider,
};
use datafusion_common::error::{DataFusionError, Result};
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
use datafusion_ffi::schema_provider::FFI_SchemaProvider;
use pyo3::types::PyCapsule;
use pyo3::{pyclass, pymethods, Bound, PyAny, PyResult, Python};
Expand Down Expand Up @@ -203,3 +204,67 @@ impl MyCatalogProvider {
PyCapsule::new(py, provider, Some(name))
}
}

/// This catalog provider list is intended only for unit tests.
/// It pre-populates with a single catalog.
#[pyclass(
name = "MyCatalogProviderList",
module = "datafusion_ffi_example",
subclass
)]
#[derive(Debug, Clone)]
pub(crate) struct MyCatalogProviderList {
inner: Arc<MemoryCatalogProviderList>,
}

impl CatalogProviderList for MyCatalogProviderList {
fn as_any(&self) -> &dyn Any {
self
}

fn catalog_names(&self) -> Vec<String> {
self.inner.catalog_names()
}

fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
self.inner.catalog(name)
}

fn register_catalog(
&self,
name: String,
catalog: Arc<dyn CatalogProvider>,
) -> Option<Arc<dyn CatalogProvider>> {
self.inner.register_catalog(name, catalog)
}
}

#[pymethods]
impl MyCatalogProviderList {
#[new]
pub fn new() -> PyResult<Self> {
let inner = Arc::new(MemoryCatalogProviderList::new());

inner.register_catalog(
"auto_ffi_catalog".to_owned(),
Arc::new(MyCatalogProvider::new()?),
);

Ok(Self { inner })
}

pub fn __datafusion_catalog_provider_list__<'py>(
&self,
py: Python<'py>,
session: Bound<PyAny>,
) -> PyResult<Bound<'py, PyCapsule>> {
let name = cr"datafusion_catalog_provider_list".into();

let provider = Arc::clone(&self.inner) as Arc<dyn CatalogProviderList + Send>;

let codec = ffi_logical_codec_from_pycapsule(session)?;
let provider = FFI_CatalogProviderList::new_with_ffi_codec(provider, None, codec);

PyCapsule::new(py, provider, Some(name))
}
}
3 changes: 2 additions & 1 deletion examples/datafusion-ffi-example/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use pyo3::prelude::*;

use crate::aggregate_udf::MySumUDF;
use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider};
use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogProviderList};
use crate::scalar_udf::IsNullUDF;
use crate::table_function::MyTableFunction;
use crate::table_provider::MyTableProvider;
Expand All @@ -37,6 +37,7 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<MyTableProvider>()?;
m.add_class::<MyTableFunction>()?;
m.add_class::<MyCatalogProvider>()?;
m.add_class::<MyCatalogProviderList>()?;
m.add_class::<FixedSchemaProvider>()?;
m.add_class::<IsNullUDF>()?;
m.add_class::<MySumUDF>()?;
Expand Down
89 changes: 89 additions & 0 deletions python/datafusion/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,61 @@

__all__ = [
"Catalog",
"CatalogList",
"CatalogProvider",
"CatalogProviderList",
"Schema",
"SchemaProvider",
"Table",
]


class CatalogList:
"""DataFusion data catalog list."""

def __init__(self, catalog_list: df_internal.catalog.RawCatalogList) -> None:
"""This constructor is not typically called by the end user."""
self.catalog_list = catalog_list

def __repr__(self) -> str:
"""Print a string representation of the catalog list."""
return self.catalog_list.__repr__()

def names(self) -> set[str]:
"""This is an alias for `catalog_names`."""
return self.catalog_names()

def catalog_names(self) -> set[str]:
"""Returns the list of schemas in this catalog."""
return self.catalog_list.catalog_names()

@staticmethod
def memory_catalog(ctx: SessionContext | None = None) -> CatalogList:
"""Create an in-memory catalog provider list."""
catalog_list = df_internal.catalog.RawCatalogList.memory_catalog(ctx)
return CatalogList(catalog_list)

def catalog(self, name: str = "datafusion") -> Schema:
"""Returns the catalog with the given ``name`` from this catalog."""
catalog = self.catalog_list.catalog(name)

return (
Catalog(catalog)
if isinstance(catalog, df_internal.catalog.RawCatalog)
else catalog
)

def register_catalog(
self,
name: str,
catalog: Catalog | CatalogProvider | CatalogProviderExportable,
) -> Catalog | None:
"""Register a catalog with this catalog list."""
if isinstance(catalog, Catalog):
return self.catalog_list.register_catalog(name, catalog.catalog)
return self.catalog_list.register_catalog(name, catalog)


class Catalog:
"""DataFusion data catalog."""

Expand Down Expand Up @@ -195,6 +243,38 @@ def kind(self) -> str:
return self._inner.kind


class CatalogProviderList(ABC):
"""Abstract class for defining a Python based Catalog Provider List."""

@abstractmethod
def catalog_names(self) -> set[str]:
"""Set of the names of all catalogs in this catalog list."""
...

@abstractmethod
def catalog(self, name: str) -> Catalog | None:
"""Retrieve a specific catalog from this catalog list."""
...

def register_catalog( # noqa: B027
self, name: str, catalog: CatalogProviderExportable | CatalogProvider | Catalog
) -> None:
"""Add a catalog to this catalog list.

This method is optional. If your catalog provides a fixed list of catalogs, you
do not need to implement this method.
"""


class CatalogProviderListExportable(Protocol):
"""Type hint for object that has __datafusion_catalog_provider_list__ PyCapsule.

https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProviderList.html
"""

def __datafusion_catalog_provider_list__(self, session: Any) -> object: ...


class CatalogProvider(ABC):
"""Abstract class for defining a Python based Catalog Provider."""

Expand Down Expand Up @@ -229,6 +309,15 @@ def deregister_schema(self, name: str, cascade: bool) -> None: # noqa: B027
"""


class CatalogProviderExportable(Protocol):
"""Type hint for object that has __datafusion_catalog_provider__ PyCapsule.

https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
"""

def __datafusion_catalog_provider__(self, session: Any) -> object: ...


class SchemaProvider(ABC):
"""Abstract class for defining a Python based Schema Provider."""

Expand Down
27 changes: 17 additions & 10 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@

import pyarrow as pa

from datafusion.catalog import Catalog
from datafusion.catalog import (
Catalog,
CatalogList,
CatalogProviderExportable,
CatalogProviderList,
CatalogProviderListExportable,
)
from datafusion.dataframe import DataFrame
from datafusion.expr import sort_list_to_raw_sort_list
from datafusion.record_batch import RecordBatchStream
Expand Down Expand Up @@ -91,15 +97,6 @@ class TableProviderExportable(Protocol):
def __datafusion_table_provider__(self, session: Any) -> object: ... # noqa: D105


class CatalogProviderExportable(Protocol):
"""Type hint for object that has __datafusion_catalog_provider__ PyCapsule.

https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
"""

def __datafusion_catalog_provider__(self, session: Any) -> object: ... # noqa: D105


class SessionConfig:
"""Session configuration options."""

Expand Down Expand Up @@ -832,6 +829,16 @@ def catalog_names(self) -> set[str]:
"""Returns the list of catalogs in this context."""
return self.ctx.catalog_names()

def register_catalog_provider_list(
self,
provider: CatalogProviderListExportable | CatalogProviderList | CatalogList,
) -> None:
"""Register a catalog provider list."""
if isinstance(provider, CatalogList):
self.ctx.register_catalog_provider_list(provider.catalog)
else:
self.ctx.register_catalog_provider_list(provider)

def register_catalog_provider(
self, name: str, provider: CatalogProviderExportable | CatalogProvider | Catalog
) -> None:
Expand Down
35 changes: 34 additions & 1 deletion python/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

import datafusion as dfn
import pyarrow as pa
import pyarrow.dataset as ds
import pytest
from datafusion import SessionContext, Table, udtf
from datafusion import Catalog, SessionContext, Table, udtf

if TYPE_CHECKING:
from datafusion.catalog import CatalogProvider, CatalogProviderExportable


# Note we take in `database` as a variable even though we don't use
Expand Down Expand Up @@ -93,6 +98,34 @@ def deregister_schema(self, name, cascade: bool):
del self.schemas[name]


class CustomCatalogProviderList(dfn.catalog.CatalogProviderList):
def __init__(self):
self.catalogs = {"my_catalog": CustomCatalogProvider()}

def catalog_names(self) -> set[str]:
return set(self.catalogs.keys())

def catalog(self, name: str) -> Catalog | None:
return self.catalogs[name]

def register_catalog(
self, name: str, catalog: CatalogProviderExportable | CatalogProvider | Catalog
) -> None:
self.catalogs[name] = catalog


def test_python_catalog_provider_list(ctx: SessionContext):
ctx.register_catalog_provider_list(CustomCatalogProviderList())

# Ensure `datafusion` catalog does not exist since
# we replaced the catalog list
assert ctx.catalog_names() == {"my_catalog"}

# Ensure registering works
ctx.register_catalog_provider("second_catalog", CustomCatalogProvider())
assert ctx.catalog_names() == {"my_catalog", "second_catalog"}


def test_python_catalog_provider(ctx: SessionContext):
ctx.register_catalog_provider("my_catalog", CustomCatalogProvider())

Expand Down
Loading