Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
cf13aad
[python-package]add_features_from with PyArrow Table incorrectly free…
suk1yak1 Apr 23, 2025
99ee66f
[python-package] add PyArrow Table case to test_add_features_from_dif…
suk1yak1 Apr 23, 2025
a3256a3
[python-package] fix handling and tests for PyArrow Table input in ad…
suk1yak1 Apr 24, 2025
9de2650
Merge branch 'master' into fix/6891-pyarrow-table-add-features
suk1yak1 Apr 24, 2025
01bc668
delete unnecessary-else
suk1yak1 Apr 28, 2025
27e2545
Merge branch 'master' of https://github.com/microsoft/LightGBM into f…
suk1yak1 Apr 28, 2025
219e61c
[python-package]add test for pyarrow table in Dataset.get_data()
suk1yak1 Apr 28, 2025
d4de309
[python-package]add test for add_features_from with pyarrow tables
suk1yak1 Apr 28, 2025
6686a60
Merge branch 'master' into fix/6891-pyarrow-table-add-features
suk1yak1 May 7, 2025
22ec314
[python-package] add PyArrow Table to get_data
suk1yak1 May 9, 2025
c7594c8
[python-package] add test for subset of PyArrow table dataset
suk1yak1 May 10, 2025
e0fce82
Merge branch 'master' into fix/6891-pyarrow-table-get-data
StrikerRUS May 20, 2025
1af15c6
Merge branch 'master' into fix/6891-pyarrow-table-get-data
suk1yak1 May 29, 2025
b5395a0
[python-package] improve PyArrow table subset tests for null values a…
suk1yak1 May 30, 2025
5f067ec
Merge branch 'master' into fix/6891-pyarrow-table-get-data
jameslamb Jul 27, 2025
eaf9510
Merge branch 'master' into fix/6891-pyarrow-table-get-data
suk1yak1 Aug 12, 2025
04db4aa
Merge branch 'master' into fix/6891-pyarrow-table-get-data
suk1yak1 Sep 1, 2025
e569ac5
[python-package]avoid TypeError: ChunkedArray.to_numpy() takes no key…
suk1yak1 Sep 1, 2025
5b4b197
Merge branch 'master' into fix/6891-pyarrow-table-get-data
suk1yak1 Sep 4, 2025
e826a4f
[python-package]Move final assertion before for loop to fail faster a…
suk1yak1 Sep 12, 2025
17dbc43
[python-package]Rename test helper and add docstring to clarify purpose
suk1yak1 Sep 12, 2025
8135e92
[python-package]Rename test helper
suk1yak1 Sep 12, 2025
71a65e5
Merge branch 'fix/6891-pyarrow-table-get-data' of github.com:suk1yak1…
suk1yak1 Sep 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
pd_Series,
)

if PYARROW_INSTALLED:
import pyarrow as pa

if TYPE_CHECKING:
from typing import Literal

Expand Down Expand Up @@ -3267,6 +3270,8 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]:
self.data = self.data.iloc[self.used_indices].copy()
elif isinstance(self.data, Sequence):
self.data = self.data[self.used_indices]
elif isinstance(self.data, pa_Table):
self.data = self.data.take(self.used_indices)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This was definitely just something we'd missed.

Can you please add a test in https://github.com/microsoft/LightGBM/blob/master/tests/python_package_test/test_arrow.py just for this change to get_data()? The other changes you made in test_basic.py do not cover these changes.

When you do that, please check that the content of self.data AND the returned value are correct (e.g., contain exactly the expected values and data types).

If you'd like, I'd even support opening a new pull request that only has the get_data() changes + test (and then making this PR only about add_features_from()). Totally your choice, I want to be respectful of your time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!
I added tests for both get_data() and add_features_from() directly in test_arrow.py as part of this PR.
Please let me know if there’s anything else you’d like me to adjust!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jameslamb
I’ve opened a new pull request(#6911) that includes only the changes to get_data() along with the corresponding test. This should help keep things focused. I’d appreciate it if you could take a look when you have time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll focus there.

elif _is_list_of_sequences(self.data) and len(self.data) > 0:
self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices)))
else:
Expand Down Expand Up @@ -3451,6 +3456,21 @@ def add_features_from(self, other: "Dataset") -> "Dataset":
self.data = np.hstack((self.data, other.data.toarray()))
elif isinstance(other.data, pd_DataFrame):
self.data = np.hstack((self.data, other.data.values))
elif isinstance(other.data, pa_Table):
if not PYARROW_INSTALLED:
raise LightGBMError(
"Cannot add features to pyarrow.Table type of raw data "
"without pyarrow installed. "
"Install pyarrow and restart your session."
)
self.data = np.hstack(
(
self.data,
np.column_stack(
[other.data.column(i).to_numpy() for i in range(len(other.data.column_names))]
),
)
)
else:
self.data = None
elif isinstance(self.data, scipy.sparse.spmatrix):
Expand All @@ -3459,6 +3479,22 @@ def add_features_from(self, other: "Dataset") -> "Dataset":
self.data = scipy.sparse.hstack((self.data, other.data), format=sparse_format)
elif isinstance(other.data, pd_DataFrame):
self.data = scipy.sparse.hstack((self.data, other.data.values), format=sparse_format)
elif isinstance(other.data, pa_Table):
if not PYARROW_INSTALLED:
raise LightGBMError(
"Cannot add features to pyarrow.Table type of raw data "
"without pyarrow installed. "
"Install pyarrow and restart your session."
)
self.data = scipy.sparse.hstack(
(
self.data,
np.column_stack(
[other.data.column(i).to_numpy() for i in range(len(other.data.column_names))]
),
),
format=sparse_format,
)
else:
self.data = None
elif isinstance(self.data, pd_DataFrame):
Expand All @@ -3474,6 +3510,71 @@ def add_features_from(self, other: "Dataset") -> "Dataset":
self.data = concat((self.data, pd_DataFrame(other.data.toarray())), axis=1, ignore_index=True)
elif isinstance(other.data, pd_DataFrame):
self.data = concat((self.data, other.data), axis=1, ignore_index=True)
elif isinstance(other.data, pa_Table):
if not PYARROW_INSTALLED:
raise LightGBMError(
"Cannot add features to pyarrow.Table type of raw data "
"without pyarrow installed. "
"Install pyarrow and restart your session."
)
self.data = concat(
(
self.data,
pd_DataFrame(
{
other.data.column_names[i]: other.data.column(i).to_numpy()
for i in range(len(other.data.column_names))
}
),
),
axis=1,
ignore_index=True,
)
else:
self.data = None
elif isinstance(self.data, pa_Table):
if not PYARROW_INSTALLED:
raise LightGBMError(
"Cannot add features to pyarrow.Table type of raw data "
"without pyarrow installed. "
"Install pyarrow and restart your session."
)
if isinstance(other.data, np.ndarray):
self.data = pa_Table.from_arrays(
[
*self.data.columns,
*[pa.array(other.data[:, i]) for i in range(other.data.shape[1])],
],
names=[
*self.data.column_names,
*[f"D{len(self.data.column_names) + i + 1}" for i in range(other.data.shape[1])],
],
)
elif isinstance(other.data, scipy.sparse.spmatrix):
other_array = other.data.toarray()
self.data = pa_Table.from_arrays(
[
*self.data.columns,
*[pa.array(other_array[:, i]) for i in range(other_array.shape[1])],
],
names=[
*self.data.column_names,
*[f"D{len(self.data.column_names) + i + 1}" for i in range(other_array.shape[1])],
],
)
elif isinstance(other.data, pd_DataFrame):
self.data = pa_Table.from_arrays(
[
*self.data.columns,
*[pa.array(other.data.iloc[:, i].values) for i in range(len(other.data.columns))],
],
names=[*self.data.column_names, *map(str, other.data.columns.tolist())],
)
elif isinstance(other.data, pa_Table):
self.data = pa_Table.from_arrays(
[*self.data.columns, *other.data.columns],
names=[*self.data.column_names, *other.data.column_names],
)
else:
self.data = None
else:
Expand Down
99 changes: 98 additions & 1 deletion tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def generate_random_arrow_array(
chunks = [chunk for chunk in chunks if len(chunk) > 0]

# Turn chunks into array
return pa.chunked_array([data], type=pa.float32())
return pa.chunked_array(chunks, type=pa.float32())


def dummy_dataset_params() -> Dict[str, Any]:
Expand Down Expand Up @@ -456,6 +456,103 @@ def test_arrow_feature_name_manual():
assert booster.feature_name() == ["c", "d"]


def pyarrow_array_equal(arr1: pa.ChunkedArray, arr2: pa.ChunkedArray) -> bool:
"""Similar to ``np.array_equal()``, but for ``pyarrow.Array`` objects.

``pyarrow.Array`` objects with identical values do not compare equal if any of those
values are nulls. This function treats them as equal.
"""
if len(arr1) != len(arr2):
return False

np1 = arr1.to_numpy()
np2 = arr2.to_numpy()
return np.array_equal(np1, np2, equal_nan=True)


def test_get_data_arrow_table():
original_table = generate_simple_arrow_table()
dataset = lgb.Dataset(original_table, free_raw_data=False)
dataset.construct()

returned_data = dataset.get_data()
assert isinstance(returned_data, pa.Table)
assert returned_data.schema == original_table.schema
assert returned_data.shape == original_table.shape

for column_name in original_table.column_names:
original_column = original_table[column_name]
returned_column = returned_data[column_name]

assert original_column.type == returned_column.type
assert original_column.num_chunks == returned_column.num_chunks
assert pyarrow_array_equal(original_column, returned_column)

for i in range(original_column.num_chunks):
original_chunk_array = pa.chunked_array([original_column.chunk(i)])
returned_chunk_array = pa.chunked_array([returned_column.chunk(i)])
assert pyarrow_array_equal(original_chunk_array, returned_chunk_array)


def test_get_data_arrow_table_subset(rng):
original_table = generate_random_arrow_table(num_columns=3, num_datapoints=1000, seed=42)
dataset = lgb.Dataset(original_table, free_raw_data=False)
dataset.construct()

subset_size = 100
used_indices = rng.choice(a=original_table.shape[0], size=subset_size, replace=False)
used_indices = sorted(used_indices)

subset_dataset = dataset.subset(used_indices).construct()
expected_subset = original_table.take(used_indices)
subset_data = subset_dataset.get_data()

assert isinstance(subset_data, pa.Table)
assert subset_data.schema == expected_subset.schema
assert subset_data.shape == expected_subset.shape
assert len(subset_data) == len(used_indices)
assert subset_data.shape == (subset_size, 3)

for column_name in expected_subset.column_names:
expected_col = expected_subset[column_name]
returned_col = subset_data[column_name]
assert expected_col.type == returned_col.type
assert pyarrow_array_equal(expected_col, returned_col)


def test_add_features_from_arrow_table():
table1 = pa.Table.from_arrays(
[pa.array([1, 2, 3, 4, 5], type=pa.int32()), pa.array([0.1, 0.2, 0.3, 0.4, 0.5], type=pa.float32())],
names=["feature1", "feature2"],
)

table2 = pa.Table.from_arrays(
[
pa.array([10, 20, 30, 40, 50], type=pa.int64()),
pa.array([1.1, 1.2, 1.3, 1.4, 1.5], type=pa.float64()),
pa.array([True, False, True, False, True], type=pa.bool_()),
],
names=["feature3", "feature4", "feature5"],
)

dataset1 = lgb.Dataset(table1, free_raw_data=False)
dataset2 = lgb.Dataset(table2, free_raw_data=False)

dataset1.construct()
dataset2.construct()

dataset1.add_features_from(dataset2)
combined_data = dataset1.get_data()
assert isinstance(combined_data, pa.Table)
assert combined_data.num_columns == table1.num_columns + table2.num_columns
assert set(combined_data.column_names) == set(table1.column_names + table2.column_names)
assert combined_data.num_rows == table1.num_rows
for column in table1.column_names:
assert combined_data[column].equals(table1[column])
for column in table2.column_names:
assert combined_data[column].equals(table2[column])


def test_dataset_construction_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi):
with pytest.raises(
lgb.basic.LightGBMError, match="Cannot init Dataset from Arrow without 'pyarrow' and 'cffi' installed."
Expand Down
22 changes: 21 additions & 1 deletion tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@

from .utils import dummy_obj, load_breast_cancer, mse_obj, np_assert_array_equal

if getenv("ALLOW_SKIP_ARROW_TESTS") == "1":
pa = pytest.importorskip("pyarrow")
else:
import pyarrow as pa # type: ignore

assert lgb.compat.PYARROW_INSTALLED is True, (
"'pyarrow' and its dependencies must be installed to run the arrow tests"
)


def test_basic(tmp_path):
X_train, X_test, y_train, y_test = train_test_split(
Expand Down Expand Up @@ -348,7 +357,18 @@ def test_add_features_from_different_sources(rng):
n_row = 100
n_col = 5
X = rng.uniform(size=(n_row, n_col))
xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X)]
xxs = [
X,
sparse.csr_matrix(X),
pd.DataFrame(X),
]
if getenv("ALLOW_SKIP_ARROW_TESTS") != "1":
xxs.append(
pa.Table.from_arrays(
[pa.array(X[:, i]) for i in range(X.shape[1])], names=[f"D{i}" for i in range(X.shape[1])]
)
)

names = [f"col_{i}" for i in range(n_col)]
seq = _create_sequence_from_ndarray(X, 1, 30)
seq_ds = lgb.Dataset(seq, feature_name=names, free_raw_data=False).construct()
Expand Down