Quickstart
Meta-train with MetaTrainer
The fastest path: wrap a model in MAML, hand it to a
MetaTrainer with a task sampler, and call fit.
import torch
from torch import nn
from iteryne import MAML, MetaTrainer, SinusoidTaskSampler
model = nn.Sequential(
nn.Linear(1, 40), nn.ReLU(),
nn.Linear(40, 40), nn.ReLU(),
nn.Linear(40, 1),
)
maml = MAML(model, inner_lr=0.01, inner_steps=1, first_order=False)
meta_opt = torch.optim.Adam(maml.parameters(), lr=1e-3)
trainer = MetaTrainer(maml, meta_opt, nn.MSELoss(), SinusoidTaskSampler(seed=0))
trainer.fit(num_iterations=2000, meta_batch_size=25)
Adapt to a new task
After meta-training, clone the learner and take a few inner-loop steps on a new task's support set:
task = SinusoidTaskSampler(seed=123).sample(1)[0]
learner = maml.clone()
learner.adapt_on(nn.MSELoss(), task.support_x, task.support_y)
predictions = learner(task.query_x)
Write your own training loop
MetaTrainer is optional. The core MAML loop is just:
maml = MAML(model, inner_lr=0.01, inner_steps=5, first_order=False)
meta_opt = torch.optim.Adam(maml.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
for task_batch in task_batches: # each is a list[Task]
meta_opt.zero_grad()
for task in task_batch:
learner = maml.clone()
for _ in range(maml.inner_steps):
learner.adapt(loss_fn(learner(task.support_x), task.support_y))
loss_fn(learner(task.query_x), task.query_y).backward()
meta_opt.step()
The single backward() produces the full second-order meta-gradient when
first_order=False and the first-order (FOMAML) meta-gradient when True. See
Concepts for why.
Bring your own data
Any object with a sample(meta_batch_size) -> list[Task] method is a
TaskSampler. A Task
holds support_x, support_y, query_x, query_y tensors. The tensors can be
anything your model and loss understand, so classification, regression, or any
differentiable objective works.