Widget for GP intuition

Widget for GP intuition#

The widget here provides a Gaussian process (GP) visualization and some things to try to help build intuition about GPs. A GP defines a distribution over functions, so in the simulation we will see draws of random functions, in analogy to drawing random variables from a normal distribution (recall that “normal” and “Gaussian” are synonymous here). Our principal goal is to experience how GPs look in practice and how they are used for interpolation or regression.

The defining feature of an ensemble of GP function draws is that at any fixed point, the histogram of the function values at that point would approximate a Gaussian distribution; while for any two points the function values would be draws from a bivariate Gaussian distribution with a proscribed covariance matrix; and so on with more points being draws from a multivariate Gaussian distribution. The covariance in each case between any two points is specified by a kernel evaluated at those points.

Setting up to use the widget#

We use Panel and Bokeh here for the widget. If you don’t have them installed, uncomment the next cell (if using conda; alternatively you could run pip install bokeh, panel).

We use scikit-learn to generate the GP draws and optimize hyperparameters; we assume this is already available.

#! conda install -c conda-forge bokeh panel -y

Create the widget#

# Gaussian Process Regression — Panel + Bokeh (with 68% & 95% bands)
# -------------------------------------------------------------
# • Shows BOTH 68% (±1σ) and 95% (±1.96σ) credible bands with different alpha
# • Bands are underlays; mean and samples stay visible on top

import numpy as np

import panel as pn
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.events import Tap
from bokeh.palettes import Category10, Category20, Turbo256, Viridis256
from itertools import cycle

from scipy.linalg import solve_triangular

import warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import (
    RBF, Matern, RationalQuadratic, ExpSineSquared,
    WhiteKernel, DotProduct, ConstantKernel
)

pn.extension(sizing_mode="stretch_width") # load front-end components Component stretches and set the widget
                                          # to expand horizontally to fill available width with a fixed preferred height.

# -------------------- Config / State --------------------
rng = np.random.default_rng()  # reset random numbers

Xmin, Xmax = -5.0, 5.0    # x range of functions
Ymin, Ymax = -3.5, 3.5    # y range of functions
n_test = 400              # no. of points in draw from prior or posterior
Xtest = np.linspace(Xmin, Xmax, n_test).reshape(-1, 1) # reshape turns 1d array into (n_test x 1) 2d array

X_train = np.empty((0, 1))  # empty 2-D NumPy array with 0 rows and 1 column
y_train = np.empty((0,))    # empty 1-D NumPy array with 0 elements
sample_nonce = 0            # forces ColumnDataSource (CDS) updates for samples

def reseed_rng():
    global rng, sample_nonce
    rng = np.random.default_rng()  # fresh random numbers
    sample_nonce += 1              # keep CDS updates distinct if you use 'nonce'

def set_status(msg: str):
    """Render short, informative status messages."""
    try:
        status_pane.object = msg
    except Exception:
        # last resort in notebooks: still show *something*
        print(msg)

def set_status_if_empty(msg: str):
    try:
        cur = (status_pane.object or "").strip()
    except Exception:
        cur = ""
    if not cur:
        status_pane.object = msg
        


# -------------------- Widgets --------------------
# Select a kernel (default RBF)
kernel_sel  = pn.widgets.Select(
    name="Kernel", value="RBF (SE)",
    options=["RBF (SE)", "Matern ν=1.5", "Matern ν=2.5",
             "RationalQuadratic", "Periodic (ExpSineSquared)", "Linear (DotProduct)"]
)

# Sliders to set GP hyperparameters
#  ell --- scale length
#  sigma_f --- std of random draws
#  sigma_n --- std of observation white noise
#  alpha_rq --- shape parameter for (only) Rational Quadratic (RQ) kernel
#  periodicity --- period of ExpSineSquared (periodic) kernel
ell         = pn.widgets.FloatSlider(name="ℓ (length scale)", start=0.05, end=20.0, step=0.01, value=1.0)
sigma_f     = pn.widgets.FloatSlider(name="σ_f (signal std)", start=0.05, end=10.0, step=0.01, value=1.0)
sigma_n     = pn.widgets.FloatSlider(name="σ_n (noise std)",  start=0.0,  end=2.0,  step=0.001, value=0.1)
alpha_rq    = pn.widgets.FloatSlider(name="α (RQ)",           start=0.05, end=10.0, step=0.01, value=1.0, visible=False)
periodicity = pn.widgets.FloatSlider(name="Period",           start=0.1,  end=10.0, step=0.01, value=2.0,  visible=False)

big_font = {"font-size": "12pt"}
ell.styles = big_font
sigma_f.styles = big_font
sigma_n.styles = big_font

# Slider and checkboxes to control what is displayed in the widget
draw_n      = pn.widgets.IntSlider(name="# samples", start=0, end=100, step=1, value=3)
optimize    = pn.widgets.Checkbox(name="Optimize hyperparameters", value=False)
posterior   = pn.widgets.Checkbox(name="Draw from posterior [prior]", value=True)
show_mean   = pn.widgets.Checkbox(name="Show mean & bands (68% & 95%)", value=True)

optimize.styles = big_font



# Buttons to do things
btn_draw    = pn.widgets.Button(name="Draw samples", button_type="primary")
btn_clear   = pn.widgets.Button(name="Clear points", button_type="warning")
btn_reseed  = pn.widgets.Button(name="Reseed RNG", button_type="success")

# Indicator of status
status_pane = pn.pane.Markdown("Ready.", sizing_mode="stretch_width")


# -------------------- Bokeh Figure --------------------
p = figure(height=380, x_range=(Xmin, Xmax), y_range=(Ymin, Ymax),
           tools="pan,wheel_zoom,reset,tap", active_drag="pan",
           title="Gaussian Process Regression — click to add points")

src_pts   = ColumnDataSource(data=dict(x=[], y=[]))
p.scatter('x', 'y', source=src_pts, marker="circle", size=6, alpha=0.9, color="#1f77b4") # medium blue color

src_mean  = ColumnDataSource(data=dict(x=[], y=[]))
p.line('x', 'y', source=src_mean, line_width=3.5, color="#d62728", alpha=0.9, legend_label="GP mean")

# Two bands: 95% (lighter) and 68% (darker), both underlays
src_band95  = ColumnDataSource(data=dict(x=[], y1=[], y2=[]))
band95 = p.varea(x='x', y1='y1', y2='y2', source=src_band95, fill_alpha=0.20, fill_color="lightgray")
band95.level = "underlay"

src_band68  = ColumnDataSource(data=dict(x=[], y1=[], y2=[]))
band68 = p.varea(x='x', y1='y1', y2='y2', source=src_band68, fill_alpha=0.30, fill_color="gray")
band68.level = "underlay"

src_samples = ColumnDataSource(data=dict(xs=[], ys=[], nonce=[]))
p.multi_line(xs='xs', ys='ys', source=src_samples,
             line_color="#1f77b4", line_alpha=0.4, line_width=2)


p.add_tools(HoverTool(tooltips=[("x", "@x{0.000}"), ("y", "@y{0.000}")]))
p.legend.location = "top_left"

p.title.text_font_size = "14pt"
p.xaxis.axis_label_text_font_size = "14pt"
p.yaxis.axis_label_text_font_size = "14pt"
p.xaxis.major_label_text_font_size = "14pt"
p.yaxis.major_label_text_font_size = "14pt"
p.legend.label_text_font_size = "14pt"

# -------------------- GP Helpers --------------------
def make_kernel():
    # slider-driven bounds
    ell_bounds   = (ell.start, ell.end)
    sf_bounds    = (sigma_f.start**2, sigma_f.end**2)               # ConstantKernel uses variance
    sn_bounds    = (max(1e-9, sigma_n.start**2), sigma_n.end**2)    # avoid exact 0
    alpha_bounds = (alpha_rq.start, alpha_rq.end)
    per_bounds   = (periodicity.start, periodicity.end)

    sf2 = sigma_f.value**2
    l   = ell.value
    name = kernel_sel.value

    if name == "RBF (SE)":
        k_base = ConstantKernel(sf2, constant_value_bounds=sf_bounds) * \
                 RBF(length_scale=l, length_scale_bounds=ell_bounds)
    elif name == "Matern ν=1.5":
        k_base = ConstantKernel(sf2, constant_value_bounds=sf_bounds) * \
                 Matern(length_scale=l, length_scale_bounds=ell_bounds, nu=1.5)
    elif name == "Matern ν=2.5":
        k_base = ConstantKernel(sf2, constant_value_bounds=sf_bounds) * \
                 Matern(length_scale=l, length_scale_bounds=ell_bounds, nu=2.5)
    elif name == "RationalQuadratic":
        k_base = ConstantKernel(sf2, constant_value_bounds=sf_bounds) * \
                 RationalQuadratic(length_scale=l, alpha=alpha_rq.value,
                                   length_scale_bounds=ell_bounds, alpha_bounds=alpha_bounds)
    elif name == "Periodic (ExpSineSquared)":
        k_base = ConstantKernel(sf2, constant_value_bounds=sf_bounds) * \
                 ExpSineSquared(length_scale=l, periodicity=periodicity.value,
                                length_scale_bounds=ell_bounds, periodicity_bounds=per_bounds)
    elif name == "Linear (DotProduct)":
        k_base = ConstantKernel(sf2, constant_value_bounds=sf_bounds) * \
                 DotProduct(sigma_0=l, sigma_0_bounds=ell_bounds)
    else:
        k_base = ConstantKernel(sf2, constant_value_bounds=sf_bounds) * \
                 RBF(length_scale=l, length_scale_bounds=ell_bounds)

    sn2 = sigma_n.value**2
    return k_base + WhiteKernel(noise_level=sn2, noise_level_bounds=sn_bounds)
    

import warnings
from sklearn.exceptions import ConvergenceWarning

def fit_gp():
    """Fit GP; if optimize checkbox is on, try L-BFGS-B with restarts.
    Always write a status line with LML, movement, warnings, or errors."""
    if X_train.shape[0] == 0 or not posterior.value:
        set_status("No posterior fit (need data and 'Draw from posterior' checked).")
        # still return a non-None gp to keep callers simple? We'll return None here:
        return None

    init_kernel = make_kernel()
    init_theta = init_kernel.theta.copy()

    # Configure optimizer
    want_opt = bool(optimize.value)
    gp = GaussianProcessRegressor(
        kernel=init_kernel,
        optimizer=("fmin_l_bfgs_b" if want_opt else None),
        n_restarts_optimizer=(12 if want_opt else 0),
        normalize_y=False,
    )

    try:
        with warnings.catch_warnings(record=True) as caught:
            warnings.simplefilter("always", ConvergenceWarning)
            gp.fit(X_train, y_train)

        # Prepare status
        moved = (gp.kernel_.theta.shape == init_theta.shape) and \
                (not np.allclose(gp.kernel_.theta, init_theta))
        try:
            lml = gp.log_marginal_likelihood_value_
        except Exception:
            lml = gp.log_marginal_likelihood(gp.kernel_.theta)

        warn_msgs = [str(w.message) for w in caught if issubclass(w.category, ConvergenceWarning)]

        if want_opt:
            status = f"Optimized LML={lml:.3f}; " + ("params updated." if moved else "no change.")
        else:
            status = f"Fitted (no optimization) LML={lml:.3f}."

        if warn_msgs:
            status += " Warnings: " + " | ".join(warn_msgs)

        # Optional: append kernel summary
        status += f"  Kernel: {gp.kernel_}"
        set_status(status)
        return gp

    except Exception as e:
        # Fall back gracefully (no optimization), and surface the error
        set_status(f"⚠️ Optimization failed: {type(e).__name__}: {e}. Falling back to current sliders.")
        gp_fb = GaussianProcessRegressor(kernel=init_kernel, optimizer=None, normalize_y=False)
        gp_fb.fit(X_train, y_train)
        # show fallback kernel/LML too
        try:
            lml_fb = gp_fb.log_marginal_likelihood_value_
        except Exception:
            lml_fb = gp_fb.log_marginal_likelihood(gp_fb.kernel_.theta)
        set_status(f"Fallback fit LML={lml_fb:.3f}.  Kernel: {gp_fb.kernel_}")
        return gp_fb

def _base_kernel(k):
    """Extract the latent (noise-free) part from base + White."""
    try:
        from sklearn.gaussian_process.kernels import WhiteKernel as _WK
        if hasattr(k, "k2") and isinstance(k.k2, _WK): return k.k1
        if hasattr(k, "k1") and isinstance(k.k1, _WK): return k.k2
    except Exception:
        pass
    return k

def _latent_cov_on_grid(kernel):
    """Prior latent covariance on Xtest."""
    Kss = kernel(Xtest)
    try:
        from sklearn.gaussian_process.kernels import WhiteKernel as _WK
        if hasattr(kernel, "k2") and isinstance(kernel.k2, _WK): return kernel.k1(Xtest)
        if hasattr(kernel, "k1") and isinstance(kernel.k1, _WK): return kernel.k2(Xtest)
    except Exception:
        pass
    return Kss

def posterior_latent_stats(gp, Xstar):
    """Latent posterior mean/cov at Xstar (noise-free)."""
    K_base = _base_kernel(gp.kernel_)
    K_trans = K_base(X_train, Xstar)
    K_ss    = K_base(Xstar)
    v = solve_triangular(gp.L_, K_trans, lower=True)
    mu = (K_trans.T @ gp.alpha_).ravel()
    cov = K_ss - v.T @ v
    cov = (cov + cov.T) * 0.5
    return mu, cov


# -------------------- Update logic --------------------
def update_plot():
    global sample_nonce
    src_pts.data = dict(x=X_train.ravel().tolist(), y=y_train.tolist())
    xs = Xtest.ravel()
    xs_list = xs.tolist()

    # Helper to set both bands at once
    def _set_bands(mu, std):
        if show_mean.value:
            src_mean.data = dict(x=xs_list, y=mu.tolist())
            # 68% band (±1σ)
            src_band68.data = dict(x=xs_list,
                                   y1=(mu - 1.0*std).tolist(),
                                   y2=(mu + 1.0*std).tolist())
            # 95% band (±1.96σ)
            src_band95.data = dict(x=xs_list,
                                   y1=(mu - 1.96*std).tolist(),
                                   y2=(mu + 1.96*std).tolist())
        else:
            src_mean.data = dict(x=[], y=[])
            src_band68.data = dict(x=[], y1=[], y2=[])
            src_band95.data = dict(x=[], y1=[], y2=[])

    if posterior.value and X_train.shape[0] > 0:
        gp = fit_gp()
        mu, cov = posterior_latent_stats(gp, Xtest)
        std = np.sqrt(np.clip(np.diag(cov), 0.0, np.inf))
        _set_bands(mu, std)

        n = draw_n.value
        if n > 0:
            try:
                paths = rng.multivariate_normal(mean=mu, cov=cov, size=n)
            except np.linalg.LinAlgError:
                cov = cov + 1e-10*np.eye(n_test)
                paths = rng.multivariate_normal(mean=mu, cov=cov, size=n)

            new_data = dict(
                            xs=[xs_list]*n,
                            ys=[p.tolist() for p in paths],
                            nonce=[sample_nonce]*n,
                        )
            # HARD REFRESH: empty first, then assign new
            src_samples.data = dict(xs=[], ys=[], nonce=[])
            src_samples.data = new_data

        else:
            src_samples.data = dict(xs=[], ys=[], nonce=[])

    else:
        # Prior (latent)
        kernel = make_kernel()
        K_latent = _latent_cov_on_grid(kernel)
        mu = np.zeros(n_test)
        std = np.sqrt(np.maximum(0.0, np.diag(K_latent)))
        _set_bands(mu, std)

        n = draw_n.value
        if n > 0:
            try:
                paths = rng.multivariate_normal(mean=mu, cov=K_latent, size=n)
            except np.linalg.LinAlgError:
                paths = rng.multivariate_normal(mean=mu, cov=K_latent + 1e-10*np.eye(n_test), size=n)

            new_data = dict(
                            xs=[xs_list]*n,
                            ys=[p.tolist() for p in paths],
                            nonce=[sample_nonce]*n,
                        )
            # HARD REFRESH: empty first, then assign new
            src_samples.data = dict(xs=[], ys=[], color=[], nonce=[])
            src_samples.data = new_data
        else:
            src_samples.data = dict(xs=[], ys=[], color=[], nonce=[])

    set_status_if_empty("Ready.")

# -------------------- Event handlers --------------------
def on_tap(event):
    global X_train, y_train
    x = float(event.x); y = float(event.y)
    if not (Xmin <= x <= Xmax and Ymin <= y <= Ymax): return
    X_train = np.vstack([X_train, [[x]]])
    y_train = np.concatenate([y_train, [y]])
    order = np.argsort(X_train.ravel()); X_train, y_train = X_train[order], y_train[order]
    update_plot()

def on_any_param_change(event=None): update_plot()

def on_draw_click(event=None):
    global rng, sample_nonce
    src_samples.data = dict(xs=[], ys=[], color=[], nonce=[])
    reseed_rng()
    sample_nonce += 1
    update_plot()

def on_clear_click(event=None):
    global X_train, y_train
    X_train = np.empty((0, 1)); y_train = np.empty((0,))
    src_pts.data = dict(x=[], y=[]); src_mean.data = dict(x=[], y=[])
    src_band68.data = dict(x=[], y1=[], y2=[]); src_band95.data = dict(x=[], y1=[], y2=[])
    src_samples.data = dict(xs=[], ys=[], color=[], nonce=[])
    status_pane.object = "Cleared."
    update_plot()

def on_reseed_click(event=None):
    global rng
    reseed_rng()
    update_plot()


# -------------------- Wiring --------------------
p.on_event(Tap, on_tap)

def _toggle_kernel_specific(event=None):
    name = kernel_sel.value
    alpha_rq.visible    = (name == "RationalQuadratic")
    periodicity.visible = (name == "Periodic (ExpSineSquared)")
    update_plot()

for w in [kernel_sel, ell, sigma_f, sigma_n, alpha_rq, periodicity,
          posterior, show_mean, draw_n, optimize]:
    w.param.watch(lambda e: on_any_param_change(), "value")

#optimize.param.watch(lambda e: update_plot(), "value")


# Robust button wiring in notebooks
btn_draw.param.watch(lambda e: on_draw_click(),   "clicks")
btn_clear.param.watch(lambda e: on_clear_click(), "clicks")
btn_reseed.param.watch(lambda e: on_reseed_click(),"clicks")

_toggle_kernel_specific(); update_plot()

# -------------------- Layout --------------------
controls1 = pn.Row(kernel_sel, ell, sigma_f, sigma_n)
controls2 = pn.Row(alpha_rq, periodicity, draw_n, optimize)
posterior_row = pn.Column(posterior)   # posterior_hint)
toggles = pn.Row(posterior_row, show_mean, btn_draw, btn_clear, btn_reseed)

app = pn.Column(
    "### Gaussian Process Regression (Panel + Bokeh)",
    pn.pane.HTML("<i>Click in the plot to add training points. Adjust kernel & hyperparameters.</i>"),
    controls1, controls2, toggles,
    p,
    status_pane,
)

app