Skip to content

Commit bd3f7ee

Browse files
authored
Merge pull request #369 from rsokl/mirror-dtype
mirror numpy dtypes that are valid for tensors
2 parents 82bd144 + c7553bb commit bd3f7ee

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

src/mygrad/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
tensor,
55
Tensor,
66
)
7+
from mygrad._dtype_mirrors import *
78
from mygrad._utils.graph_tracking import no_autodiff
89
from mygrad._utils.lock_management import (
910
mem_guard_active,

src/mygrad/_dtype_mirrors.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy
2+
3+
__all__ = [
4+
"bool8",
5+
"int8",
6+
"int16",
7+
"int32",
8+
"int64",
9+
"uint8",
10+
"uint16",
11+
"uint32",
12+
"uint64",
13+
"intp",
14+
"uintp",
15+
"float16",
16+
"float32",
17+
"float64",
18+
"half",
19+
"single",
20+
"double",
21+
"longdouble",
22+
]
23+
24+
bool8 = numpy.bool8
25+
int8 = numpy.int8
26+
int16 = numpy.int16
27+
int32 = numpy.int32
28+
int64 = numpy.int64
29+
uint8 = numpy.uint8
30+
uint16 = numpy.uint16
31+
uint32 = numpy.uint32
32+
uint64 = numpy.uint64
33+
intp = numpy.intp
34+
uintp = numpy.uintp
35+
float16 = numpy.float16
36+
float32 = numpy.float32
37+
float64 = numpy.float64
38+
half = numpy.half
39+
single = numpy.single
40+
double = numpy.double
41+
longdouble = numpy.longdouble

tests/test_dtype_mirrors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import pytest
2+
3+
import mygrad as mg
4+
from mygrad._dtype_mirrors import __all__ as all_mirrored_dtyped
5+
6+
7+
@pytest.mark.parametrize("dtype_str", all_mirrored_dtyped)
8+
def test_mirrored_dtype_is_valid(dtype_str):
9+
mg.tensor(1, dtype=getattr(mg, dtype_str))

0 commit comments

Comments
 (0)