aboutsummaryrefslogtreecommitdiff
path: root/test/test_torch_utils.py
blob: 23ccb93a464b4fc9b196a5a4c841cfd890003c80 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import types

import pytest
import torch

from modules import torch_utils


@pytest.mark.parametrize("wrapped", [True, False])
def test_get_param(wrapped):
    mod = torch.nn.Linear(1, 1)
    cpu = torch.device("cpu")
    mod.to(dtype=torch.float16, device=cpu)
    if wrapped:
        # more or less how spandrel wraps a thing
        mod = types.SimpleNamespace(model=mod)
    p = torch_utils.get_param(mod)
    assert p.dtype == torch.float16
    assert p.device == cpu