Skip to content

Commit 7684384

Browse files
committed
avoid deep copies from dc.asdict()
1 parent 5c2c514 commit 7684384

3 files changed

Lines changed: 21 additions & 4 deletions

File tree

src/finegrain_toolbox/dc.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import dataclasses as dc
2+
from typing import TYPE_CHECKING, Any
3+
4+
if TYPE_CHECKING:
5+
from _typeshed import DataclassInstance
6+
7+
8+
def shallow_asdict(m: "DataclassInstance") -> dict[str, Any]:
9+
return {field.name: getattr(m, field.name) for field in dc.fields(m)}
10+
11+
12+
class DcMixin:
13+
def shallow_asdict(self) -> dict[str, Any]:
14+
assert dc.is_dataclass(self)
15+
return shallow_asdict(self)

src/finegrain_toolbox/flux/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel
99

10+
from ..dc import DcMixin
1011
from ..models import SafePushToHubMixin
1112
from ..torch import default_device, default_dtype
1213
from ..types import Self
@@ -25,7 +26,7 @@ def get_mu(scheduler: FlowMatchEulerDiscreteScheduler, image_seq_len: int) -> fl
2526

2627

2728
@dc.dataclass(kw_only=True)
28-
class Model(SafePushToHubMixin):
29+
class Model(SafePushToHubMixin, DcMixin):
2930
device: torch.device
3031
dtype: torch.dtype
3132
scheduler: FlowMatchEulerDiscreteScheduler
@@ -62,7 +63,7 @@ def to(
6263
device: torch.device | None = None,
6364
dtype: torch.dtype | None = None,
6465
) -> Self:
65-
params = dc.asdict(self)
66+
params = self.shallow_asdict()
6667
if device is not None:
6768
params["device"] = device
6869
if dtype is not None:

src/finegrain_toolbox/flux/prompt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
66

7+
from ..dc import DcMixin
78
from ..torch import default_device, default_dtype
89
from ..types import Self
910

@@ -16,7 +17,7 @@ class Tokenized:
1617

1718

1819
@dc.dataclass(kw_only=True)
19-
class Prompt:
20+
class Prompt(DcMixin):
2021
t5_embeds: torch.Tensor # (batch_size, seq_len, 4096)
2122
clip_embeds: torch.Tensor # (batch_size, 768)
2223
text_ids: torch.Tensor # shape (seq_len, 3)
@@ -38,7 +39,7 @@ def to(
3839
device: torch.device | None = None,
3940
dtype: torch.dtype | None = None,
4041
) -> Self:
41-
params = dc.asdict(self)
42+
params = self.shallow_asdict()
4243
params["t5_embeds"] = self.t5_embeds.to(device=device, dtype=dtype)
4344
params["clip_embeds"] = self.clip_embeds.to(device=device, dtype=dtype)
4445
params["text_ids"] = self.text_ids.to(device=device, dtype=dtype)

0 commit comments

Comments
 (0)