Skip to content

Quickstart

Meta-train with MetaTrainer

The fastest path: wrap a model in MAML, hand it to a MetaTrainer with a task sampler, and call fit.

import torch
from torch import nn
from iteryne import MAML, MetaTrainer, SinusoidTaskSampler

model = nn.Sequential(
    nn.Linear(1, 40), nn.ReLU(),
    nn.Linear(40, 40), nn.ReLU(),
    nn.Linear(40, 1),
)
maml = MAML(model, inner_lr=0.01, inner_steps=1, first_order=False)
meta_opt = torch.optim.Adam(maml.parameters(), lr=1e-3)

trainer = MetaTrainer(maml, meta_opt, nn.MSELoss(), SinusoidTaskSampler(seed=0))
trainer.fit(num_iterations=2000, meta_batch_size=25)

Adapt to a new task

After meta-training, clone the learner and take a few inner-loop steps on a new task's support set:

task = SinusoidTaskSampler(seed=123).sample(1)[0]
learner = maml.clone()
learner.adapt_on(nn.MSELoss(), task.support_x, task.support_y)
predictions = learner(task.query_x)

Write your own training loop

MetaTrainer is optional. The core MAML loop is just:

maml = MAML(model, inner_lr=0.01, inner_steps=5, first_order=False)
meta_opt = torch.optim.Adam(maml.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

for task_batch in task_batches:           # each is a list[Task]
    meta_opt.zero_grad()
    for task in task_batch:
        learner = maml.clone()
        for _ in range(maml.inner_steps):
            learner.adapt(loss_fn(learner(task.support_x), task.support_y))
        loss_fn(learner(task.query_x), task.query_y).backward()
    meta_opt.step()

The single backward() produces the full second-order meta-gradient when first_order=False and the first-order (FOMAML) meta-gradient when True. See Concepts for why.

Bring your own data

Any object with a sample(meta_batch_size) -> list[Task] method is a TaskSampler. A Task holds support_x, support_y, query_x, query_y tensors. The tensors can be anything your model and loss understand, so classification, regression, or any differentiable objective works.