-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
53 lines (47 loc) · 1.68 KB
/
utils.py
File metadata and controls
53 lines (47 loc) · 1.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import triton
from torch import Tensor
import triton.language as tl
from jaxtyping import Float32, Int32
import triton_viz
import inspect
from triton_viz.interpreter import record_builder
def test(puzzle, puzzle_spec, nelem={}, B={"B0": 32}, viz=True):
B = dict(B)
if "N1" in nelem and "B1" not in B:
B["B1"] = 32
if "N2" in nelem and "B2" not in B:
B["B2"] = 32
triton_viz.interpreter.record_builder.reset()
torch.manual_seed(0)
signature = inspect.signature(puzzle_spec)
args = {}
for n, p in signature.parameters.items():
print(p)
args[n + "_ptr"] = ([d.size for d in p.annotation.dims], p)
args["z_ptr"] = ([d.size for d in signature.return_annotation.dims], None)
tt_args = []
for k, (v, t) in args.items():
tt_args.append(torch.rand(*v) - 0.5)
if t is not None and t.annotation.dtypes[0] == "int32":
tt_args[-1] = torch.randint(-100000, 100000, v)
grid = lambda meta: (triton.cdiv(nelem["N0"], meta["B0"]),
triton.cdiv(nelem.get("N1", 1), meta.get("B1", 1)),
triton.cdiv(nelem.get("N2", 1), meta.get("B2", 1)))
#for k, v in args.items():
# print(k, v)
triton_viz.trace(puzzle)[grid](*tt_args, **B, **nelem)
z = tt_args[-1]
tt_args = tt_args[:-1]
z_ = puzzle_spec(*tt_args)
match = torch.allclose(z, z_, rtol=1e-3, atol=1e-3)
print("Results match:", match)
failures = False
if viz:
failures = triton_viz.launch()
if not match or failures:
print("Invalid Access:", failures)
print("Yours:", z)
print("Spec:", z_)
print(torch.isclose(z, z_))
return