Skip to content

Concepts

The MAML objective

MAML learns an initialization \(\theta\) such that, for a task \(\mathcal{T}_i\) drawn from \(p(\mathcal{T})\), a few gradient steps on the task's support set yield parameters that perform well on its query set.

Inner loop (adaptation) — one or more gradient steps on the support loss:

\[\theta_i' = \theta - \alpha \, \nabla_\theta \, \mathcal{L}_{\mathcal{T}_i}(f_\theta)\]

Outer loop (meta-update) — minimize the query loss of the adapted model, summed over a batch of tasks, with respect to the original \(\theta\):

\[\theta \leftarrow \theta - \beta \, \nabla_\theta \sum_i \mathcal{L}_{\mathcal{T}_i}(f_{\theta_i'})\]

Here \(\alpha\) is the inner step size (inner_lr) and \(\beta\) is the meta step size (your torch.optim optimizer's learning rate).

Model-agnostic by construction

The inner loop must differentiate through a parameter update, which an ordinary in-place param -= lr * grad would break. iteryne instead carries parameters as a dict[str, Tensor] and runs the model statelessly with torch.func.functional_call. Nothing about your nn.Module needs to change, so the same code path serves MLPs, CNNs, modules with BatchNorm, and so on.

First-order vs second-order in one code path

The inner step is computed with torch.autograd.grad(support_loss, params, create_graph=not first_order) and then \(p' = p - \alpha \, g\).

  • Second-order MAML (first_order=False, create_graph=True): the adapted parameters stay connected to \(\theta\) through the gradient term \(g\), so a later query_loss.backward() differentiates through the inner step, producing the exact meta-gradient (including the Hessian-vector terms).
  • First-order MAML / FOMAML (first_order=True, create_graph=False): \(g\) is detached, so \(p' = p - \alpha \, g\) keeps only the identity edge \(p' \to p\). The same query_loss.backward() then yields \(\partial \mathcal{L} / \partial \theta'\) — exactly the FOMAML meta-gradient — at roughly a third less compute.

Because the only difference is the create_graph flag, the wrapper, trainer, and variants never special-case the two regimes.

Meta-SGD

MetaSGD adds a learnable inner learning rate per parameter: the update becomes \(p' = p - \alpha \odot g\) where \(\alpha\) is an elementwise tensor trained alongside \(\theta\) by the meta-optimizer. It is a thin subclass that registers the rates as parameters and feeds them into the inner step.

ANIL

ANIL ("Almost No Inner Loop") adapts only the network head in the inner loop; the body (the shared representation) is updated solely by the meta-optimizer. In iteryne this is simply MAML with adapt_names restricted to the head's parameters.

BatchNorm and buffers

Buffers such as BatchNorm running statistics are passed through the functional forward but are not adapted in the inner loop. For few-shot settings consider track_running_stats=False so batch statistics are used consistently.