11#pyright: reportIncompatibleMethodOverride=false
2- """"""
32from abc import ABC , abstractmethod
43from collections .abc import Iterable , Sequence
54from operator import itemgetter
@@ -48,6 +47,10 @@ def step(self, var: Var) -> Var:
4847
4948
5049class Add (BinaryOperation ):
50+ """Add :code:`other` to tensors. :code:`other` can be a number or a module.
51+
52+ If :code:`other` is a module, this calculates :code:`tensors + other(tensors)`
53+ """
5154 def __init__ (self , other : Chainable | float , alpha : float = 1 ):
5255 defaults = dict (alpha = alpha )
5356 super ().__init__ (defaults , other = other )
@@ -59,6 +62,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
5962 return update
6063
6164class Sub (BinaryOperation ):
65+ """Subtract :code:`other` from tensors. :code:`other` can be a number or a module.
66+
67+ If :code:`other` is a module, this calculates :code:`tensors - other(tensors)`
68+ """
6269 def __init__ (self , other : Chainable | float , alpha : float = 1 ):
6370 defaults = dict (alpha = alpha )
6471 super ().__init__ (defaults , other = other )
@@ -70,6 +77,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
7077 return update
7178
7279class RSub (BinaryOperation ):
80+ """Subtract tensors from :code:`other`. :code:`other` can be a number or a module.
81+
82+ If :code:`other` is a module, this calculates :code:`other(tensors) - tensors`
83+ """
7384 def __init__ (self , other : Chainable | float ):
7485 super ().__init__ ({}, other = other )
7586
@@ -78,6 +89,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
7889 return other - TensorList (update )
7990
8091class Mul (BinaryOperation ):
92+ """Multiply tensors by :code:`other`. :code:`other` can be a number or a module.
93+
94+ If :code:`other` is a module, this calculates :code:`tensors * other(tensors)`
95+ """
8196 def __init__ (self , other : Chainable | float ):
8297 super ().__init__ ({}, other = other )
8398
@@ -87,6 +102,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
87102 return update
88103
89104class Div (BinaryOperation ):
105+ """Divide tensors by :code:`other`. :code:`other` can be a number or a module.
106+
107+ If :code:`other` is a module, this calculates :code:`tensors / other(tensors)`
108+ """
90109 def __init__ (self , other : Chainable | float ):
91110 super ().__init__ ({}, other = other )
92111
@@ -96,6 +115,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
96115 return update
97116
98117class RDiv (BinaryOperation ):
118+ """Divide :code:`other` by tensors. :code:`other` can be a number or a module.
119+
120+ If :code:`other` is a module, this calculates :code:`other(tensors) / tensors`
121+ """
99122 def __init__ (self , other : Chainable | float ):
100123 super ().__init__ ({}, other = other )
101124
@@ -104,6 +127,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
104127 return other / TensorList (update )
105128
106129class Pow (BinaryOperation ):
130+ """Take tensors to the power of :code:`exponent`. :code:`exponent` can be a number or a module.
131+
132+ If :code:`exponent` is a module, this calculates :code:`tensors ^ exponent(tensors)`
133+ """
107134 def __init__ (self , exponent : Chainable | float ):
108135 super ().__init__ ({}, exponent = exponent )
109136
@@ -113,6 +140,10 @@ def transform(self, var, update: list[torch.Tensor], exponent: float | list[torc
113140 return update
114141
115142class RPow (BinaryOperation ):
143+ """Take :code:`other` to the power of tensors. :code:`other` can be a number or a module.
144+
145+ If :code:`other` is a module, this calculates :code:`other(tensors) ^ tensors`
146+ """
116147 def __init__ (self , other : Chainable | float ):
117148 super ().__init__ ({}, other = other )
118149
@@ -123,6 +154,10 @@ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.T
123154 return other
124155
125156class Lerp (BinaryOperation ):
157+ """Does a linear interpolation of tensors and :code:`end` module based on a scalar :code:`weight`.
158+
159+ The output is given by :code:`output = tensors + weight * (end(tensors) - tensors)`
160+ """
126161 def __init__ (self , end : Chainable , weight : float ):
127162 defaults = dict (weight = weight )
128163 super ().__init__ (defaults , end = end )
@@ -133,6 +168,7 @@ def transform(self, var, update: list[torch.Tensor], end: list[torch.Tensor]):
133168 return update
134169
135170class CopySign (BinaryOperation ):
171+ """Returns tensors with sign copied from :code:`other(tensors)`."""
136172 def __init__ (self , other : Chainable ):
137173 super ().__init__ ({}, other = other )
138174
@@ -141,6 +177,7 @@ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
141177 return [u .copysign_ (o ) for u , o in zip (update , other )]
142178
143179class RCopySign (BinaryOperation ):
180+ """Returns :code:`other(tensors)` with sign copied from tensors."""
144181 def __init__ (self , other : Chainable ):
145182 super ().__init__ ({}, other = other )
146183
@@ -150,6 +187,10 @@ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
150187CopyMagnitude = RCopySign
151188
152189class Clip (BinaryOperation ):
190+ """clip tensors to be in :code:`(min, max)` range. :code:`min` and :code:`max: can be None, numbers or modules.
191+
192+ If code:`min` and :code:`max`: are modules, this calculates :code:`tensors.clip(min(tensors), max(tensors))`.
193+ """
153194 def __init__ (self , min : float | Chainable | None = None , max : float | Chainable | None = None ):
154195 super ().__init__ ({}, min = min , max = max )
155196
@@ -158,7 +199,10 @@ def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Ten
158199 return TensorList (update ).clamp_ (min = min , max = max )
159200
160201class MirroredClip (BinaryOperation ):
161- """clip by -value, value"""
202+ """clip tensors to be in :code:`(-value, value)` range. :code:`value` can be a number or a module.
203+
204+ If :code:`value` is a module, this calculates :code:`tensors.clip(-value(tensors), value(tensors))`
205+ """
162206 def __init__ (self , value : float | Chainable ):
163207 super ().__init__ ({}, value = value )
164208
@@ -168,7 +212,7 @@ def transform(self, var, update: list[torch.Tensor], value: float | list[torch.T
168212 return TensorList (update ).clamp_ (min = min , max = value )
169213
170214class Graft (BinaryOperation ):
171- """use direction from update and magnitude from `magnitude` module """
215+ """Outputs tensors rescaled to have the same norm as :code: `magnitude(tensors)`. """
172216 def __init__ (self , magnitude : Chainable , tensorwise :bool = True , ord :float = 2 , eps :float = 1e-6 ):
173217 defaults = dict (tensorwise = tensorwise , ord = ord , eps = eps )
174218 super ().__init__ (defaults , magnitude = magnitude )
@@ -179,7 +223,7 @@ def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tenso
179223 return TensorList (update ).graft_ (magnitude , tensorwise = tensorwise , ord = ord , eps = eps )
180224
181225class RGraft (BinaryOperation ):
182- """use direction from `direction` module and magnitude from update """
226+ """Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors """
183227
184228 def __init__ (self , direction : Chainable , tensorwise :bool = True , ord :float = 2 , eps :float = 1e-6 ):
185229 defaults = dict (tensorwise = tensorwise , ord = ord , eps = eps )
@@ -193,6 +237,7 @@ def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tenso
193237GraftToUpdate = RGraft
194238
195239class Maximum (BinaryOperation ):
240+ """Outputs :code:`maximum(tensors, other(tensors))`"""
196241 def __init__ (self , other : Chainable ):
197242 super ().__init__ ({}, other = other )
198243
@@ -202,6 +247,7 @@ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
202247 return update
203248
204249class Minimum (BinaryOperation ):
250+ """Outputs :code:`minimum(tensors, other(tensors))`"""
205251 def __init__ (self , other : Chainable ):
206252 super ().__init__ ({}, other = other )
207253
@@ -212,7 +258,7 @@ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
212258
213259
214260class GramSchimdt (BinaryOperation ):
215- """makes update orthonormal to `other` """
261+ """outputs tensors made orthogonal to `other(tensors)` via Gram-Schmidt. """
216262 def __init__ (self , other : Chainable ):
217263 super ().__init__ ({}, other = other )
218264
@@ -223,7 +269,7 @@ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
223269
224270
225271class Threshold (BinaryOperation ):
226- """update above/below threshold, value at and below """
272+ """Outputs tensors thresholded such that values above :code:` threshold` are set to :code:`value`. """
227273 def __init__ (self , threshold : Chainable | float , value : Chainable | float , update_above : bool ):
228274 defaults = dict (update_above = update_above )
229275 super ().__init__ (defaults , threshold = threshold , value = value )
0 commit comments