Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions examples/Advanced/tasks_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
#
# We will start by simply listing only *supervised classification* tasks.
#
# **openml.tasks.list_tasks()** returns a dictionary of dictionaries by default, but we
# request a
# **openml.list_all("task")** (or **openml.tasks.list_tasks()**) returns a dictionary of
# dictionaries by default, but we request a
# [pandas dataframe](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html)
# instead to have better visualization capabilities and easier access:

# %%
tasks = openml.tasks.list_tasks(task_type=TaskType.SUPERVISED_CLASSIFICATION)
tasks = openml.list_all("task", task_type=TaskType.SUPERVISED_CLASSIFICATION)
# Legacy path still works:
# tasks = openml.tasks.list_tasks(task_type=TaskType.SUPERVISED_CLASSIFICATION)
print(tasks.columns)
print(f"First 5 of {len(tasks)} tasks:")
print(tasks.head())
Expand Down Expand Up @@ -66,23 +68,29 @@
# Similar to listing tasks by task type, we can list tasks by tags:

# %%
tasks = openml.tasks.list_tasks(tag="OpenML100")
tasks = openml.list_all("task", tag="OpenML100")
# Legacy path still works:
# tasks = openml.tasks.list_tasks(tag="OpenML100")
print(f"First 5 of {len(tasks)} tasks:")
print(tasks.head())

# %% [markdown]
# Furthermore, we can list tasks based on the dataset id:

# %%
tasks = openml.tasks.list_tasks(data_id=1471)
tasks = openml.list_all("task", data_id=1471)
# Legacy path still works:
# tasks = openml.tasks.list_tasks(data_id=1471)
print(f"First 5 of {len(tasks)} tasks:")
print(tasks.head())

# %% [markdown]
# In addition, a size limit and an offset can be applied both separately and simultaneously:

# %%
tasks = openml.tasks.list_tasks(size=10, offset=50)
tasks = openml.list_all("task", size=10, offset=50)
# Legacy path still works:
# tasks = openml.tasks.list_tasks(size=10, offset=50)
print(tasks)

# %% [markdown]
Expand All @@ -98,7 +106,9 @@
# Finally, it is also possible to list all tasks on OpenML with:

# %%
tasks = openml.tasks.list_tasks()
tasks = openml.list_all("task")
# Legacy path still works:
# tasks = openml.tasks.list_tasks()
print(len(tasks))

# %% [markdown]
Expand All @@ -118,7 +128,9 @@

# %%
task_id = 31
task = openml.tasks.get_task(task_id)
task = openml.get(task_id, object_type="task")
# Legacy path still works:
# task = openml.tasks.get_task(task_id)

# %%
# Properties of the task are stored as member variables:
Expand Down
12 changes: 10 additions & 2 deletions examples/Basics/simple_datasets_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,23 @@
# ## List datasets stored on OpenML

# %%
datasets_df = openml.datasets.list_datasets()
datasets_df = openml.list_all("dataset")
print(datasets_df.head(n=10))

# Legacy path still works:
# datasets_df = openml.datasets.list_datasets()

# %% [markdown]
# ## Download a dataset

# %%
# Iris dataset https://www.openml.org/d/61
dataset = openml.datasets.get_dataset(dataset_id=61)
dataset = openml.get(61)
# You can also fetch by name:
# dataset = openml.get("Fashion-MNIST")

# Legacy path still works:
# dataset = openml.datasets.get_dataset(dataset_id=61)

# Print a summary
print(
Expand Down
15 changes: 14 additions & 1 deletion examples/Basics/simple_flows_and_runs_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,25 @@
# %%
openml.config.start_using_configuration_for_example()

# %% [markdown]
# ## Quick: list flows and runs via unified entrypoints

# %%
flows_df = openml.list_all("flow", size=3)
print(flows_df.head())

runs_df = openml.list_all("run", size=3)
print(runs_df.head())

# %% [markdown]
# ## Train a machine learning model and evaluate it
# NOTE: We are using task 119 from the test server: https://test.openml.org/d/20

# %%
task = openml.tasks.get_task(119)
task = openml.get(119, object_type="task")

# Legacy path still works:
# task = openml.tasks.get_task(119)

# Get the data
dataset = task.get_dataset()
Expand Down
5 changes: 4 additions & 1 deletion examples/Basics/simple_tasks_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# [supervised classification on credit-g](https://www.openml.org/search?type=task&id=31&source_data.data_id=31):

# %%
task = openml.tasks.get_task(31)
task = openml.get(31, object_type="task")

# Legacy path still works:
# task = openml.tasks.get_task(31)

# %% [markdown]
# Get the dataset and its data from the task.
Expand Down
5 changes: 5 additions & 0 deletions openml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_api_calls,
config,
datasets,
dispatchers,
evaluations,
exceptions,
extensions,
Expand All @@ -34,6 +35,7 @@
)
from .__version__ import __version__
from .datasets import OpenMLDataFeature, OpenMLDataset
from .dispatchers import get, list_all
from .evaluations import OpenMLEvaluation
from .flows import OpenMLFlow
from .runs import OpenMLRun
Expand Down Expand Up @@ -108,6 +110,7 @@ def populate_cache(
"OpenMLStudy",
"OpenMLBenchmarkSuite",
"datasets",
"dispatchers",
"evaluations",
"exceptions",
"extensions",
Expand All @@ -120,4 +123,6 @@ def populate_cache(
"utils",
"_api_calls",
"__version__",
"get",
"list_all",
]
107 changes: 107 additions & 0 deletions openml/dispatchers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""OpenML API dispatchers for unified get/list operations."""

# License: BSD 3-Clause
from __future__ import annotations

from typing import Any, Callable, Dict

from .datasets import get_dataset, list_datasets
from .flows import get_flow, list_flows
from .runs import get_run, list_runs
from .tasks import get_task, list_tasks

ListDispatcher = Dict[str, Callable[..., Any]]
GetDispatcher = Dict[str, Callable[..., Any]]

_LIST_DISPATCH: ListDispatcher = {
"dataset": list_datasets,
"task": list_tasks,
"flow": list_flows,
"run": list_runs,
}

_GET_DISPATCH: GetDispatcher = {
"dataset": get_dataset,
"task": get_task,
"flow": get_flow,
"run": get_run,
}


def list_all(object_type: str, /, **kwargs: Any) -> Any:
"""List OpenML objects by type (e.g., datasets, tasks, flows, runs).

This is a convenience dispatcher that forwards to the existing type-specific
``list_*`` functions. Existing imports remain available for backward compatibility.

Parameters
----------
object_type : str
The type of object to list. Must be one of 'dataset', 'task', 'flow', 'run'.
**kwargs : Any
Additional arguments passed to the underlying list function.

Returns
-------
Any
The result from the type-specific list function (typically a DataFrame).

Raises
------
ValueError
If object_type is not one of the supported types.
"""
if not isinstance(object_type, str):
raise TypeError(f"object_type must be a string, got {type(object_type).__name__}")

func = _LIST_DISPATCH.get(object_type.lower())
if func is None:
valid_types = ", ".join(repr(k) for k in _LIST_DISPATCH)
raise ValueError(
f"Unsupported object_type {object_type!r}; expected one of {valid_types}.",
)

return func(**kwargs)


def get(identifier: int | str, *, object_type: str = "dataset", **kwargs: Any) -> Any:
"""Get an OpenML object by identifier.

Parameters
----------
identifier : int | str
The ID or name of the object to retrieve.
object_type : str, default="dataset"
The type of object to get. Must be one of 'dataset', 'task', 'flow', 'run'.
**kwargs : Any
Additional arguments passed to the underlying get function.

Returns
-------
Any
The requested OpenML object.

Raises
------
ValueError
If object_type is not one of the supported types.

Examples
--------
>>> openml.get(61) # Get dataset 61 (default object_type="dataset")
>>> openml.get("Fashion-MNIST") # Get dataset by name
>>> openml.get(31, object_type="task") # Get task 31
>>> openml.get(10, object_type="flow") # Get flow 10
>>> openml.get(20, object_type="run") # Get run 20
"""
if not isinstance(object_type, str):
raise TypeError(f"object_type must be a string, got {type(object_type).__name__}")

func = _GET_DISPATCH.get(object_type.lower())
if func is None:
valid_types = ", ".join(repr(k) for k in _GET_DISPATCH)
raise ValueError(
f"Unsupported object_type {object_type!r}; expected one of {valid_types}.",
)

return func(identifier, **kwargs)
40 changes: 40 additions & 0 deletions tests/test_openml/test_openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,43 @@ def test_populate_cache(
assert task_mock.call_count == 2
for argument, fixture in zip(task_mock.call_args_list, [(1,), (2,)]):
assert argument[0] == fixture

@mock.patch("openml.tasks.functions.list_tasks")
@mock.patch("openml.datasets.functions.list_datasets")
def test_list_dispatch(self, list_datasets_mock, list_tasks_mock):
# Need to patch after import, so update dispatch dict
with mock.patch.dict(
"openml.dispatchers._LIST_DISPATCH",
{
"dataset": list_datasets_mock,
"task": list_tasks_mock,
},
):
openml.list_all("dataset")
list_datasets_mock.assert_called_once_with()

openml.list_all("task", size=5)
list_tasks_mock.assert_called_once_with(size=5)

@mock.patch("openml.tasks.functions.get_task")
@mock.patch("openml.datasets.functions.get_dataset")
def test_get_dispatch(self, get_dataset_mock, get_task_mock):
# Need to patch after import, so update dispatch dict
with mock.patch.dict(
"openml.dispatchers._GET_DISPATCH",
{
"dataset": get_dataset_mock,
"task": get_task_mock,
},
):
openml.get(61)
get_dataset_mock.assert_called_with(61)

openml.get("Fashion-MNIST", version=2)
get_dataset_mock.assert_called_with("Fashion-MNIST", version=2)

openml.get("Fashion-MNIST")
get_dataset_mock.assert_called_with("Fashion-MNIST")

openml.get(31, object_type="task")
get_task_mock.assert_called_with(31)