XAI: Variance-Based Analysis

MVA - APM_5MV75_TP

matfontaine.github.io/APM_5MV75_TP

Mathieu FONTAINE
mathieu.fontaine@telecom-paris.fr

January 2025

Outline

I - What Does “Feature Importance” Really Mean?

II - Measuring Importance Through Variance

III - Sobol Decomposition: Theory and Interpretation [Sob. 01]

IV - Local Feature Importance via Perturbations [Fel. 21]

V - Sobol Indices Estimation

VI - What Sobol Explains (and What It Does Not)

VII - Take-Home Messages

  • Sobol, I. M. Global sensitivity indices for nonlinear mathematical models and their Monte Carlo estimates. Mathematics and computers in simulation, 2001
  • Fel, Thomas, et al. "Look at the variance! efficient black-box explanations with sobol-based sensitivity analysis." NeurIPS, 2021

I - What Does “Feature Importance” Really Mean?

Not really mathematic…

  • Consider a model $f:\mathbb{R}^d \to \mathbb{R}$ (e.g. a classification score).
  • Intuitive definition:

A feature is important if changing it changes the prediction.


Changing it how?

  • Gradient-based importance (infinitesimal changes)
  • Occlusion / masking (finite removal)
  • SHAP (coalitions / marginal contributions)
  • Sobol (variance decomposition + interactions)

Example 1: Linear model

  • Consider $f(x_1,x_2)=3x_1+0.1x_2$.
  • Changing $x_1$ affects the output much more than changing $x_2$.
  • In this special case: importance $\approx$ coefficient magnitude.
  • This is a global notion (same for all inputs).

Example 2: Gradient (local)

  • Consider $f(x_1,x_2)=x_1^2+x_2$.
  • $\nabla f(x) = (2x_1,\,1)$.
  • At $x=(0,1)$: the sensitivity to $x_1$ is $0$, to $x_2$ is $1$.
  • We have a local notion of importance.
  • It depends on the local metric / scaling.

Example 3: Occlusion (finite change)

Occlusion = remove some input information and compare the prediction before/after.

  • Image: hide a patch
  • Audio: remove a frequency band / TF region
  • Text: delete a word

Let $f(x_1,x_2)=x_1x_2$ and input $(x_1,x_2)=(1,1)$. Then $f(1,1)=1$. If we remove $x_1$ (set $x_1=0$), then $f(0,1)=0$. If we remove $x_2$, $f(1,0)=0$.

Both are important… individually or only jointly?

Hidden interaction (Sobol motivation)

In $f(x_1,x_2)=x_1x_2$, neither variable has a strong individual effect everywhere: the effect comes from their interaction.

How can we separate what comes from individual effects and what comes from interactions?

Feature importance (in this course)

Feature importance: measure how much the prediction varies when a given feature is allowed to vary, possibly jointly with others.

Sobol: provides an exact and additive decomposition of variance into main effects and interactions (under assumptions).

Variance-based feature importance

  • To define "importance", we must define input variability.
$$\textbf{(1) Random input:}\qquad X=(X_1,\dots,X_d)\sim p_X$$

$$\textbf{(2) Random output:}\qquad Y=f(X)$$

$$\textbf{(3) Global variability:}\qquad \mathrm{Var}(Y)=\mathbb{E}[Y^2]-\mathbb{E}[Y]^2$$

$$\textbf{(4) Variance explained by }X_i:\qquad \mathrm{Var}\!\big(\mathbb{E}[Y\mid X_i]\big)$$
Intuition: if knowing $X_i$ changes the conditional mean prediction a lot, then $X_i$ is important.

II - Measuring Importance Through Variance

Setup: make the output random

  • Deterministic model: $f:\mathbb{R}^d \to \mathbb{R}$
  • Introduce random input: $X=(X_1,\dots,X_d)$
  • Output becomes a random variable:
$$Y=f(X)$$
All importance notions depend on how we choose the distribution of $X$ (or local perturbations later).

Notation

  • For a given index $i$, we denote:
$$X_{\sim i}=(X_1,\dots,X_{i-1},X_{i+1},\dots,X_d)$$
$X_i$ = “feature of interest”, $X_{\sim i}$ = “all other features”.

Variance as global sensitivity

  • Mean and variance:
$$\mathbb{E}[Y]=\int f(x)\,p_X(x)\,dx$$ $$\mathrm{Var}(Y)=\mathbb{E}[Y^2]-\mathbb{E}[Y]^2$$
High variance $\Rightarrow$ sensitive output. Low variance $\Rightarrow$ stable output.

Law of total variance

  • For any $Y$ and any variable $X_i$:
$$\mathrm{Var}(Y)=\mathrm{Var}\big(\mathbb{E}[Y\mid X_i]\big)+\mathbb{E}\big[\mathrm{Var}(Y\mid X_i)\big]$$
This identity is the starting point for variance-based feature importance.

Main effect vs total effect

  • Main-effect variance contribution:
$$D_i \triangleq \mathrm{Var}\big(\mathbb{E}[Y\mid X_i]\big)$$
  • Total-effect variance contribution (includes interactions):
$$D_i^{\mathrm{tot}} \triangleq \mathbb{E}\big[\mathrm{Var}(Y\mid X_{\sim i})\big]$$ $$= \mathrm{Var}(Y)-\mathrm{Var}\big(\mathbb{E}[Y\mid X_{\sim i}]\big)$$
$D_i$: effect of $X_i$ alone.
$D_i^{\mathrm{tot}}$: everything involving $X_i$ (main + interactions).

Normalized indices

$$S_i \triangleq \frac{D_i}{\mathrm{Var}(Y)} \qquad (\text{first-order})$$ $$ST_i \triangleq \frac{D_i^{\mathrm{tot}}}{\mathrm{Var}(Y)} \qquad (\text{total-order})$$
$$0 \le S_i \le ST_i \le 1$$

Toy example A: additive model

  • $X_1,X_2 \sim \mathcal{U}[0,1]$ independent, $Y=X_1+X_2$
$$\mathrm{Var}(Y)=\mathrm{Var}(X_1)+\mathrm{Var}(X_2)=\frac{1}{12}+\frac{1}{12}=\frac{1}{6}$$ $$D_1=\mathrm{Var}(\mathbb{E}[Y|X_1])=\mathrm{Var}(X_1)=\frac{1}{12}$$ $$S_1=\frac{D_1}{\mathrm{Var}(Y)}=\frac{1/12}{1/6}=\frac{1}{2}$$
No interactions: $S_1=ST_1=1/2$, $S_2=ST_2=1/2$.

Toy example B: pure interaction

  • $X_1,X_2 \sim \mathcal{U}[0,1]$ independent, $Y=(X_1-\frac{1}{2})(X_2 - \frac{1}{2})$
$$\mathbb{E}[Y]=0$$
$$\mathbb{E}[Y^2]=\frac{1}{12}.\frac{1}{12}=\frac{1}{144} = \mathrm{Var}(Y)$$
$\mathbb{E}[Y \mid X_1]=0$ etc.
Main effects vanish: $S_1=S_2=0$. Total effects are maximal: $ST_1=ST_2=1$.

III - Sobol Decomposition: Theory and Interpretation [Sob. 01]

Functional ANOVA / Hoeffding decomposition

  • Assume $X_1,\dots,X_d$ independent and $f\in L^2$.
  • Assume that $\int f_u(X_u)d\mathbb{P}_{X_i} = 0, \forall i \in u, \forall u \subset \{1, \dots, d\}$
  • Then $f$ admits a unique decomposition:
$$f(X)=f_0+\sum_i f_i(X_i)+\sum_{i< j} f_{ij}(X_i,X_j)+\cdots+f_{1\cdots d}(X_1,\dots,X_d)$$
“main effects + pairwise interactions + higher-order interactions”.
  • Sobol. Sensitivity estimates for nonlinear mathematical models. Wiley 1 407-414, 1993
  • Van Der Vaart, A. W. Asymptotic Statistics. Cambridge University Press,,2012

Centering constraint for variance additivity (1/2)

  • Without constraints, the decomposition is not unique (constants can be moved between terms).
  • The two previous assumptions (aka. zero-mean constraints) make terms identifiable
$$\mathbb{E}[f_i(X_i)] = 0,\quad \mathbb{E}[f_{ij}(X_i,X_j)\mid X_i]=0,\quad \mathbb{E}[f_{ij}(X_i,X_j)\mid X_j]=0 $$ etc.
Each term contains only the “new effect” not already captured by lower-order components.

Centering constraint for variance additivity (2/2)

  • Under independence + centering constraints, ANOVA components are orthogonal:
$$\mathbb{E}\big[f_u(X_u)\,f_v(X_v)\big]=0 \quad \text{for } u\neq v$$
  • Therefore, variance decomposes exactly:
$$\mathrm{Var}(f(X))=\sum_{u\neq\emptyset}\mathrm{Var}(f_u(X_u))$$
This is the core reason Sobol indices are “clean”: exact variance budget, no heuristics.

Sobol indices for subsets

  • For any subset $u\subseteq\{1,\dots,d\}$:
$$D_u \triangleq \mathrm{Var}(f_u(X_u)),\qquad S_u\triangleq \frac{D_u}{\mathrm{Var}(f(X))}$$
  • Properties:
$$S_u\ge 0,\qquad \sum_{u\neq\emptyset} S_u = 1$$

Alternative expressions (useful)

  • Main-effect function (intuition):
$$f_i(x_i)=\mathbb{E}[f(X)\mid X_i=x_i]-f_0$$
  • Main-effect variance contribution (from Part II):
$$D_i=\mathrm{Var}\big(\mathbb{E}[f(X)\mid X_i]\big)$$
This is the bridge between Part II (conditional variance) and Part III (functional ANOVA).

Total-order index and interactions

  • Total-order index aggregates all terms involving feature $i$:
$$ST_i = \sum_{u\ni i} S_u$$
Interpretation: $ST_i$ measures “everything that uses feature $i$” (directly or via interactions).
  • Quick diagnostic:
$$ST_i \approx S_i \Rightarrow \text{weak interactions}, \qquad ST_i \gg S_i \Rightarrow \text{strong interactions}.$$

IV - Local Feature Importance via Perturbations [Fel. 21]

Local explanation (1/2)

Local variance-based explanation $\texttt{[Fel. 21]}$: perturb patches, evaluate the model, estimate total-order Sobol indices, and visualize as a heatmap.
  • We want to explain a single input $x$ (image / audio / text).
  • Introduce a random mask $M=(M_1,\dots,M_d)$ controlling visibility of regions.
  • Define a perturbation operator $\Phi(x,M)$.

Local explanation (2/2)

  • We want to explain a single input $x$ (image / audio / text).
  • Introduce a random mask $M=(M_1,\dots,M_d)$ controlling visibility of regions.
  • Define a perturbation operator $\Phi(x,M)$.
$$M=(M_1,\dots,M_d)\in[0,1]^d$$ $$\tilde x = \Phi(x,M)$$ $$Y = f(\tilde x)=f(\Phi(x,M))$$
We apply variance-based analysis to the random variable $Y$ induced by random perturbations $M$.

Perturbation operator: baseline inpainting

  • Typical continuous perturbation used in the paper:
$$\Phi(x,M)=x\odot M + (1-M)\odot \mu$$
  • $\mu$ is a baseline / reference input (black image / blur / mean / inpainted content).
Explanations are conditional on the perturbation model $(\Phi,\mu)$.

From perturbations to a saliency map

  • We sample many masks $M^{(1)},\dots,M^{(N)}$.
  • We build perturbed inputs $\tilde x^{(n)}=\Phi(x,M^{(n)})$.
  • We evaluate the model outputs $Y^{(n)}=f(\tilde x^{(n)})$.
  • We estimate Sobol indices for each mask component $M_i$.
One importance score per region $\Rightarrow$ reshape into a grid $\Rightarrow$ heatmap.

Local Sobol importance (Total-order)

  • For a given region/feature $i$ (mask component $M_i$), we use the total-order Sobol index.
$$ST_i = \frac{ \mathbb{E}_{M_{\sim i}}\left[\mathrm{Var}_{M_i}\left(Y \mid M_{\sim i}\right)\right] }{ \mathrm{Var}(Y) }$$
$ST_i$ captures the contribution of region $i$, including all interactions with other regions.

Why total-order is the default in this paper

  • First-order $S_i$ measures the isolated effect of region $i$.
  • Total-order $ST_i$ measures everything that involves region $i$.
  • If $ST_i \gg S_i$, region $i$ mostly acts through interactions.
For local saliency, interactions matter a lot (edges + context, harmonics + formants, etc.).

Signed importance (optional)

  • Sobol indices are non-negative: they measure magnitude (how much it matters), not direction.
  • A simple signed variant attaches a sign using an occlusion test:
$$\mathrm{sign}_i=\mathrm{sign}\big(f(x)-f(x\setminus i)\big)$$ $$ST_i^{\Delta} = ST_i \cdot \mathrm{sign}_i$$
The sign is baseline-dependent. The magnitude $ST_i$ is the robust part.

V - Sobol Indices in Practice (Estimators)

Why we need estimators

  • Definitions involve conditional expectations / high-dimensional integrals.
  • We want a black-box estimator using only evaluations of $f(\Phi(x,M))$.
  • Key trick: build pairs of perturbations that differ in one component.
We will estimate the conditional variance term inside $ST_i$.

Saltelli sampling: A, B, and mixed C

  • Sample two matrices of masks: $A,B\in[0,1]^{N\times d}$.
  • For each feature $i$, build $C^{(i)}$ by replacing column $i$ of $A$ with column $i$ of $B$.
$$C^{(i)} = (A_1,\dots,A_{i-1},B_i,A_{i+1},\dots,A_d)$$
Evaluate $f(\Phi(x,\cdot))$ on $A$, $B$, and each $C^{(i)}$.
  • Saltelli, Andrea, et al. "Variance based sensitivity analysis of model output. Design and estimator for the total sensitivity index." Computer physics communications 181.2, 2010

Jansen estimator: total-order

  • Estimate the output variance (using $A$):
$$\hat V = \frac{1}{N}\sum_{j=1}^N f(\Phi(x,A_j))^2 - \left(\frac{1}{N}\sum_{j=1}^N f(\Phi(x,A_j))\right)^2$$
  • Total-order estimator:
$$\widehat{ST}_i=\frac{1}{2N\hat V}\sum_{j=1}^N \Big(f(\Phi(x,A_j))-f(\Phi(x,C^{(i)}_j))\Big)^2$$
  • Jansen M.. Analysis of variance designs for model output. Computer Physics Communications,1999

Jansen estimator: first-order

  • First-order estimator:
$$\widehat{S}_i = 1-\frac{1}{2N\hat V}\sum_{j=1}^N \Big(f(\Phi(x,B_j))-f(\Phi(x,C^{(i)}_j))\Big)^2$$
In XAI, we typically visualize $\widehat{ST}_i$ as the main saliency map.

Computational cost

  • Evaluations required: $N(d+2)$ forward passes.
  • Typical setting in the paper: mask grid $11\times 11$ $\Rightarrow d=121$ and $N=32$:
$$32\times(121+2)=3936 \text{ forward passes}$$
Still feasible with batching on GPU.

Quasi-Monte Carlo (Sobol sequences)

  • Instead of i.i.d. Monte Carlo masks, use a low-discrepancy sequence.
  • Better coverage of $[0,1]^d$ $\Rightarrow$ lower estimator variance in practice.
  • Gerber M. On integration methods based on scrambled nets of arbitrary size. Journal of Complexity, 2015

White-box vs Black-box explanations

In XAI, the evaluation protocol often depends on whether we have access to the model internals.
  • White-box methods use gradients / internals:
    • Examples: Saliency maps (∇), Integrated Gradients, Grad-CAM, LRP
    • Require access to model structure + backpropagation

  • Black-box methods use only model outputs:
    • Examples: Occlusion, RISE, SHAP, Sobol perturbation-based saliency
    • Only require forward passes: $f(x)$ queries
Black-box explanations are generally more expensive, but more model-agnostic.

Evaluation #1: Pointing Game

  • Goal: check whether the saliency map points to the object of interest.
  • Requires: ground-truth object location (bounding box or segmentation mask).
  • Protocol:
    • Compute a saliency map $S(x)\in\mathbb{R}^{H\times W}$
    • Find the most salient pixel / region:
    • $$p^\star = \arg\max_{p} S_p(x)$$
    • Hit if $p^\star$ lies inside the ground-truth object mask / box
Score = fraction of images where the explanation “points” inside the target object.
Simple and intuitive, but only tests the single most salient point.

Evaluation #1: Pointing Game

Evaluation #2: Deletion score

  • Goal: measure how “faithful” the saliency map is to the model decision.
  • Idea: remove the most important pixels first and track the output drop.
  • Protocol:
    • Rank pixels/regions by importance (descending): $p_1,p_2,\dots$
    • Create a sequence of perturbed inputs $x^{(k)}$ by deleting top-$k$ pixels
    • Track the score (e.g., class logit/probability) along the deletion path:
    • $$s_k = f_c\!\left(x^{(k)}\right)$$
A good explanation yields a fast decrease of the target class score under deletion.
  • Typical metric: Area Under the Curve (AUC) of $s_k$ vs. deleted fraction
  • Lower AUC = more faithful (faster score collapse)

Evaluation #2: Deletion score

Fel et al. (2021): what to remember

  • Local Sobol saliency maps are obtained by applying Sobol indices to a perturbation model.
  • Total-order maps naturally capture interactions between regions.
  • QMC sampling improves stability / convergence for a fixed evaluation budget.
  • Provides a principled “variance budget” view of explanations.

VI - What Sobol Explains (and What It Does Not)

Interpretation: sensitivity, not causality

Sobol-based explanations quantify sensitivity under a user-defined perturbation model.
They do not provide causal explanations.
  • Depends on $p(M)$ (mask distribution) and $\Phi(x,M)$ (perturbation operator).
  • Depends on baseline $\mu$.

Limitations / pitfalls

  • Off-manifold: perturbations may create unrealistic inputs.
  • Feature dependence: classical Sobol assumes independent variables (true for masks; not always for raw features).
  • Granularity: pixel vs superpixel vs patch changes the explanation.
  • Faithfulness metrics: deletion/insertion depend on the perturbation protocol.
Sobol is principled, but the explanation is conditional on modeling choices.

Audio / Signal Segment (Why Sobol is natural here)

Time-frequency features

  • Mixture: $x(t)=s(t)+n(t)$
  • STFT: $X(\omega,\tau)$
  • Mask TF regions:
$$\widetilde{X}(\omega,\tau)=M(\omega,\tau)\odot X(\omega,\tau)$$ $$Y = f(\widetilde{X})$$
Feature can be a TF-bin, a TF-patch, or a frequency band.

Why interactions matter in audio

  • Speech/music structure:
    • harmonics (correlated frequencies)
    • formants (bands)
    • time coherence (onsets/transients)
  • Many features are informative only jointly.
Total-order Sobol indices capture these spectro-temporal interactions.

Spectrogram in audio

the word "friend" seems important (maybe the phoneme ?). some specific frequencies too.

VII - Take-Home Messages

Take-Home Messages

  1. Feature importance must be defined: it depends on “how you vary the input”.
  2. Variance is a principled sensitivity measure (scalar, stable, decomposable).
  3. Sobol = functional ANOVA: exact decomposition into main effects + interactions (under assumptions).
  4. Total-order indices are key in interaction-heavy models.
  5. Local XAI: apply Sobol to $g(M)=f(\Phi(x,M))$ using perturbation masks.

References

  • Sobol, I. M. Global sensitivity indices for nonlinear mathematical models and their Monte Carlo estimates. Mathematics and computers in simulation, 2001
  • Fel, Thomas, et al. "Look at the variance! efficient black-box explanations with sobol-based sensitivity analysis." NeurIPS, 2021
  • Jansen M.. Analysis of variance designs for model output. Computer Physics Communications,1999
  • Gerber M. On integration methods based on scrambled nets of arbitrary size. Journal of Complexity, 2015

XAI: Counterfactual explanations

MVA - APM_5MV75_TP

matfontaine.github.io/APM_5MV75_TP

Mathieu FONTAINE
mathieu.fontaine@telecom-paris.fr

January 2025

Outline

I — Why counterfactuals?

II — Formal definition

III — Single counterfactual = optimization view

IV — DiCE: multiple diverse counterfactuals

V — Distances, constraints, mixed data

VI — Evaluation + limitations

VII — App: Counterfactuals for medical image

VIII — App: Counterfactuals in text-audio

I — Why Counterfactual Explanations?

Explanation as “minimal change”

  • We have an instance $x$ and a model prediction $y=f(x)$.
  • We want a new point $c$ such that the decision becomes $y^\star$.
Counterfactual = “a feasible input close to $x$ that flips the outcome”.

Why one counterfactual is not enough

  • Many different changes can flip the decision.
  • We want multiple options to provide flexibility.
For a given $x$, the set $\{c: f(c)=y^\star\}$ can be large and disconnected.

II — Formal Definition

Model + target outcome

  • Let $f:\mathbb{R}^D \to \mathcal{Y}$ be a black-box predictor.
  • Given an input $x\in\mathbb{R}^D$, we want $c\in\mathbb{R}^D$ such that:
$$f(c)=y^\star$$
In practice: $f$ may output a probability/logit; we enforce $f(c)\ge \tau$ for some target threshold.

Core desiderata

  • Validity: $c$ achieves the target outcome.
  • Proximity: $c$ is close to $x$ (small change).
  • Feasibility / actionability: respect constraints (immutable features).
  • Interpretability: sparse changes preferred (few features).

III — Single Counterfactual: Optimization View

Constrained formulation

  • Minimal change subject to achieving the target:
$$\min_{c\in\mathcal{C}} \; d(c,x) \quad \text{s.t.}\quad f(c)=y^\star$$
$\mathcal{C}$ encodes feasibility: immutable features, box constraints, categorical constraints.

Relaxed (penalized) formulation

  • Convert constraint to a loss term:
$$\min_{c\in\mathcal{C}}\; \lambda \,\ell\big(f(c),y^\star\big) + d(c,x)$$
$\ell$ can be hinge-like on the target class probability (or cross-entropy on $y^\star$).

Discrete / categorical constraints

  • Some features are categorical ({CDI, CDD, FreeLancer}) (one-hot) or ordinal {Bac < Bachelor < Master < PhD} or continuous {salary etc.}.
  • Practical solutions:
    • continuous relaxation (categories becomes real) + projection (on category space)
    • heuristic search (beam search, simulated annealing etc.)
    • optimize in a continuous latent space of a generator

IV — DiCE: Diverse Counterfactual Explanations

K counterfactuals

  • Find $\mathcal{C}=\{c_1,\dots,c_K\}$ such that each $c_k$ is valid and close.
  • Additionally: counterfactuals should be diverse.
$$f(c_k)=y^\star,\quad k=1,\dots,K$$
Diversity gives multiple “paths” to reach the desired outcome.

DiCE objective in a nutshell

  • Optimize all CFs jointly:
$$\min_{c_1,\dots,c_K\in\mathcal{C}} \;\; \sum_{k=1}^K \Big[ \lambda\,\ell(f(c_k),y^\star) + d(c_k,x)\Big] \;-\; \eta \cdot \mathrm{Diversity}(c_1,\dots,c_K)$$

DPP-style diversity (determinant)

  • Define a similarity kernel between counterfactuals, e.g. RBF:
$$K_{ij}=\exp\!\left(-\frac{\|c_i-c_j\|_2^2}{\sigma^2}\right)$$
  • Diversity proxy:
$$\mathrm{Diversity}(c_1,\dots,c_K)=\det(K)$$
$\det(K)$ is large when points are dissimilar (spread out).

Practical note: gradients

  • If $f$ is differentiable (NN), optimize by gradient descent in input space.
  • If $f$ is non-differentiable, DiCE can use model-agnostic search / sampling.
Differentiable optimization is fast but can produce off-manifold counterfactuals.

V — Distances, Constraints, Mixed Data

Continuous features: robust scaling (MAD)

  • For continuous features, normalize by median absolute deviation (MAD):
$$d_{\text{cont}}(c,x)=\sum_{j\in \mathcal{J}_{\text{cont}}}\frac{|c_j-x_j|}{\mathrm{MAD}_j}$$
MAD is more robust than standard deviation (less sensitive to outliers).

Categorical features: mismatch penalty

  • For one-hot encoded categorical features:
$$d_{\text{cat}}(c,x)=\sum_{j\in \mathcal{J}_{\text{cat}}} \mathbf{1}[c_j \neq x_j]$$
Enforce valid one-hot vectors via projection (argmax) or constraints.

Immutable / actionable constraints

  • Partition features into:
    • immutable: cannot change (e.g., age, past defaults)
    • actionable: can change (e.g., savings, debt)
    • conditionally actionable: changes allowed only in one direction
$$c_j = x_j \;\;\text{for } j\in\mathcal{J}_{\text{immut}}$$

Sparsity

  • Interpretability often requires changing few features.
  • Proxy: encourage sparse deltas $\Delta = c-x$ via $\ell_1$ penalty:
$$d(c,x) \;+\; \gamma \|\;c-x\;\|_1$$
Alternative: post-process CFs to keep only the largest changes.

VI — Evaluation and Limitations

Standard evaluation metrics

  • Validity: fraction of CFs achieving the target.
  • Proximity: average distance $d(c_k,x)$.
  • Diversity: determinant-based or pairwise distance statistics.
  • Sparsity: average number of changed features.

Local decision boundary approximation (intuition)

  • Counterfactuals approximate the local boundary around $x$.
  • One can fit a simple surrogate near $x$ (e.g., 1-NN on generated samples).
Idea: evaluate how well CFs “cover” different boundary directions locally.

Main limitations

  • Off-manifold: CFs can be unrealistic (not resembling real data).
  • Feature dependence: changing one feature may require changing others.
  • Causality: actionable recourse should respect causal relations.
  • Multiple solutions: must balance proximity vs diversity vs sparsity.
Counterfactuals are “action suggestions” only under a chosen feasibility model.

Take-Home Messages

Take-Home Messages

  1. Counterfactuals explain by giving changes that flip the outcome.
  2. Mathematically: solve a constrained optimization (validity + proximity).
  3. DiCE: generate K CFs and add a diversity term (determinant / DPP-style).
  4. Mixed data + constraints are central (actionability, immutability, categorical).
  5. CFs are only meaningful under an explicit feasibility model.

References

  • DiCE: Mothilal, Sharma, Tan. “Explaining ML Classifiers through Diverse Counterfactual Explanations.” FAT* 2020.
  • (Optional background) Wachter et al. “Counterfactual explanations without opening the black box.” 2017.

Application (Images) — Counterfactuals via Diffusion Autoencoder [Atad et al., 2024]

Problem setting (medical imaging)

  • We have an image classifier / regressor: what minimal change flips the decision?
  • We want counterfactual image $x_{cf}$ such that:
$$\text{(flip)}\quad f(x_{cf}) \neq f(x) \qquad \text{and}\qquad \text{(realism)}\quad x_{cf}\sim p_{\text{data}}$$
Pixel-space optimization often produces adversarial-like artifacts. Generative latent spaces help keep edits on-manifold.
  • Matan Atad et al., Counterfactual Explanations for Medical Image Classification and Regression using Diffusion Autoencoder, JMLBI, 2024

Diffusion Autoencoder (DAE) as unsupervised feature extractor

  • Train a DAE on unlabeled images → semantic latent space zsem.
  • DAE learns a compressed representation where interpolations are meaningful.
$$x \xrightarrow{\text{encoder}} z_{\mathrm{sem}} \in \mathbb{R}^m \qquad\qquad \hat x = \text{decode}(z_{\mathrm{sem}}, \text{(stochastic latent)})$$
Unsupervised training → same latent space can support multiple downstream tasks (classification + ordinal regression).

Linear decision boundary in latent space

  • Encode labeled samples: $w = z_{\mathrm{sem}}(x)$
  • Fit a linear classifier (SVM / linear layer): hyperplane
$$\mathcal{P}:\quad n^\top w + b = 0$$ $$\text{signed distance:}\quad \mathrm{dist}(w,\mathcal{P})=\frac{n^\top w + b}{\|n\|}$$
  • $n:$ semantic direction corresponsing to the pathology existence
  • $b:$ bias term
The vector n becomes a semantic direction for the pathology.

Binary counterfactual = reflect across hyperplane

  • Given latent code w, the paper proposes a closed-form CE in latent space.
$$w_{cf} = w - 2\cdot \mathrm{dist}(w,\mathcal{P})\;\frac{n}{\|n\|}$$
  • Then decode: $$x_{cf} = \mathrm{decode}(w_{cf})$$
CE is an on-manifold semantic edit: flip class while staying realistic.

Ordinal counterfactuals(severity progression)

  • Severity measure: severity is defined as the signed distance to the hyperplane $$ \text{dist}(w,P) = \frac{n^\top w + b}{\|n\|} $$
  • Calibration to ordinal grades: $$ \text{grade}(w) \approx \alpha\,\text{dist}(w,P) + \beta $$
  • Controlled severity traversal: we explicitly parametrize a latent trajectory along the normal direction $$ w(t) = w_0 + t\,\frac{n}{\|n\|} $$
  • Key property: along this path, $$ \text{dist}(w(t),P) = \text{dist}(w_0,P) + t $$ $\Rightarrow$ increasing $t$ increases severity linearly.
  • Counterfactual generation: decoded images form a smooth severity progression $$ x(t) = \text{decode}(w(t)) $$

Results of the generator

Results of the generator

Results of accuracy, grade etc.

Limitations

  • Linearity assumption: a single hyperplane direction may not match complex pathologies.
  • Confounders in latent space: “move to healthy” may introduce spurious anatomical changes (bias).
  • Calibration: mapping distance→grade can be imperfect (especially under label scarcity / imbalance).
Counterfactuals are only as good as (i) the generative model, (ii) the latent geometry, (iii) the downstream classifier.

Application (Audio) — Learning with Counterfactual Captions [Vosoughi et al., 2024]

Main idea

Can be a chosen paper for the evaluation !

  • Ali. Vosoughi et al,LEARNING AUDIO CONCEPTS FROM COUNTERFACTUAL NATURAL LANGUAGE, ICASSP, 2024

XAI: Prototypical Networks

MVA - APM_5MV75_TP

matfontaine.github.io/APM_5MV75_TP

Mathieu FONTAINE
mathieu.fontaine@telecom-paris.fr

February 2025

Why prototypes?

  • Standard deep nets rely on abstract latent features.
  • Prototype nets explain by analogy: “this looks like that”.
  • Interpretability is intrinsic (part of the model), not post-hoc.
We build a model where logits are sums of similarities to learned prototypes.
  • Snell et al. Prototypical Networks for Few Shot Learning, NIPS 2017

Outline

I — Motivations

II — Prototypical networks: $\texttt{[Sne. 17]}$

III — ProtoPNet: patch prototypes $\texttt{[Che. 19]}$

IV - Prototypical and Counterfactual

V — Prototypes in Sound Classification $\texttt{[Zin. 21]}$

VI — take-home

  • Snell et al. Prototypical Networks for Few Shot Learning, NIPS 2017
  • Chen et al. This Looks Like That: Deep Learning for Interpretable Image Recognition, NeurIPS 2019
  • Zinemanas et al. An Interpretable Deep Learning Model for Automatic Sound Classification, Electronics, 2021.

II — Prototypical networks: $\texttt{[Sne. 17]}$

II – Prototypical Networks

Few-shot classification setting

  • We consider an episodic learning setup.
  • Each episode mimics a few-shot task.
Support set: $S = \{(x_i, y_i)\}_{i=1}^N,\quad y_i \in \{1,\dots,K\}$
$S_k = \{x_i : y_i = k\}$

Learning an embedding

  • Inputs are mapped to a representation space:
$f_\phi : \mathbb{R}^D \to \mathbb{R}^M$
  • $f_\phi$ is a neural network (CNN, ResNet, …)
  • All reasoning happens in the embedding space

Class prototypes

  • Each class is represented by a prototype.
$$ c_k = \frac{1}{|S_k|} \sum_{(x_i,y_i)\in S_k} f_\phi(x_i) $$
A prototype is the mean embedding of support examples.

Classification rule

  • Given a query point $x$:
$$ p_\phi(y=k\mid x) = \frac{\exp\big(-d(f_\phi(x), c_k)\big)} {\sum_{k'} \exp\big(-d(f_\phi(x), c_{k'})\big)} $$
  • $d(\cdot,\cdot)$ is a distance in embedding space

Episodic training objective

  • Minimize negative log-likelihood on queries:
$$ \mathcal{L}(\phi) = - \log p_\phi(y=k \mid x) $$
  • Episodes are sampled during training
  • SGD on $\phi$

Why is the prototype the mean?

  • Assume the distance is a Bregman divergence ($\phi$ strictly convex on a convex domain).
$$ d_\varphi(z,z') = \varphi(z) - \varphi(z') - (z-z')^\top \nabla\varphi(z') $$
For Bregman divergences, the cluster representative minimizing $\sum_i d(z_i, c)$ is the mean.

Connection to mixture models

  • Exponential family density:
$$ p(z\mid\theta) = \exp\{-d_\varphi(z,\mu(\theta)) - g(z)\} $$
  • One component per class
  • Equal mixture weights

Squared Euclidean distance

$$ d(z,c) = \|z-c\|^2 $$
  • Corresponds to spherical Gaussian clusters
  • Prototype = empirical mean

Equivalence to a linear classifier

$$ -\|f(x)-c_k\|^2 = 2c_k^\top f(x) - \|c_k\|^2 + \text{const} $$
$$ = w_k^\top f(x) + b_k $$
ProtoNet = linear classifier on learned embeddings

Few-shot terminology and datasets

  • N-way: number of classes in a classification episode
  • K-shot: number of labeled examples per class
An episode = N-way K-shot classification task
  • Omniglot: handwritten characters, many classes, low intra-class variability
  • miniImageNet: natural images, fewer classes, high variability

Experimental validation (ProtoNet)

  • Benchmarks: Omniglot and miniImageNet
  • Few-shot classification tasks:
    • $N$-way, $K$-shot (e.g. 5-way 1-shot, 5-way 5-shot)
  • Episodic training and testing
Support set → prototypes → softmax over distances
No fine-tuning at test time, only prototype computation.

Results

Key takeaways

  • Classification via distances to prototypes
  • Prototype = mean (theoretically optimal)
  • Distance choice encodes distributional assumptions
  • Model is simple, stable, and interpretable
Simplicity is the inductive bias.

Why ProtoNets matter for interpretability

  • Each decision is explained by distances to class prototypes
  • Prototypes live in the same space as data embeddings
  • No opaque classifier head
Interpretability is intrinsic, not post-hoc.

III — ProtoPNet: patch prototypes $\texttt{[Che. 19]}$

Patch embedding map

  • Backbone (ResNet, VGG ...) outputs feature map: $F_\theta(x)\in\mathbb{R}^{H\times W\times d}$
  • d: channel number
  • Patch vectors: $z_{ij}(x)\in\mathbb{R}^d$
A prototype corresponds to a patch pattern (beak, wheel, texture...), not the whole image.

Matching: max over patches

$$ s_k(x)=\max_{i,j}\; -\|z_{ij}(x)-p_k\|_2^2 $$
Non-smooth due to max; use subgradients or a softmax approximation.

Softmax (log-sum-exp) smoothing

$$ s_k^\tau(x)= -\tau \log\sum_{i,j}\exp\Big(-\tfrac{1}{\tau}\|z_{ij}(x)-p_k\|^2\Big) $$
As $\tau\to 0$, $s_k^\tau(x)\to \max_{i,j} -\|z_{ij}-p_k\|^2$.

Objective: accuracy + interpretability

$$ \mathcal{L}=\mathcal{L}_{cls}+\lambda_1\mathcal{L}_{p2d}+\lambda_2\mathcal{L}_{d2p} $$
  • $\mathcal{L}_{cls}$: cross-entropy on logits
  • $\mathcal{L}_{p2d}$: each prototype close to at least one real patch (same class)
  • $\mathcal{L}_{d2p}$: each sample covered by at least one prototype (same class)

Prototype-to-data term (patch version)

$$ \mathcal{L}_{p2d}=\sum_{k=1}^K \min_{x_i:y_i=c(k)} \min_{i,j}\; \|z_{ij}(x_i)-p_k\|_2^2 $$
Makes each prototype represent a real pattern from the training data.

Data-to-prototype term

$$ \mathcal{L}_{d2p}=\mathbb{E}\Big[\min_{k:c(k)=y}\min_{i,j}\|z_{ij}(x)-p_k\|_2^2\Big] $$

Prototype projection (key XAI step)

$$ (x^\star,i^\star,j^\star)=\arg\min_{x_i:y_i=c(k)}\min_{i,j}\|z_{ij}(x_i)-p_k\|^2 $$ $$ p_k \leftarrow z_{i^\star j^\star}(x^\star) $$
Prototypes become exact training patches, hence directly visualizable.

V — Prototypical and Counterfactual ?

Prototype-driven counterfactual (idea)

  • Target class $y'$ has prototypes $\{p_k:c(k)=y'\}$.
  • Seek $c$ close to $x$ but near a target prototype:
$$ \min_c \|c-x\| \quad\text{s.t.}\quad \exists k:c(k)=y'\;\; \|\phi(c)-p_k\|\le \epsilon $$

Proto-based Interpretability for Audio: APNet [Zin. 21]

What APNet tries to fix (audio XAI)

  • Deep audio classifiers are accurate but opaque.
  • Post-hoc saliency (gradients / LRP / etc.) can be unstable or hard to interpret in TF.
  • APNet: explanation is case-based: prediction is based on similarity to a small set of learned prototypes.
Interpretability is intrinsic: the decision is explicitly computed from prototype similarities.

Architecture at a glance

  • Input: log-mel spectrogram $X_i \in \mathbb{R}^{T\times F}$
  • Encoder: $Z_i = f(X_i)$ with $Z_i \in \mathbb{R}^{T'\times F'\times C}$
  • Prototype layer: store $M$ prototypes $P_j \in \mathbb{R}^{T'\times F'\times C}$
  • Similarity $\rightarrow$ weighted sum over frequency $\rightarrow$ linear classifier

Autoencoder objective (makes prototypes audible)

  • Decoder reconstructs $\hat X_i = g(Z_i)$
  • Reconstruction loss:
$$\ell_r=\frac{1}{N}\sum_{i=1}^N \|X_i-\hat X_i\|_2^2$$

Prototype loss = “prototypes are real” + “data are covered”

  • Distance from each sample to each prototype (latent space):
$$D_{ij}=\|Z_i-P_j\|_2^2$$
  • Prototype objective:
$$\ell_p=\frac{1}{N}\sum_{i=1}^N\min_j D_{ij} \;+\; \frac{1}{M}\sum_{j=1}^M\min_i D_{ij}$$

Audio-specific twist: frequency-dependent similarity

  • They compute similarity per frequency bin $f$ (latent TF structure):
$$S_{ij}[f]=\exp\!\left(-\sum_{t=1}^{T'}\sum_{c=1}^{C}\big(Z_i[t,f,c]-P_j[t,f,c]\big)^2\right)$$
  • Then aggregate frequency with a learnable kernel $H_j[f]$:
$$\hat S_{ij}=\sum_{f=1}^{F'} H_j[f]\,S_{ij}[f]$$

Classifier + training objective

  • Linear layer on similarities (no bias for interpretability):
$$\hat Y = \mathrm{softmax}(\hat S\,W)$$
  • Cross-entropy:
$$\ell_c=-\frac{1}{N}\sum_{i=1}^N\sum_{k=1}^K Y_{ik}\log \hat Y_{ik}$$
  • Final loss:
$$\ell=\alpha \ell_c + \beta \ell_p + \gamma \ell_r$$

Datasets (3 tasks)

  • UrbanSound8K: 10 urban classes (≤4s clips).
  • Medley-solos-DB: 9 instruments, 3s clips .
  • Google Speech Commands V2: 35 keywords, 1s clips.

Results

  • They compare to SB-CNN, Att-CRNN, and OpenL3 features + MLP.
  • APNet is competitive on all tasks.

Inspection: what makes it interpretable? (1/2)

  • Decode prototypes $\to$ mel-spectrograms (and audio) to inspect “what a class looks like”.

Inspection: what makes it interpretable? (2/2)

  • Inspect $W$ (prototype-to-class connections): prototypes mostly connect to their own class.

Refinement (because you can debug it)

  • Prototype redundancy pruning using prototype-prototype distances.
  • Channel redundancy pruning by analyzing prototype distances per channel.
  • After pruning + short retrain: fewer params and sometimes better accuracy.

Interested by it ? Then select his extension for the project :-)

Take-home (APNet)

  1. Prototype networks give intrinsic explanations: “this sounds like that”.
  2. Audio needs domain-aware similarity (frequency matters).
  3. Decoding prototypes to audio makes explanations tangible (listen + view).
  4. Interpretable structure enables pruning + debugging beyond accuracy.