wandb.watch

View source on GitHub

Hooks into the torch model to collect gradients and the topology.

watch(
    models, criterion=None, log="gradients", log_freq=1000, idx=None
)

Should be extended to accept arbitrary ML models.

Args

models

(torch.Module) The model to hook, can be a tuple

criterion

(torch.F) An optional loss value being optimized

log

(str) One of "gradients", "parameters", "all", or None

log_freq

(int) log gradients and parameters every N batches

idx

(int) an index to be used when calling wandb.watch on multiple models

Returns

wandb.Graph The graph object that will populate after the first backward pass

Last updated

Was this helpful?