From 5ac9344d4734746dce98712310186d2a42d16ae2 Mon Sep 17 00:00:00 2001 From: Bram Veenboer Date: Mon, 12 Jan 2026 14:13:19 +0100 Subject: [PATCH 1/6] Add streamlit version of the dashboard --- ktdashboard/streamlit.py | 282 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 282 insertions(+) create mode 100644 ktdashboard/streamlit.py diff --git a/ktdashboard/streamlit.py b/ktdashboard/streamlit.py new file mode 100644 index 0000000..75ee3a0 --- /dev/null +++ b/ktdashboard/streamlit.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python +import json +import argparse +from typing import Tuple, Dict, Any, List + +import streamlit as st +import pandas as pd +import plotly.express as px +import numpy as np + + +def _read_cachefile(cache_file: str) -> Dict[str, Any]: + # Read file and handle partial/trailing content + with open(cache_file, "r") as fh: + filestr = fh.read().strip() + + if filestr == "": + raise ValueError("Cache file is empty") + + # Try to be permissive: if file ends with a stray comma or missing closing braces, try to fix + if not filestr.endswith("}\n}") and not filestr.endswith("}\n}"): + # remove trailing comma if present + if filestr[-1] == ",": + filestr = filestr[:-1] + # attempt to close + if not filestr.endswith("}\n}"): + filestr = filestr + "}\n}" + + cached_data = json.loads(filestr) + return cached_data + + +def prepare_dataframe( + cached_data: Dict[str, Any], objective: str = None +) -> Tuple[pd.DataFrame, Dict[str, List[Any]]]: + if objective is None: + objective = cached_data.get("objective", "time") + + data = list(cached_data["cache"].values()) + data = [ + d + for d in data + if d.get(objective) != 1e20 and not isinstance(d.get(objective), str) + ] + + tune_params_keys = cached_data["tune_params_keys"] + all_tune_params = {} + for key in tune_params_keys: + values = cached_data["tune_params"][key] + for row in data: + if row[key] not in values: + values = sorted(values + [row[key]]) + all_tune_params[key] = values + + # figure out which keys are interesting + single_value_tune_param_keys = [ + key for key in tune_params_keys if len(all_tune_params[key]) == 1 + ] + tune_param_keys = [ + key for key in tune_params_keys if key not in single_value_tune_param_keys + ] + scalar_value_keys = [ + key + for key in data[0].keys() + if not isinstance(data[0][key], list) + and key not in single_value_tune_param_keys + ] + output_keys = [key for key in scalar_value_keys if key not in tune_param_keys] + + df = pd.DataFrame(data)[scalar_value_keys] + + # Add 'index' column (numeric) so the UI can select "index" as an axis + df = df.reset_index(drop=True) + df.insert(0, "index", df.index.astype(int)) + + # Convert tune params to categorical where appropriate to preserve ordering + for key in tune_param_keys: + if key in df.columns: + df[key] = pd.Categorical( + df[key], categories=all_tune_params[key], ordered=True + ) + + return df, { + "tune_param_keys": tune_param_keys, + "all_tune_params": all_tune_params, + "scalar_value_keys": scalar_value_keys, + "output_keys": output_keys, + } + + +def filter_dataframe( + df: pd.DataFrame, selections: Dict[str, List[Any]] +) -> pd.DataFrame: + mask = pd.Series(True, index=df.index) + for k, v in selections.items(): + if v: + mask &= df[k].isin(v) + return df[mask] + + +def plot_scatter( + df: pd.DataFrame, + x: str, + y: str, + color: str, + xscale: str, + yscale: str, + palette: str = "Viridis", +) -> Any: + # For categorical axes, we can map categories to numbers and add jitter for visual separation + df_plot = df.copy() + + def jitter(col): + dtype = df_plot[col].dtype + if isinstance(dtype, pd.CategoricalDtype) or dtype == object: + categories = list(pd.Categorical(df_plot[col]).categories) + mapping = {c: i for i, c in enumerate(categories)} + arr = df_plot[col].map(mapping).astype(float) + arr += np.random.normal(scale=0.15, size=len(arr)) + return arr, categories + else: + return df_plot[col], None + + x_vals, x_cats = jitter(x) + y_vals, y_cats = jitter(y) + + df_plot["_x"] = x_vals + df_plot["_y"] = y_vals + + color_arg = color if color in df_plot.columns else None + + # Determine palette lists from plotly (sequential palettes only) + seq = getattr(px.colors.sequential, palette, None) + color_kwargs = {"color_continuous_scale": seq} + + fig = px.scatter( + df_plot, + x="_x", + y="_y", + color=color_arg, + hover_data=df_plot.columns, + height=600, + width=900, + labels={"_x": x, "_y": y}, + **color_kwargs, + ) + + # If axis corresponded to categories, set tick labels + if x_cats is not None: + fig.update_xaxes( + tickmode="array", tickvals=list(range(len(x_cats))), ticktext=x_cats + ) + if y_cats is not None: + fig.update_yaxes( + tickmode="array", tickvals=list(range(len(y_cats))), ticktext=y_cats + ) + + # Set log scales if requested + if xscale == "log": + fig.update_xaxes(type="log") + if yscale == "log": + fig.update_yaxes(type="log") + + return fig + + +def main(): + parser = argparse.ArgumentParser(description="Streamlit Kernel Tuner Dashboard") + parser.add_argument("cachefile", nargs="?", help="Path to cache file (JSON)") + args = parser.parse_args() + + st.set_page_config(layout="wide", page_title="Kernel Tuner Dashboard") + + st.title("Kernel Tuner Dashboard") + + # Allow providing file via arg or file uploader + cachefile = args.cachefile + + if not cachefile: + uploaded = st.sidebar.file_uploader("Upload cache JSON file", type=["json"]) + if uploaded is not None: + # save to a temp file so we can read locations + import tempfile + + tf = tempfile.NamedTemporaryFile(delete=False, suffix=".json") + tf.write(uploaded.read()) + tf.flush() + cachefile = tf.name + + if not cachefile: + st.info( + "Provide a cache file via command-line (streamlit run ... -- ) or upload one in the sidebar." + ) + return + + try: + cached_data = _read_cachefile(cachefile) + except Exception as exc: + st.error(f"Failed to read cache file: {exc}") + return + + kernel_name = cached_data.get("kernel_name", "") + device_name = cached_data.get("device_name", "") + + st.sidebar.markdown(f"**Kernel:** {kernel_name}") + st.sidebar.markdown(f"**Device:** {device_name}") + + df, meta = prepare_dataframe(cached_data) + + scalar_value_keys = meta["scalar_value_keys"] + tune_param_keys = meta["tune_param_keys"] + all_tune_params = meta["all_tune_params"] + + default_key = ( + "GFLOP/s" + if "GFLOP/s" in scalar_value_keys + else ("time" if "time" in scalar_value_keys else scalar_value_keys[0]) + ) + + yvariable = st.sidebar.selectbox( + "Y", options=scalar_value_keys, index=scalar_value_keys.index(default_key) + ) + xvariable = st.sidebar.selectbox( + "X", options=["index"] + scalar_value_keys, index=0 + ) + colorvariable = st.sidebar.selectbox( + "Color By", + options=scalar_value_keys, + index=scalar_value_keys.index(default_key), + ) + xscale = st.sidebar.radio("X axis scale", options=["linear", "log"], index=0) + yscale = st.sidebar.radio("Y axis scale", options=["linear", "log"], index=0) + + # Color palette chooser (sequential palettes only) + seq_names = [ + name + for name in dir(px.colors.sequential) + if not name.startswith("_") + and isinstance(getattr(px.colors.sequential, name), list) + ] + seq_names = sorted(seq_names) + default_idx = seq_names.index("Viridis") if "Viridis" in seq_names else 0 + palette = st.sidebar.selectbox( + "Color palette", options=seq_names, index=default_idx + ) + + # tune param multi-selects + selections = {} + for tp in tune_param_keys: + selections[tp] = st.sidebar.multiselect( + tp, options=all_tune_params[tp], default=list(all_tune_params[tp]) + ) + + filtered_df = filter_dataframe(df, selections) + + st.markdown(f"## Auto-tuning {kernel_name} on {device_name}") + + fig = plot_scatter( + filtered_df, + xvariable, + yvariable, + colorvariable, + xscale, + yscale, + palette=palette, + ) + st.plotly_chart(fig, width="stretch") + + st.markdown("---") + st.markdown("### Top results") + + # Show best by selected y (if numeric) + if pd.api.types.is_numeric_dtype(df[yvariable]): + best = df.sort_values(yvariable).head(10) + st.dataframe(best) + else: + st.write("Y variable is non-numeric; showing raw filtered data") + st.dataframe(filtered_df) + + +if __name__ == "__main__": + main() From 36e99c0f4459e531a86b1392a67e59509f02743c Mon Sep 17 00:00:00 2001 From: Bram Veenboer Date: Mon, 12 Jan 2026 15:40:10 +0100 Subject: [PATCH 2/6] Make panel plot adjust to page width --- README.md | 10 +- ktdashboard/dashboard.py | 139 +++++++++++++ ktdashboard/ktdashboard.py | 315 ++--------------------------- ktdashboard/panel_dashboard.py | 220 ++++++++++++++++++++ ktdashboard/streamlit.py | 282 -------------------------- ktdashboard/streamlit_dashboard.py | 168 +++++++++++++++ setup.py | 2 +- 7 files changed, 558 insertions(+), 578 deletions(-) create mode 100644 ktdashboard/dashboard.py create mode 100644 ktdashboard/panel_dashboard.py delete mode 100644 ktdashboard/streamlit.py create mode 100644 ktdashboard/streamlit_dashboard.py diff --git a/README.md b/README.md index eef67a3..2bf6bd6 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,6 @@ To turn up your first dashboard and test if your installation works, type: ktdashboard -demo tests/test_cache_1000.json ``` This creates a KTdashboard using a test cache with about a 1000 different benchmark configurations in it. -The ``-demo`` switch enables demo mode, which means that KTdashboard mimicks a live tuning run. KTdashboard uses Kernel Tuner's cachefiles to visualize the auto-tuning results as they come in. Cache files are used within Kernel Tuner to record all information about all benchmarked kernel configurations. This allows the tuner to do several @@ -41,6 +40,15 @@ allows you to monitor the tuner's progress using: ktdashboard my_cache_filename.json ``` +You can choose the backend to visualize results. Example: +``` +# Launch the Panel + Bokeh dashboard (default) +ktdashboard my_cache_filename.json + +# Launch the Streamlit + Plotly dashboard +ktdashboard --backend streamlit my_cache_filename.json +``` + ## License, contributions, citation KTdashboard is considered part of the Kernel Tuner project, for licensing, contribution guide, and citation information please see diff --git a/ktdashboard/dashboard.py b/ktdashboard/dashboard.py new file mode 100644 index 0000000..4623465 --- /dev/null +++ b/ktdashboard/dashboard.py @@ -0,0 +1,139 @@ +from typing import Dict, Any, List +import json +import pandas as pd + + +class Dashboard: + def __init__(self, cache_file: str): + self.cache_file = cache_file + + # Open cachefile for reading appended results + self.cache_file_handle = open(cache_file, "r") + filestr = self.cache_file_handle.read().strip() + + # Try to be tolerant for trailing missing brackets or commas + if filestr and not filestr.endswith("}\n}"): + if filestr[-1] == ",": + filestr = filestr[:-1] + filestr = filestr + "}\n}" + + cached_data = json.loads(filestr) if filestr else {} + + self.kernel_name = cached_data.get("kernel_name", "") + self.device_name = cached_data.get("device_name", "") + self.objective = cached_data.get("objective", "time") + + # raw performance records + data = list(cached_data.get("cache", {}).values()) + data = [ + d + for d in data + if d.get(self.objective) != 1e20 + and not isinstance(d.get(self.objective), str) + ] + + self.index = len(data) + self._all_data = data + + # tune parameters + self.tune_params_keys = cached_data.get("tune_params_keys", []) + self.all_tune_params: Dict[str, List[Any]] = {} + for key in self.tune_params_keys: + values = cached_data.get("tune_params", {}).get(key, []) + for row in data: + if row.get(key) not in values: + values = sorted(values + [row.get(key)]) + self.all_tune_params[key] = values + + # find keys + self.single_value_tune_param_keys = [ + k for k in self.tune_params_keys if len(self.all_tune_params[k]) == 1 + ] + self.tune_param_keys = [ + k + for k in self.tune_params_keys + if k not in self.single_value_tune_param_keys + ] + + scalar_value_keys = [ + k + for k in (data[0].keys() if data else []) + if not isinstance(data[0][k], list) + and k not in self.single_value_tune_param_keys + ] + self.scalar_value_keys = scalar_value_keys + self.output_keys = [ + k for k in scalar_value_keys if k not in self.tune_param_keys + ] + self.float_keys = [ + k for k in self.output_keys if isinstance(data[0].get(k), float) if data + ] + + # prepare DataFrame + self.data_df = ( + pd.DataFrame(data)[self.scalar_value_keys] + if len(self.scalar_value_keys) > 0 + else pd.DataFrame() + ) + self.data_df = self.data_df.reset_index(drop=True) + if not self.data_df.empty: + self.data_df.insert(0, "index", self.data_df.index.astype(int)) + + # categorical conversion for tune params + for key in self.tune_param_keys: + if key in self.data_df.columns: + self.data_df[key] = pd.Categorical( + self.data_df[key], + categories=self.all_tune_params[key], + ordered=True, + ) + + # selections for filtering + self.selected_tune_params = { + key: self.all_tune_params[key].copy() for key in self.tune_param_keys + } + + def close(self): + try: + self.cache_file_handle.close() + except Exception: + pass + + def get_filtered_df(self) -> pd.DataFrame: + """Return filtered DataFrame based on selected_tune_params.""" + mask = pd.Series(True, index=self.data_df.index) + for k, v in self.selected_tune_params.items(): + mask &= self.data_df[k].isin(v) + return self.data_df[mask] + + def update_selection(self, key: str, values: List[Any]): + """Update selection for a tune parameter.""" + self.selected_tune_params[key] = values + + def get_stream_for_index(self, i: int) -> Dict[str, List[Any]]: + """Return a stream dict for a single element at index i.""" + element = self._all_data[i] + stream_dict = { + k: [v] + for k, v in dict(element, index=i).items() + if k in ["index"] + self.scalar_value_keys + } + return stream_dict + + def read_new_contents(self) -> List[Dict[str, List[Any]]]: + """Read appended JSON from the cachefile and return list of stream_dicts for new entries.""" + new_contents = self.cache_file_handle.read().strip() + stream_dicts = [] + if new_contents: + # process new contents (parse as JSON, make into dict that goes into source.stream) + new_contents_json = "{" + new_contents[:-1] + "}" + new_data = list(json.loads(new_contents_json).values()) + for i, element in enumerate(new_data): + stream_dict = { + k: [v] + for k, v in dict(element, index=self.index + i).items() + if k in ["index"] + self.scalar_value_keys + } + stream_dicts.append(stream_dict) + self.index += len(new_data) + return stream_dicts diff --git a/ktdashboard/ktdashboard.py b/ktdashboard/ktdashboard.py index c262aa8..1739dfa 100644 --- a/ktdashboard/ktdashboard.py +++ b/ktdashboard/ktdashboard.py @@ -1,306 +1,33 @@ #!/usr/bin/env python -import json -import sys -import os +import argparse, subprocess, sys, os -import panel as pn -import panel.widgets as pnw -import pandas as pd -import bokeh.palettes -from bokeh.models.ranges import FactorRange -from bokeh.transform import jitter -from bokeh.models import HoverTool, LinearColorMapper, CategoricalColorMapper -from bokeh.plotting import ColumnDataSource, figure +def main(): + """Command-line interface with backend selection.""" + parser = argparse.ArgumentParser(prog="ktdashboard") + parser.add_argument("--backend", choices=["panel", "streamlit"], default="panel", help="Backend to use for visualization") + parser.add_argument("filename", help="Path to cache JSON file") -class KTdashboard: - """ Main object to instantiate to hold everything related to a running dashboard""" + args = parser.parse_args() - def __init__(self, cache_file, demo=False, default_key=None): - self.demo = demo - self.cache_file = cache_file + if not os.path.isfile(args.filename): + print("Cachefile not found") + exit(1) - # read in the cachefile - self.cache_file_handle = open(cache_file, "r") - filestr = self.cache_file_handle.read().strip() - # if file was not properly closed, pretend it was properly closed - if not filestr[-3:] == "}\n}": - # remove the trailing comma if any, and append closing brackets - if filestr[-1] == ",": - filestr = filestr[:-1] - filestr = filestr + "}\n}" + if args.backend == "streamlit": + script_path = os.path.join(os.path.dirname(__file__), "streamlit_dashboard.py") + cmd = [sys.executable, "-m", "streamlit", "run", script_path, "--", args.filename] + subprocess.run(cmd) + return - cached_data = json.loads(filestr) - self.kernel_name = cached_data["kernel_name"] - self.device_name = cached_data["device_name"] - if "objective" in cached_data: - self.objective = cached_data["objective"] - else: - self.objective = "time" + if args.backend == "panel": + from panel_dashboard import serve_panel + serve_panel(args.filename) + return - # get the performance data - data = list(cached_data["cache"].values()) - data = [d for d in data if d[self.objective] != 1e20 and not isinstance(d[self.objective], str)] - - # use all data or just the first 1000 records in demo mode - self.index = len(data) - if self.demo: - self.index = min(len(data), 1000) - - all_tune_param_keys = cached_data["tune_params_keys"] - all_tune_params = dict() - - for key in all_tune_param_keys: - values = cached_data["tune_params"][key] - for row in data: - if row[key] not in values: - values = sorted(values + [row[key]]) - - all_tune_params[key] = values - - # figure out which keys are interesting - single_value_tune_param_keys = [key for key in all_tune_param_keys if len(all_tune_params[key]) == 1] - tune_param_keys = [key for key in all_tune_param_keys if key not in single_value_tune_param_keys] - scalar_value_keys = [key for key in data[0].keys() if not isinstance(data[0][key],list) and key not in single_value_tune_param_keys] - output_keys = [key for key in scalar_value_keys if key not in tune_param_keys] - float_keys = [key for key in output_keys if isinstance(data[0][key], float)] - - self.single_value_tune_param_keys = single_value_tune_param_keys - self.tune_param_keys = tune_param_keys - self.scalar_value_keys = scalar_value_keys - self.output_keys = output_keys - self.float_keys = float_keys - - # Convert to a data frame - data_df = pd.DataFrame(data[:self.index])[scalar_value_keys] - - # Replace all column that are objects by categorical - for column, dtype in data_df.dtypes.items(): - if column in tune_param_keys and dtype == "object": - data_df[column] = pd.Categorical( - data_df[column], - categories=all_tune_params[column], - ordered=True) - - self.data = data - self.data_df = data_df - self.source = ColumnDataSource(data=self.data_df) - self.selected_tune_params = {key: all_tune_params[key].copy() for key in tune_param_keys} - - self.plot_width = 900 - self.plot_height = 600 - plot_options=dict(width=self.plot_width, min_width=self.plot_width, height=self.plot_height, min_height=self.plot_height) - plot_options['tools'] = [HoverTool(tooltips=[(k, "@{"+k+"}" + ("{0.00}" if k in float_keys else "")) for k in scalar_value_keys]), "box_select,box_zoom,save,reset"] - - self.plot_options = plot_options - - # find default key - if default_key is None: - default_key = 'GFLOP/s' - if default_key not in scalar_value_keys: - default_key = 'time' # Check if time is defined - - if default_key not in scalar_value_keys: - default_key = scalar_value_keys[0] - - # setup widgets - self.yvariable = pnw.Select(name='Y', value=default_key, options=scalar_value_keys) - self.xvariable = pnw.Select(name='X', value='index', options=['index']+scalar_value_keys) - self.colorvariable = pnw.Select(name='Color By', value=default_key, options=scalar_value_keys) - self.xscale = pnw.RadioButtonGroup(name="xscale", options=["linear", "log"]) - self.yscale = pnw.RadioButtonGroup(name="yscale", options=["linear", "log"]) - - # connect widgets with the function that draws the scatter plot - self.scatter = pn.bind( - self.make_scatter, - xvariable=self.xvariable, - yvariable=self.yvariable, - color_by=self.colorvariable, - xscale=self.xscale, - yscale=self.yscale) - - # actually build up the dashboard - self.dashboard = pn.template.BootstrapTemplate(title='Kernel Tuner Dashboard') - self.dashboard.main.append(self.scatter) - self.dashboard.sidebar.append(pn.Column( - self.yvariable, - self.xvariable, - self.colorvariable)) - - self.dashboard.sidebar.append(pn.layout.Divider()) - - self.dashboard.sidebar.append(pn.Row( - pn.pane.Markdown("X axis"), - self.xscale - )) - - self.dashboard.sidebar.append(pn.Row( - pn.pane.Markdown("Y axis"), - self.yscale - )) - - self.dashboard.sidebar.append(pn.layout.Divider()) - - self.multi_choice = list() - for tune_param in self.tune_param_keys: - values = all_tune_params[tune_param] - - multi_choice = pnw.MultiChoice(name=tune_param, value=values, options=values) - self.dashboard.sidebar.append(multi_choice) - - row = pn.bind(self.update_data_selection, tune_param, multi_choice) - self.dashboard.sidebar.append(row) - - def __del__(self): - self.cache_file_handle.close() - - def notebook(self): - """ Return a static version of the dashboard without the template """ - return pn.Row(pn.Column(self.yvariable, self.xvariable, self.colorvariable), self.scatter) - - def update_data_selection(self, tune_param, multi_choice): - """ Update view according to values selected by the user """ - selection_key = tune_param - selection_values = multi_choice - - # The idea here is to remember multiple selections across different tunable parameters - # but also allowing these to shrink or grow over time - # this is why the mask is recomputed every time the selection changes - self.selected_tune_params[selection_key] = selection_values - - # Cross selection based on all selections in all tunable parameters - mask = pd.Series(True, index=self.data_df.index) - for k,v in self.selected_tune_params.items(): - mask &= self.data_df[k].isin(v) - - index = self.data_df.index[mask].values - self.index = index - - data_df = self.data_df[mask] - self.source.data = data_df - - def update_colors(self, color_by): - dtype = self.data_df.dtypes[color_by] - - if dtype == "category": - factors = dtype.categories - if len(factors) < 10: - palette = bokeh.palettes.Category10[10] - else: - palette = bokeh.palettes.Category20[20] - - - color_mapper = CategoricalColorMapper(palette=palette, factors=factors) - - else: - color_mapper = LinearColorMapper(palette='Viridis256', low=min(self.data_df[color_by]), - high=max(self.data_df[color_by])) - - color = {'field': color_by, 'transform': color_mapper} - return color - - def make_scatter(self, xvariable, yvariable, color_by, xscale, yscale): - color = self.update_colors(color_by) - - x = xvariable - y = yvariable - - plot_options = dict(self.plot_options) - plot_options["x_axis_type"] = xscale - plot_options["y_axis_type"] = yscale - - # For categorical data, we add some jitter - dtype = self.data_df.dtypes.get(xvariable) - if dtype == "category": - plot_options["x_range"] = list(dtype.categories) - x = jitter(xvariable, width=0.02, distribution="normal", - range=FactorRange(*dtype.categories)) - - dtype = self.data_df.dtypes.get(yvariable) - if dtype == "category": - plot_options["y_range"] = list(dtype.categories) - x = jitter(yvariable, width=0.02, distribution="normal", - range=FactorRange(*dtype.categories)) - - f = figure(**plot_options) - f.scatter(x, y, size=5, color=color, alpha=0.5, source=self.source) - f.xaxis.axis_label = xvariable - f.yaxis.axis_label = yvariable - - bokeh_pane = pn.pane.Bokeh(object=f, min_width=self.plot_width, min_height=self.plot_height, max_width=self.plot_width, max_height=self.plot_height) - - pane = pn.Column(pn.pane.Markdown(f"## Auto-tuning {self.kernel_name} on {self.device_name}"), bokeh_pane) - - return pane - - def update_plot(self, i): - stream_dict = {k:[v] for k,v in dict(self.data[i], index=i).items() if k in ['index']+self.scalar_value_keys} - self.source.stream(stream_dict) - - def update_data(self): - if not self.demo: - new_contents = self.cache_file_handle.read().strip() - if new_contents: - - # process new contents (parse as JSON, make into dict that goes into source.stream) - new_contents_json = "{" + new_contents[:-1] + "}" - new_data = list(json.loads(new_contents_json).values()) - - for i,element in enumerate(new_data): - - stream_dict = {k:[v] for k,v in dict(element, index=self.index+i).items() if k in ['index']+self.scalar_value_keys} - self.source.stream(stream_dict) - - self.index += len(new_data) - - if self.demo: - if self.index < (len(self.data)-1): - self.update_plot(self.index) - self.index += 1 - - - -def print_usage(): - print("Usage: ./dashboard.py [-demo] filename") - print(" -demo option to enable demo mode that mimicks a running Kernel Tuner session") - print(" filename name of the cachefile") - exit(0) - - - -def cli(): - """ implements the command-line interface to start the dashboard """ - - if len(sys.argv) < 2: - print_usage() - - filename = "" - demo = False - if len(sys.argv) == 2: - if os.path.isfile(sys.argv[1]): - filename = sys.argv[1] - else: - print("Cachefile not found") - exit(1) - elif len(sys.argv) == 3: - if sys.argv[1] == "-demo": - demo = True - else: - print_usage() - if os.path.isfile(sys.argv[2]): - filename = sys.argv[2] - - db = KTdashboard(filename, demo=demo) - - db.dashboard.servable() - - def dashboard_f(): - """ wrapper function to add the callback, doesn't work without this construct """ - pn.state.add_periodic_callback(db.update_data, 1000) - return db.dashboard - server = pn.serve(dashboard_f, show=False) + exit(1) if __name__ == "__main__": - cli() + main() diff --git a/ktdashboard/panel_dashboard.py b/ktdashboard/panel_dashboard.py new file mode 100644 index 0000000..d0ac02f --- /dev/null +++ b/ktdashboard/panel_dashboard.py @@ -0,0 +1,220 @@ +from typing import Optional +import panel as pn +import panel.widgets as pnw +import pandas as pd +import bokeh.palettes +from bokeh.models.ranges import FactorRange +from bokeh.transform import jitter +from bokeh.models import HoverTool, LinearColorMapper, CategoricalColorMapper +from bokeh.plotting import ColumnDataSource, figure + +from dashboard import Dashboard + + +class PanelDashboard: + def __init__(self, cachefile: str, default_key: Optional[str] = None): + self.model = Dashboard(cachefile) + + # local copies for UI + self.data_df = self.model.data_df + self.scalar_value_keys = self.model.scalar_value_keys + self.tune_param_keys = self.model.tune_param_keys + self.all_tune_params = self.model.all_tune_params + self.source = ColumnDataSource(data=self._df_categorical_to_str(self.data_df)) + self.selected_tune_params = { + key: self.all_tune_params[key].copy() for key in self.tune_param_keys + } + + # layout parameters + self.plot_height = 600 + plot_options = dict( + height=self.plot_height, + min_height=self.plot_height, + sizing_mode="stretch_width", + ) + float_keys = [k for k in self.scalar_value_keys if k in self.model.float_keys] + plot_options["tools"] = [ + HoverTool( + tooltips=[ + (k, "@{" + k + "}" + ("{0.00}" if k in float_keys else "")) + for k in self.scalar_value_keys + ] + ), + "box_select,box_zoom,save,reset", + ] + self.plot_options = plot_options + + # find default key + if default_key is None: + default_key = "GFLOP/s" + if default_key not in self.scalar_value_keys: + default_key = ( + "time" + if "time" in self.scalar_value_keys + else (self.scalar_value_keys[0] if self.scalar_value_keys else None) + ) + + # Widgets + self.yvariable = pnw.Select( + name="Y", value=default_key, options=self.scalar_value_keys + ) + self.xvariable = pnw.Select( + name="X", value="index", options=["index"] + self.scalar_value_keys + ) + self.colorvariable = pnw.Select( + name="Color By", value=default_key, options=self.scalar_value_keys + ) + self.xscale = pnw.RadioButtonGroup(name="xscale", options=["linear", "log"]) + self.yscale = pnw.RadioButtonGroup(name="yscale", options=["linear", "log"]) + + # connect widgets + self.scatter = pn.bind( + self.make_scatter, + xvariable=self.xvariable, + yvariable=self.yvariable, + color_by=self.colorvariable, + xscale=self.xscale, + yscale=self.yscale, + ) + + # build up the dashboard + self.dashboard = pn.template.BootstrapTemplate(title="Kernel Tuner Dashboard") + self.dashboard.main.append(self.scatter) + self.dashboard.sidebar.append( + pn.Column(self.yvariable, self.xvariable, self.colorvariable) + ) + self.dashboard.sidebar.append(pn.layout.Divider()) + self.dashboard.sidebar.append(pn.Row(pn.pane.Markdown("X axis"), self.xscale)) + self.dashboard.sidebar.append(pn.Row(pn.pane.Markdown("Y axis"), self.yscale)) + self.dashboard.sidebar.append(pn.layout.Divider()) + + for tune_param in self.tune_param_keys: + values = self.all_tune_params[tune_param] + multi_choice = pnw.MultiChoice( + name=tune_param, value=values, options=values + ) + self.dashboard.sidebar.append(multi_choice) + row = pn.bind(self.update_data_selection, tune_param, multi_choice) + self.dashboard.sidebar.append(row) + + def _df_categorical_to_str(self, df: pd.DataFrame) -> pd.DataFrame: + """Return a copy of `df` where categorical columns are converted to strings.""" + df2 = df.copy() + for c in df2.columns: + if pd.api.types.is_categorical_dtype(df2[c]): + df2[c] = df2[c].astype(str) + return df2 + + def _convert_stream_dict(self, sd: dict) -> dict: + """Convert values inside a stream-dict to string for categorical columns.""" + sd2 = {} + for k, v in sd.items(): + if k in self.data_df.columns and pd.api.types.is_categorical_dtype( + self.data_df[k] + ): + sd2[k] = [str(x) for x in v] + else: + sd2[k] = v + return sd2 + + def update_data_selection(self, tune_param, multi_choice): + self.selected_tune_params[tune_param] = multi_choice + # Cross selection based on all selections in all tunable parameters + mask = pd.Series(True, index=self.data_df.index) + for k, v in self.selected_tune_params.items(): + mask &= self.data_df[k].isin(v) + data_df = self.data_df[mask] + self.source.data = self._df_categorical_to_str(data_df) + + def update_colors(self, color_by): + dtype = self.data_df.dtypes[color_by] + + if dtype == "category": + factors = [str(f) for f in dtype.categories] + if len(factors) < 10: + palette = bokeh.palettes.Category10[10] + else: + palette = bokeh.palettes.Category20[20] + color_mapper = CategoricalColorMapper(palette=palette, factors=factors) + else: + color_mapper = LinearColorMapper( + palette="Viridis256", + low=min(self.data_df[color_by]), + high=max(self.data_df[color_by]), + ) + + color = {"field": color_by, "transform": color_mapper} + return color + + def make_scatter(self, xvariable, yvariable, color_by, xscale, yscale): + color = self.update_colors(color_by) + + x = xvariable + y = yvariable + + plot_options = dict(self.plot_options) + plot_options["x_axis_type"] = xscale + plot_options["y_axis_type"] = yscale + + dtype = self.data_df.dtypes.get(xvariable) + if pd.api.types.is_categorical_dtype(dtype): + x_factors = [str(f) for f in dtype.categories] + plot_options["x_range"] = x_factors + x = jitter( + xvariable, + width=0.02, + distribution="normal", + range=FactorRange(*x_factors), + ) + + dtype = self.data_df.dtypes.get(yvariable) + if pd.api.types.is_categorical_dtype(dtype): + y_factors = [str(f) for f in dtype.categories] + plot_options["y_range"] = y_factors + y = jitter( + yvariable, + width=0.02, + distribution="normal", + range=FactorRange(*y_factors), + ) + + f = figure(**plot_options) + f.scatter(x, y, size=5, color=color, alpha=0.5, source=self.source) + f.xaxis.axis_label = xvariable + f.yaxis.axis_label = yvariable + + bokeh_pane = pn.pane.Bokeh( + object=f, + min_width=self.plot_width, + min_height=self.plot_height, + max_width=self.plot_width, + max_height=self.plot_height, + ) + pane = pn.Column( + pn.pane.Markdown( + f"## Auto-tuning {self.model.kernel_name} on {self.model.device_name}" + ), + bokeh_pane, + ) + return pane + + def update_plot(self, i): + sd = self.model.get_stream_for_index(i) + self.source.stream(self._convert_stream_dict(sd)) + + def update_data(self): + stream_dicts = self.model.read_new_contents() + for sd in stream_dicts: + self.source.stream(self._convert_stream_dict(sd)) + + +def serve_panel(cachefile: str) -> None: + ui = PanelDashboard(cachefile) + + ui.dashboard.servable() + + def dashboard_f(): + pn.state.add_periodic_callback(ui.update_data, 1000) + return ui.dashboard + + pn.serve(dashboard_f, show=False) diff --git a/ktdashboard/streamlit.py b/ktdashboard/streamlit.py deleted file mode 100644 index 75ee3a0..0000000 --- a/ktdashboard/streamlit.py +++ /dev/null @@ -1,282 +0,0 @@ -#!/usr/bin/env python -import json -import argparse -from typing import Tuple, Dict, Any, List - -import streamlit as st -import pandas as pd -import plotly.express as px -import numpy as np - - -def _read_cachefile(cache_file: str) -> Dict[str, Any]: - # Read file and handle partial/trailing content - with open(cache_file, "r") as fh: - filestr = fh.read().strip() - - if filestr == "": - raise ValueError("Cache file is empty") - - # Try to be permissive: if file ends with a stray comma or missing closing braces, try to fix - if not filestr.endswith("}\n}") and not filestr.endswith("}\n}"): - # remove trailing comma if present - if filestr[-1] == ",": - filestr = filestr[:-1] - # attempt to close - if not filestr.endswith("}\n}"): - filestr = filestr + "}\n}" - - cached_data = json.loads(filestr) - return cached_data - - -def prepare_dataframe( - cached_data: Dict[str, Any], objective: str = None -) -> Tuple[pd.DataFrame, Dict[str, List[Any]]]: - if objective is None: - objective = cached_data.get("objective", "time") - - data = list(cached_data["cache"].values()) - data = [ - d - for d in data - if d.get(objective) != 1e20 and not isinstance(d.get(objective), str) - ] - - tune_params_keys = cached_data["tune_params_keys"] - all_tune_params = {} - for key in tune_params_keys: - values = cached_data["tune_params"][key] - for row in data: - if row[key] not in values: - values = sorted(values + [row[key]]) - all_tune_params[key] = values - - # figure out which keys are interesting - single_value_tune_param_keys = [ - key for key in tune_params_keys if len(all_tune_params[key]) == 1 - ] - tune_param_keys = [ - key for key in tune_params_keys if key not in single_value_tune_param_keys - ] - scalar_value_keys = [ - key - for key in data[0].keys() - if not isinstance(data[0][key], list) - and key not in single_value_tune_param_keys - ] - output_keys = [key for key in scalar_value_keys if key not in tune_param_keys] - - df = pd.DataFrame(data)[scalar_value_keys] - - # Add 'index' column (numeric) so the UI can select "index" as an axis - df = df.reset_index(drop=True) - df.insert(0, "index", df.index.astype(int)) - - # Convert tune params to categorical where appropriate to preserve ordering - for key in tune_param_keys: - if key in df.columns: - df[key] = pd.Categorical( - df[key], categories=all_tune_params[key], ordered=True - ) - - return df, { - "tune_param_keys": tune_param_keys, - "all_tune_params": all_tune_params, - "scalar_value_keys": scalar_value_keys, - "output_keys": output_keys, - } - - -def filter_dataframe( - df: pd.DataFrame, selections: Dict[str, List[Any]] -) -> pd.DataFrame: - mask = pd.Series(True, index=df.index) - for k, v in selections.items(): - if v: - mask &= df[k].isin(v) - return df[mask] - - -def plot_scatter( - df: pd.DataFrame, - x: str, - y: str, - color: str, - xscale: str, - yscale: str, - palette: str = "Viridis", -) -> Any: - # For categorical axes, we can map categories to numbers and add jitter for visual separation - df_plot = df.copy() - - def jitter(col): - dtype = df_plot[col].dtype - if isinstance(dtype, pd.CategoricalDtype) or dtype == object: - categories = list(pd.Categorical(df_plot[col]).categories) - mapping = {c: i for i, c in enumerate(categories)} - arr = df_plot[col].map(mapping).astype(float) - arr += np.random.normal(scale=0.15, size=len(arr)) - return arr, categories - else: - return df_plot[col], None - - x_vals, x_cats = jitter(x) - y_vals, y_cats = jitter(y) - - df_plot["_x"] = x_vals - df_plot["_y"] = y_vals - - color_arg = color if color in df_plot.columns else None - - # Determine palette lists from plotly (sequential palettes only) - seq = getattr(px.colors.sequential, palette, None) - color_kwargs = {"color_continuous_scale": seq} - - fig = px.scatter( - df_plot, - x="_x", - y="_y", - color=color_arg, - hover_data=df_plot.columns, - height=600, - width=900, - labels={"_x": x, "_y": y}, - **color_kwargs, - ) - - # If axis corresponded to categories, set tick labels - if x_cats is not None: - fig.update_xaxes( - tickmode="array", tickvals=list(range(len(x_cats))), ticktext=x_cats - ) - if y_cats is not None: - fig.update_yaxes( - tickmode="array", tickvals=list(range(len(y_cats))), ticktext=y_cats - ) - - # Set log scales if requested - if xscale == "log": - fig.update_xaxes(type="log") - if yscale == "log": - fig.update_yaxes(type="log") - - return fig - - -def main(): - parser = argparse.ArgumentParser(description="Streamlit Kernel Tuner Dashboard") - parser.add_argument("cachefile", nargs="?", help="Path to cache file (JSON)") - args = parser.parse_args() - - st.set_page_config(layout="wide", page_title="Kernel Tuner Dashboard") - - st.title("Kernel Tuner Dashboard") - - # Allow providing file via arg or file uploader - cachefile = args.cachefile - - if not cachefile: - uploaded = st.sidebar.file_uploader("Upload cache JSON file", type=["json"]) - if uploaded is not None: - # save to a temp file so we can read locations - import tempfile - - tf = tempfile.NamedTemporaryFile(delete=False, suffix=".json") - tf.write(uploaded.read()) - tf.flush() - cachefile = tf.name - - if not cachefile: - st.info( - "Provide a cache file via command-line (streamlit run ... -- ) or upload one in the sidebar." - ) - return - - try: - cached_data = _read_cachefile(cachefile) - except Exception as exc: - st.error(f"Failed to read cache file: {exc}") - return - - kernel_name = cached_data.get("kernel_name", "") - device_name = cached_data.get("device_name", "") - - st.sidebar.markdown(f"**Kernel:** {kernel_name}") - st.sidebar.markdown(f"**Device:** {device_name}") - - df, meta = prepare_dataframe(cached_data) - - scalar_value_keys = meta["scalar_value_keys"] - tune_param_keys = meta["tune_param_keys"] - all_tune_params = meta["all_tune_params"] - - default_key = ( - "GFLOP/s" - if "GFLOP/s" in scalar_value_keys - else ("time" if "time" in scalar_value_keys else scalar_value_keys[0]) - ) - - yvariable = st.sidebar.selectbox( - "Y", options=scalar_value_keys, index=scalar_value_keys.index(default_key) - ) - xvariable = st.sidebar.selectbox( - "X", options=["index"] + scalar_value_keys, index=0 - ) - colorvariable = st.sidebar.selectbox( - "Color By", - options=scalar_value_keys, - index=scalar_value_keys.index(default_key), - ) - xscale = st.sidebar.radio("X axis scale", options=["linear", "log"], index=0) - yscale = st.sidebar.radio("Y axis scale", options=["linear", "log"], index=0) - - # Color palette chooser (sequential palettes only) - seq_names = [ - name - for name in dir(px.colors.sequential) - if not name.startswith("_") - and isinstance(getattr(px.colors.sequential, name), list) - ] - seq_names = sorted(seq_names) - default_idx = seq_names.index("Viridis") if "Viridis" in seq_names else 0 - palette = st.sidebar.selectbox( - "Color palette", options=seq_names, index=default_idx - ) - - # tune param multi-selects - selections = {} - for tp in tune_param_keys: - selections[tp] = st.sidebar.multiselect( - tp, options=all_tune_params[tp], default=list(all_tune_params[tp]) - ) - - filtered_df = filter_dataframe(df, selections) - - st.markdown(f"## Auto-tuning {kernel_name} on {device_name}") - - fig = plot_scatter( - filtered_df, - xvariable, - yvariable, - colorvariable, - xscale, - yscale, - palette=palette, - ) - st.plotly_chart(fig, width="stretch") - - st.markdown("---") - st.markdown("### Top results") - - # Show best by selected y (if numeric) - if pd.api.types.is_numeric_dtype(df[yvariable]): - best = df.sort_values(yvariable).head(10) - st.dataframe(best) - else: - st.write("Y variable is non-numeric; showing raw filtered data") - st.dataframe(filtered_df) - - -if __name__ == "__main__": - main() diff --git a/ktdashboard/streamlit_dashboard.py b/ktdashboard/streamlit_dashboard.py new file mode 100644 index 0000000..d446bb2 --- /dev/null +++ b/ktdashboard/streamlit_dashboard.py @@ -0,0 +1,168 @@ +import argparse +from typing import Any +import streamlit as st +import pandas as pd +import plotly.express as px +import numpy as np + +from dashboard import Dashboard + + +class StreamlitDashboard: + def __init__(self, cachefile: str): + self.model = Dashboard(cachefile) + self.df = self.model.data_df + + def plot_scatter( + self, + df: pd.DataFrame, + x: str, + y: str, + color: str, + xscale: str, + yscale: str, + palette: str = "Viridis", + ) -> Any: + df_plot = df.copy() + + def jitter(col): + dtype = df_plot[col].dtype + if isinstance(dtype, pd.CategoricalDtype) or dtype == object: + categories = list(pd.Categorical(df_plot[col]).categories) + mapping = {c: i for i, c in enumerate(categories)} + arr = df_plot[col].map(mapping).astype(float) + arr += np.random.normal(scale=0.15, size=len(arr)) + return arr, categories + else: + return df_plot[col], None + + x_vals, x_cats = jitter(x) + y_vals, y_cats = jitter(y) + + df_plot["_x"] = x_vals + df_plot["_y"] = y_vals + + color_arg = color if color in df_plot.columns else None + + seq = getattr(px.colors.sequential, palette, None) + color_kwargs = {"color_continuous_scale": seq} + + fig = px.scatter( + df_plot, + x="_x", + y="_y", + color=color_arg, + hover_data=df_plot.columns, + height=600, + width=900, + labels={"_x": x, "_y": y}, + **color_kwargs, + ) + + if x_cats is not None: + fig.update_xaxes( + tickmode="array", tickvals=list(range(len(x_cats))), ticktext=x_cats + ) + if y_cats is not None: + fig.update_yaxes( + tickmode="array", tickvals=list(range(len(y_cats))), ticktext=y_cats + ) + + if xscale == "log": + fig.update_xaxes(type="log") + if yscale == "log": + fig.update_yaxes(type="log") + + return fig + + def render(self): + st.set_page_config(layout="wide", page_title="Kernel Tuner Dashboard") + st.title("Kernel Tuner Dashboard") + + kernel_name = self.model.kernel_name + device_name = self.model.device_name + + st.sidebar.markdown(f"**Kernel:** {kernel_name}") + st.sidebar.markdown(f"**Device:** {device_name}") + + scalar_value_keys = self.model.scalar_value_keys + tune_param_keys = self.model.tune_param_keys + all_tune_params = self.model.all_tune_params + + default_key = ( + "GFLOP/s" + if "GFLOP/s" in scalar_value_keys + else ("time" if "time" in scalar_value_keys else scalar_value_keys[0]) + ) + + yvariable = st.sidebar.selectbox( + "Y", options=scalar_value_keys, index=scalar_value_keys.index(default_key) + ) + xvariable = st.sidebar.selectbox( + "X", options=["index"] + scalar_value_keys, index=0 + ) + colorvariable = st.sidebar.selectbox( + "Color By", + options=scalar_value_keys, + index=scalar_value_keys.index(default_key), + ) + xscale = st.sidebar.radio("X axis scale", options=["linear", "log"], index=0) + yscale = st.sidebar.radio("Y axis scale", options=["linear", "log"], index=0) + + # Color palette chooser (sequential palettes only) + seq_names = [ + name + for name in dir(px.colors.sequential) + if not name.startswith("_") + and isinstance(getattr(px.colors.sequential, name), list) + ] + seq_names = sorted(seq_names) + default_idx = seq_names.index("Viridis") if "Viridis" in seq_names else 0 + palette = st.sidebar.selectbox( + "Color palette", options=seq_names, index=default_idx + ) + + # tune param multi-selects + selections = {} + for tp in tune_param_keys: + selections[tp] = st.sidebar.multiselect( + tp, options=all_tune_params[tp], default=list(all_tune_params[tp]) + ) + self.model.update_selection(tp, selections[tp]) + + filtered_df = self.model.get_filtered_df() + + st.markdown(f"## Auto-tuning {kernel_name} on {device_name}") + + fig = self.plot_scatter( + filtered_df, + xvariable, + yvariable, + colorvariable, + xscale, + yscale, + palette=palette, + ) + st.plotly_chart(fig, width="stretch") + + st.markdown("---") + st.markdown("### Top results") + + # Show best by selected y (if numeric) + if pd.api.types.is_numeric_dtype(filtered_df[yvariable]): + sorted_df = filtered_df.sort_values(yvariable).head(10) + st.dataframe(sorted_df) + else: + st.dataframe(filtered_df) + + +def serve_streamlit(cachefile: str) -> None: + sd = StreamlitDashboard(cachefile) + sd.render() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(prog="ktdashboard") + parser.add_argument("filename", help="Path to cache JSON file") + args = parser.parse_args() + serve_streamlit(args.filename) diff --git a/setup.py b/setup.py index c50cd83..c6d0f3d 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,6 @@ 'Topic :: System :: Distributed Computing', 'Development Status :: 3 - Alpha ', ], - install_requires=['bokeh','pandas','panel'], + install_requires=['bokeh','pandas','panel','streamlit','plotly'], entry_points={'console_scripts': ['ktdashboard = ktdashboard.ktdashboard:cli']}, ) From 2462f48f908e7858e9e7a2ae95cccbded00e4d9d Mon Sep 17 00:00:00 2001 From: Bram Veenboer Date: Mon, 12 Jan 2026 15:42:08 +0100 Subject: [PATCH 3/6] Add data table to panel dashboard --- ktdashboard/ktdashboard.py | 20 +++++++-- ktdashboard/panel_dashboard.py | 72 +++++++++++++++++++++++------- ktdashboard/streamlit_dashboard.py | 29 +++++++----- 3 files changed, 91 insertions(+), 30 deletions(-) diff --git a/ktdashboard/ktdashboard.py b/ktdashboard/ktdashboard.py index 1739dfa..0e7308c 100644 --- a/ktdashboard/ktdashboard.py +++ b/ktdashboard/ktdashboard.py @@ -1,11 +1,17 @@ #!/usr/bin/env python import argparse, subprocess, sys, os + def main(): """Command-line interface with backend selection.""" parser = argparse.ArgumentParser(prog="ktdashboard") - parser.add_argument("--backend", choices=["panel", "streamlit"], default="panel", help="Backend to use for visualization") + parser.add_argument( + "--backend", + choices=["panel", "streamlit"], + default="panel", + help="Backend to use for visualization", + ) parser.add_argument("filename", help="Path to cache JSON file") args = parser.parse_args() @@ -16,18 +22,26 @@ def main(): if args.backend == "streamlit": script_path = os.path.join(os.path.dirname(__file__), "streamlit_dashboard.py") - cmd = [sys.executable, "-m", "streamlit", "run", script_path, "--", args.filename] + cmd = [ + sys.executable, + "-m", + "streamlit", + "run", + script_path, + "--", + args.filename, + ] subprocess.run(cmd) return if args.backend == "panel": from panel_dashboard import serve_panel + serve_panel(args.filename) return exit(1) - if __name__ == "__main__": main() diff --git a/ktdashboard/panel_dashboard.py b/ktdashboard/panel_dashboard.py index d0ac02f..8766cf2 100644 --- a/ktdashboard/panel_dashboard.py +++ b/ktdashboard/panel_dashboard.py @@ -5,14 +5,21 @@ import bokeh.palettes from bokeh.models.ranges import FactorRange from bokeh.transform import jitter -from bokeh.models import HoverTool, LinearColorMapper, CategoricalColorMapper +from bokeh.models import ( + DataTable, + CategoricalColorMapper, + HoverTool, + LinearColorMapper, + TableColumn, +) from bokeh.plotting import ColumnDataSource, figure from dashboard import Dashboard class PanelDashboard: - def __init__(self, cachefile: str, default_key: Optional[str] = None): + def __init__( + self, cachefile: str, default_key: Optional[str] = None): self.model = Dashboard(cachefile) # local copies for UI @@ -67,14 +74,18 @@ def __init__(self, cachefile: str, default_key: Optional[str] = None): self.xscale = pnw.RadioButtonGroup(name="xscale", options=["linear", "log"]) self.yscale = pnw.RadioButtonGroup(name="yscale", options=["linear", "log"]) + # checkbox to show/hide the data table, toggling re-renders the pane + self.show_table_checkbox = pnw.Checkbox(name="Show table", value=False) + # connect widgets self.scatter = pn.bind( - self.make_scatter, + self.make_pane, xvariable=self.xvariable, yvariable=self.yvariable, color_by=self.colorvariable, xscale=self.xscale, yscale=self.yscale, + show_table=self.show_table_checkbox, ) # build up the dashboard @@ -87,6 +98,8 @@ def __init__(self, cachefile: str, default_key: Optional[str] = None): self.dashboard.sidebar.append(pn.Row(pn.pane.Markdown("X axis"), self.xscale)) self.dashboard.sidebar.append(pn.Row(pn.pane.Markdown("Y axis"), self.yscale)) self.dashboard.sidebar.append(pn.layout.Divider()) + self.dashboard.sidebar.append(self.show_table_checkbox) + self.dashboard.sidebar.append(pn.layout.Divider()) for tune_param in self.tune_param_keys: values = self.all_tune_params[tune_param] @@ -146,7 +159,7 @@ def update_colors(self, color_by): color = {"field": color_by, "transform": color_mapper} return color - def make_scatter(self, xvariable, yvariable, color_by, xscale, yscale): + def make_pane(self, xvariable, yvariable, color_by, xscale, yscale, show_table: bool = True): color = self.update_colors(color_by) x = xvariable @@ -156,6 +169,12 @@ def make_scatter(self, xvariable, yvariable, color_by, xscale, yscale): plot_options["x_axis_type"] = xscale plot_options["y_axis_type"] = yscale + # If the table is disabled we want the plot to take the full page height + if not show_table: + plot_options.pop("height", None) + plot_options.pop("min_height", None) + plot_options["sizing_mode"] = "stretch_both" + dtype = self.data_df.dtypes.get(xvariable) if pd.api.types.is_categorical_dtype(dtype): x_factors = [str(f) for f in dtype.categories] @@ -183,19 +202,40 @@ def make_scatter(self, xvariable, yvariable, color_by, xscale, yscale): f.xaxis.axis_label = xvariable f.yaxis.axis_label = yvariable - bokeh_pane = pn.pane.Bokeh( - object=f, - min_width=self.plot_width, - min_height=self.plot_height, - max_width=self.plot_width, - max_height=self.plot_height, + # DataTable showing the raw data + columns = [TableColumn(field=c, title=c) for c in self.source.column_names][1:] + data_table = DataTable( + source=self.source, + columns=columns, + selectable=True, + sizing_mode="stretch_width", ) - pane = pn.Column( - pn.pane.Markdown( - f"## Auto-tuning {self.model.kernel_name} on {self.model.device_name}" - ), - bokeh_pane, + + pane_title = ( + f"## Auto-tuning {self.model.kernel_name} on {self.model.device_name}" ) + + if show_table: + bokeh_pane = pn.pane.Bokeh( + object=f, + sizing_mode="stretch_width", + min_height=self.plot_height, + max_height=self.plot_height, + ) + pane_children = [ + pn.pane.Markdown(pane_title), + bokeh_pane, + pn.layout.Divider(), + data_table, + ] + else: + bokeh_pane = pn.pane.Bokeh(object=f, sizing_mode="stretch_both") + pane_children = [ + pn.pane.Markdown(pane_title), + bokeh_pane, + ] + + pane = pn.Column(*pane_children) return pane def update_plot(self, i): @@ -208,7 +248,7 @@ def update_data(self): self.source.stream(self._convert_stream_dict(sd)) -def serve_panel(cachefile: str) -> None: +def serve_panel(cachefile: str, show_table: bool = True) -> None: ui = PanelDashboard(cachefile) ui.dashboard.servable() diff --git a/ktdashboard/streamlit_dashboard.py b/ktdashboard/streamlit_dashboard.py index d446bb2..1495722 100644 --- a/ktdashboard/streamlit_dashboard.py +++ b/ktdashboard/streamlit_dashboard.py @@ -12,6 +12,7 @@ class StreamlitDashboard: def __init__(self, cachefile: str): self.model = Dashboard(cachefile) self.df = self.model.data_df + self.plot_height = 600 def plot_scatter( self, @@ -53,8 +54,7 @@ def jitter(col): y="_y", color=color_arg, hover_data=df_plot.columns, - height=600, - width=900, + height=self.plot_height, labels={"_x": x, "_y": y}, **color_kwargs, ) @@ -109,6 +109,9 @@ def render(self): xscale = st.sidebar.radio("X axis scale", options=["linear", "log"], index=0) yscale = st.sidebar.radio("Y axis scale", options=["linear", "log"], index=0) + # Show table control + show_table = st.sidebar.checkbox("Show table", value=True) + # Color palette chooser (sequential palettes only) seq_names = [ name @@ -134,6 +137,10 @@ def render(self): st.markdown(f"## Auto-tuning {kernel_name} on {device_name}") + plot_height = self.plot_height + if not show_table: + plot_height = int(plot_height * 1.5) + fig = self.plot_scatter( filtered_df, xvariable, @@ -143,17 +150,17 @@ def render(self): yscale, palette=palette, ) - st.plotly_chart(fig, width="stretch") - st.markdown("---") - st.markdown("### Top results") + st.plotly_chart(fig, height=plot_height, width="stretch") - # Show best by selected y (if numeric) - if pd.api.types.is_numeric_dtype(filtered_df[yvariable]): - sorted_df = filtered_df.sort_values(yvariable).head(10) - st.dataframe(sorted_df) - else: - st.dataframe(filtered_df) + if show_table: + st.markdown("---") + + if pd.api.types.is_numeric_dtype(filtered_df[yvariable]): + sorted_df = filtered_df.sort_values(yvariable) + st.dataframe(sorted_df) + else: + st.dataframe(filtered_df) def serve_streamlit(cachefile: str) -> None: From 188e4dcdb75a6cc327bc159d007202b358f4f09b Mon Sep 17 00:00:00 2001 From: Bram Veenboer Date: Mon, 12 Jan 2026 17:08:06 +0100 Subject: [PATCH 4/6] Reduce size of the title --- ktdashboard/streamlit_dashboard.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ktdashboard/streamlit_dashboard.py b/ktdashboard/streamlit_dashboard.py index 1495722..0e344ab 100644 --- a/ktdashboard/streamlit_dashboard.py +++ b/ktdashboard/streamlit_dashboard.py @@ -77,7 +77,6 @@ def jitter(col): def render(self): st.set_page_config(layout="wide", page_title="Kernel Tuner Dashboard") - st.title("Kernel Tuner Dashboard") kernel_name = self.model.kernel_name device_name = self.model.device_name @@ -135,7 +134,7 @@ def render(self): filtered_df = self.model.get_filtered_df() - st.markdown(f"## Auto-tuning {kernel_name} on {device_name}") + st.markdown(f"### Auto-tuning {kernel_name} on {device_name}") plot_height = self.plot_height if not show_table: From 73c4b85f5a17b5e11cdab0f3fd779d35ee6ca964 Mon Sep 17 00:00:00 2001 From: Bram Veenboer Date: Mon, 12 Jan 2026 17:22:18 +0100 Subject: [PATCH 5/6] Add file uploader to streamlit dashboard --- ktdashboard/ktdashboard.py | 24 +++++++++++---- ktdashboard/streamlit_dashboard.py | 48 +++++++++++++++++++++++++----- 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/ktdashboard/ktdashboard.py b/ktdashboard/ktdashboard.py index 0e7308c..5048dfa 100644 --- a/ktdashboard/ktdashboard.py +++ b/ktdashboard/ktdashboard.py @@ -3,8 +3,6 @@ def main(): - """Command-line interface with backend selection.""" - parser = argparse.ArgumentParser(prog="ktdashboard") parser.add_argument( "--backend", @@ -12,11 +10,25 @@ def main(): default="panel", help="Backend to use for visualization", ) - parser.add_argument("filename", help="Path to cache JSON file") + parser.add_argument( + "filename", nargs="?", help="Path to cache JSON file (optional for streamlit)" + ) args = parser.parse_args() - if not os.path.isfile(args.filename): + if args.backend == "panel": + if not args.filename: + print("Cachefile is required for the 'panel' backend") + exit(1) + if not os.path.isfile(args.filename): + print("Cachefile not found") + exit(1) + + if ( + args.backend == "streamlit" + and args.filename + and not os.path.isfile(args.filename) + ): print("Cachefile not found") exit(1) @@ -29,8 +41,10 @@ def main(): "run", script_path, "--", - args.filename, ] + # pass filename only when provided + if args.filename: + cmd.append(args.filename) subprocess.run(cmd) return diff --git a/ktdashboard/streamlit_dashboard.py b/ktdashboard/streamlit_dashboard.py index 0e344ab..4569083 100644 --- a/ktdashboard/streamlit_dashboard.py +++ b/ktdashboard/streamlit_dashboard.py @@ -4,15 +4,30 @@ import pandas as pd import plotly.express as px import numpy as np +import tempfile from dashboard import Dashboard class StreamlitDashboard: - def __init__(self, cachefile: str): - self.model = Dashboard(cachefile) - self.df = self.model.data_df + def __init__(self, cachefile: str | None = None, show_table: bool = True): + self.cachefile = cachefile + self.model = Dashboard(cachefile) if cachefile else None + self.df = self.model.data_df if self.model else pd.DataFrame() self.plot_height = 600 + self.show_table = show_table + + def load_from_uploaded_file(self, uploaded_file) -> None: + tf = tempfile.NamedTemporaryFile(delete=False, suffix=".json") + content = uploaded_file.read() + #if isinstance(content, str): + # content = content.encode("utf-8") + tf.write(content) + tf.flush() + tf.close() + self.cachefile = tf.name + self.model = Dashboard(self.cachefile) + self.df = self.model.data_df def plot_scatter( self, @@ -78,6 +93,17 @@ def jitter(col): def render(self): st.set_page_config(layout="wide", page_title="Kernel Tuner Dashboard") + if self.model is None: + uploaded = st.sidebar.file_uploader("Upload a cache file", type=["json"]) + if uploaded is None: + st.info("Upload a cache JSON file via the sidebar to get started.") + return + try: + self.load_from_uploaded_file(uploaded) + except Exception as e: + st.error(f"Failed to read uploaded file: {e}") + return + kernel_name = self.model.kernel_name device_name = self.model.device_name @@ -109,7 +135,7 @@ def render(self): yscale = st.sidebar.radio("Y axis scale", options=["linear", "log"], index=0) # Show table control - show_table = st.sidebar.checkbox("Show table", value=True) + show_table = st.sidebar.checkbox("Show table", value=self.show_table) # Color palette chooser (sequential palettes only) seq_names = [ @@ -162,13 +188,19 @@ def render(self): st.dataframe(filtered_df) -def serve_streamlit(cachefile: str) -> None: - sd = StreamlitDashboard(cachefile) +def serve_streamlit(cachefile: str | None = None, show_table: bool = True) -> None: + sd = StreamlitDashboard(cachefile, show_table=show_table) sd.render() if __name__ == "__main__": parser = argparse.ArgumentParser(prog="ktdashboard") - parser.add_argument("filename", help="Path to cache JSON file") + parser.add_argument("filename", nargs="?", help="Path to cache JSON file (optional)") + parser.add_argument( + "--table", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable data table (default: enabled)", + ) args = parser.parse_args() - serve_streamlit(args.filename) + serve_streamlit(args.filename if args.filename else None, show_table=args.table) From f53994cf90f3ecb80c242d7504e2a15ca44d16f6 Mon Sep 17 00:00:00 2001 From: Bram Veenboer Date: Tue, 13 Jan 2026 09:23:46 +0100 Subject: [PATCH 6/6] Remove commented code --- ktdashboard/streamlit_dashboard.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ktdashboard/streamlit_dashboard.py b/ktdashboard/streamlit_dashboard.py index 4569083..516d396 100644 --- a/ktdashboard/streamlit_dashboard.py +++ b/ktdashboard/streamlit_dashboard.py @@ -20,8 +20,6 @@ def __init__(self, cachefile: str | None = None, show_table: bool = True): def load_from_uploaded_file(self, uploaded_file) -> None: tf = tempfile.NamedTemporaryFile(delete=False, suffix=".json") content = uploaded_file.read() - #if isinstance(content, str): - # content = content.encode("utf-8") tf.write(content) tf.flush() tf.close()