blob: e5b52393ec86df16521052d694d26e2865fcd278 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
from __future__ import annotations
import torch.nn
def get_param(model) -> torch.nn.Parameter:
"""
Find the first parameter in a model or module.
"""
if hasattr(model, "model") and hasattr(model.model, "parameters"):
# Unpeel a model descriptor to get at the actual Torch module.
model = model.model
for param in model.parameters():
return param
raise ValueError(f"No parameters found in model {model!r}")
|