aboutsummaryrefslogtreecommitdiff
path: root/modules/torch_utils.py
blob: e5b52393ec86df16521052d694d26e2865fcd278 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from __future__ import annotations

import torch.nn


def get_param(model) -> torch.nn.Parameter:
    """
    Find the first parameter in a model or module.
    """
    if hasattr(model, "model") and hasattr(model.model, "parameters"):
        # Unpeel a model descriptor to get at the actual Torch module.
        model = model.model

    for param in model.parameters():
        return param

    raise ValueError(f"No parameters found in model {model!r}")