Skip to content

Commit 3c5647e

Browse files
committed
everything is documented now (at least barely)
1 parent 0add06e commit 3c5647e

12 files changed

Lines changed: 334 additions & 46 deletions

File tree

docs/source/docstring template.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class MyModule:
2424
[A title or short sentence describing the first example]:
2525
2626
.. code-block:: python
27+
2728
opt = tz.Modular(
2829
model.parameters(),
2930
...
@@ -32,6 +33,7 @@ class MyModule:
3233
[A title or short sentence for a second, different example]:
3334
3435
.. code-block:: python
36+
3537
opt = tz.Modular(
3638
model.parameters(),
3739
...

torchzero/modules/ops/binary.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#pyright: reportIncompatibleMethodOverride=false
2-
""""""
32
from abc import ABC, abstractmethod
43
from collections.abc import Iterable, Sequence
54
from operator import itemgetter
@@ -48,6 +47,10 @@ def step(self, var: Var) -> Var:
4847

4948

5049
class Add(BinaryOperation):
50+
"""Add :code:`other` to tensors. :code:`other` can be a number or a module.
51+
52+
If :code:`other` is a module, this calculates :code:`tensors + other(tensors)`
53+
"""
5154
def __init__(self, other: Chainable | float, alpha: float = 1):
5255
defaults = dict(alpha=alpha)
5356
super().__init__(defaults, other=other)
@@ -59,6 +62,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
5962
return update
6063

6164
class Sub(BinaryOperation):
65+
"""Subtract :code:`other` from tensors. :code:`other` can be a number or a module.
66+
67+
If :code:`other` is a module, this calculates :code:`tensors - other(tensors)`
68+
"""
6269
def __init__(self, other: Chainable | float, alpha: float = 1):
6370
defaults = dict(alpha=alpha)
6471
super().__init__(defaults, other=other)
@@ -70,6 +77,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
7077
return update
7178

7279
class RSub(BinaryOperation):
80+
"""Subtract tensors from :code:`other`. :code:`other` can be a number or a module.
81+
82+
If :code:`other` is a module, this calculates :code:`other(tensors) - tensors`
83+
"""
7384
def __init__(self, other: Chainable | float):
7485
super().__init__({}, other=other)
7586

@@ -78,6 +89,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
7889
return other - TensorList(update)
7990

8091
class Mul(BinaryOperation):
92+
"""Multiply tensors by :code:`other`. :code:`other` can be a number or a module.
93+
94+
If :code:`other` is a module, this calculates :code:`tensors * other(tensors)`
95+
"""
8196
def __init__(self, other: Chainable | float):
8297
super().__init__({}, other=other)
8398

@@ -87,6 +102,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
87102
return update
88103

89104
class Div(BinaryOperation):
105+
"""Divide tensors by :code:`other`. :code:`other` can be a number or a module.
106+
107+
If :code:`other` is a module, this calculates :code:`tensors / other(tensors)`
108+
"""
90109
def __init__(self, other: Chainable | float):
91110
super().__init__({}, other=other)
92111

@@ -96,6 +115,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
96115
return update
97116

98117
class RDiv(BinaryOperation):
118+
"""Divide :code:`other` by tensors. :code:`other` can be a number or a module.
119+
120+
If :code:`other` is a module, this calculates :code:`other(tensors) / tensors`
121+
"""
99122
def __init__(self, other: Chainable | float):
100123
super().__init__({}, other=other)
101124

@@ -104,6 +127,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
104127
return other / TensorList(update)
105128

106129
class Pow(BinaryOperation):
130+
"""Take tensors to the power of :code:`exponent`. :code:`exponent` can be a number or a module.
131+
132+
If :code:`exponent` is a module, this calculates :code:`tensors ^ exponent(tensors)`
133+
"""
107134
def __init__(self, exponent: Chainable | float):
108135
super().__init__({}, exponent=exponent)
109136

@@ -113,6 +140,10 @@ def transform(self, var, update: list[torch.Tensor], exponent: float | list[torc
113140
return update
114141

115142
class RPow(BinaryOperation):
143+
"""Take :code:`other` to the power of tensors. :code:`other` can be a number or a module.
144+
145+
If :code:`other` is a module, this calculates :code:`other(tensors) ^ tensors`
146+
"""
116147
def __init__(self, other: Chainable | float):
117148
super().__init__({}, other=other)
118149

@@ -123,6 +154,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
123154
return other
124155

125156
class Lerp(BinaryOperation):
157+
"""Does a linear interpolation of tensors and :code:`end` module based on a scalar :code:`weight`.
158+
159+
The output is given by :code:`output = tensors + weight * (end(tensors) - tensors)`
160+
"""
126161
def __init__(self, end: Chainable, weight: float):
127162
defaults = dict(weight=weight)
128163
super().__init__(defaults, end=end)
@@ -133,6 +168,7 @@ def transform(self, var, update: list[torch.Tensor], end: list[torch.Tensor]):
133168
return update
134169

135170
class CopySign(BinaryOperation):
171+
"""Returns tensors with sign copied from :code:`other(tensors)`."""
136172
def __init__(self, other: Chainable):
137173
super().__init__({}, other=other)
138174

@@ -141,6 +177,7 @@ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
141177
return [u.copysign_(o) for u, o in zip(update, other)]
142178

143179
class RCopySign(BinaryOperation):
180+
"""Returns :code:`other(tensors)` with sign copied from tensors."""
144181
def __init__(self, other: Chainable):
145182
super().__init__({}, other=other)
146183

@@ -150,6 +187,10 @@ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
150187
CopyMagnitude = RCopySign
151188

152189
class Clip(BinaryOperation):
190+
"""clip tensors to be in :code:`(min, max)` range. :code:`min` and :code:`max: can be None, numbers or modules.
191+
192+
If code:`min` and :code:`max`: are modules, this calculates :code:`tensors.clip(min(tensors), max(tensors))`.
193+
"""
153194
def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
154195
super().__init__({}, min=min, max=max)
155196

@@ -158,7 +199,10 @@ def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Ten
158199
return TensorList(update).clamp_(min=min, max=max)
159200

160201
class MirroredClip(BinaryOperation):
161-
"""clip by -value, value"""
202+
"""clip tensors to be in :code:`(-value, value)` range. :code:`value` can be a number or a module.
203+
204+
If :code:`value` is a module, this calculates :code:`tensors.clip(-value(tensors), value(tensors))`
205+
"""
162206
def __init__(self, value: float | Chainable):
163207
super().__init__({}, value=value)
164208

@@ -168,7 +212,7 @@ def transform(self, var, update: list[torch.Tensor], value: float | list[torch.T
168212
return TensorList(update).clamp_(min=min, max=value)
169213

170214
class Graft(BinaryOperation):
171-
"""use direction from update and magnitude from `magnitude` module"""
215+
"""Outputs tensors rescaled to have the same norm as :code:`magnitude(tensors)`."""
172216
def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
173217
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
174218
super().__init__(defaults, magnitude=magnitude)
@@ -179,7 +223,7 @@ def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tenso
179223
return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
180224

181225
class RGraft(BinaryOperation):
182-
"""use direction from `direction` module and magnitude from update"""
226+
"""Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""
183227

184228
def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
185229
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
@@ -193,6 +237,7 @@ def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tenso
193237
GraftToUpdate = RGraft
194238

195239
class Maximum(BinaryOperation):
240+
"""Outputs :code:`maximum(tensors, other(tensors))`"""
196241
def __init__(self, other: Chainable):
197242
super().__init__({}, other=other)
198243

@@ -202,6 +247,7 @@ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
202247
return update
203248

204249
class Minimum(BinaryOperation):
250+
"""Outputs :code:`minimum(tensors, other(tensors))`"""
205251
def __init__(self, other: Chainable):
206252
super().__init__({}, other=other)
207253

@@ -212,7 +258,7 @@ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
212258

213259

214260
class GramSchimdt(BinaryOperation):
215-
"""makes update orthonormal to `other`"""
261+
"""outputs tensors made orthogonal to `other(tensors)` via Gram-Schmidt."""
216262
def __init__(self, other: Chainable):
217263
super().__init__({}, other=other)
218264

@@ -223,7 +269,7 @@ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
223269

224270

225271
class Threshold(BinaryOperation):
226-
"""update above/below threshold, value at and below"""
272+
"""Outputs tensors thresholded such that values above :code:`threshold` are set to :code:`value`."""
227273
def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
228274
defaults = dict(update_above=update_above)
229275
super().__init__(defaults, threshold=threshold, value=value)

torchzero/modules/ops/debug.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ...utils.tensorlist import Distributions
77

88
class PrintUpdate(Module):
9+
"""Prints current update."""
910
def __init__(self, text = 'update = ', print_fn = print):
1011
defaults = dict(text=text, print_fn=print_fn)
1112
super().__init__(defaults)
@@ -15,6 +16,7 @@ def step(self, var):
1516
return var
1617

1718
class PrintShape(Module):
19+
"""Prints shapes of the update."""
1820
def __init__(self, text = 'shapes = ', print_fn = print):
1921
defaults = dict(text=text, print_fn=print_fn)
2022
super().__init__(defaults)

0 commit comments

Comments
 (0)