Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions operators/dimension_reduction/dimension_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
import umap
from sklearn.manifold import TSNE

from feluda import Operator


class ReductionModel:
"""Base class for dimension reduction models."""

def __init__(self, params: Any) -> None:
self.params = params

def validate_embeddings(self, embeddings_array: np.ndarray) -> np.ndarray:
@staticmethod
def validate_embeddings(embeddings_array: np.ndarray) -> np.ndarray:
"""Validate embeddings array, converting list to numpy array if needed.

Args:
Expand Down Expand Up @@ -69,7 +72,7 @@ def __post_init__(self) -> None:
raise ValueError("learning_rate must be positive")
if self.max_iter < 1:
raise ValueError("max_iter must be at least 1")
if self.method not in ["barnes_hut", "exact"]:
if self.method not in {"barnes_hut", "exact"}:
raise ValueError("method must be 'barnes_hut' or 'exact'")


Expand Down Expand Up @@ -188,7 +191,7 @@ def run(self, embeddings_array: np.ndarray) -> np.ndarray:
raise RuntimeError(f"UMAP reduction failed: {e}")


class DimensionReduction:
class DimensionReduction(Operator):
"""Main interface for dimensionality reduction."""

def __init__(self, model_type: str, params: dict[str, Any] | None = None) -> None:
Expand Down
Loading