Skip to content

Latest commit

 

History

History
74 lines (50 loc) · 5.88 KB

File metadata and controls

74 lines (50 loc) · 5.88 KB

Overall Problem

I have a 3D cryo-electron tomogram. I have used a segmentation model for the 3D segmentation of an organelle. However, the segmentation of the edges of the organelle is not very accurate. I want to use a prior --- that the organelle membrane is a bilayer. The intensity of a bilayer membrane can be modelled using two Gaussians, with some distance between their peaks. Using this prior, I should be able to refine the segmentation edges so that it is located at the center of the membrane.

Overall Project Plan

I achieved something like this by first taking a running average (window 5) along the z-axis (which has the least artifacts), then looking at individual 2D slices of the averaged tomogram (along the z-axis). For each slice, I represent the segmentation edges as splines. Then, I took a small rectangle along the spline normals and calculated the average intensity along the normal direction (this is to enhance the SNR). Next, I used cross-correlation with an estimated 2-gaussian membrane profile to estimate the membrane center along each normal. I then took the point cloud that contains the centers of all these normals, along all the slices, and used outlier removal and Poisson surface reconstruction to calculate a mesh representing the membrane center. This is now my new segmentation edge.

I want to simplify and improve this in the following ways:

  • Goal: Use a neural network that takes the tomogram and the 3D segmentation voxels as input and predicts the mesh directly.
  • Initialization: First train the network (if necessary) to predict a mesh along the segmentation edges (without refinement).
  • Sampling: Sample along the mesh generated by the network, and get intensity profiles along the sampled triangle normals.
  • Loss Function: Initially, use cross-correlation with a template to define a loss function to optimize.
  • Adaptive Template: After a few iterations, additionally allow the template to vary (the width of each gaussian, and the distance between them), but add Laplacian-type loss functions so that they don't vary too fast between points.
  • Self-Supervised: This makes up a self-supervised training algorithm for the mesh-generation network.

After training on a few segmentations, it should work quite well at the first pass.

Current Plan Iteration (Gemini)

1. Core Problem & Objective

  • Target: Refine inaccurate 3D organelle segmentations to align with the bilayer membrane center.
  • The Prior: Biological membranes exhibit a bi-Gaussian intensity profile (two peaks for lipid heads, a central valley for hydrophobic tails).
  • Method: A hybrid CNN-GCN network that uses differentiable sampling and a self-supervised cross-correlation loss to achieve sub-voxel accuracy.

2. Technical Architecture: "Voxel-to-Graph Fusion"

A. Feature Extraction (The "Eyes")

  • Conv3D Encoder: A 3D CNN extracts local photometric context from the tomogram.
  • Oriented Cuboid Sampling: For rotation invariance, the network samples a local 3D volume at each vertex, oriented along the vertex normal.
  • Canonical Alignment: By rotating the data into a local coordinate system where the normal is the z-axis, the network only needs to learn a single pattern (the "sandwich" profile).

B. Geometric Processing (The "Muscle Memory")

  • Mesh Deformation GCN: A Graph Convolutional Network processes vertex features. It uses the mesh's adjacency matrix to ensure neighbor consensus.
  • Stochastic Update: Instead of deforming the whole mesh, the network samples ∼500 vertices per iteration, reducing memory and preventing global "drifting."
  • Scalar Shift Constraint: The network predicts a scalar displacement along the normal ($v_{new} = v_{old} + s \cdot n$) rather than a raw 3D vector. This prevents mesh tangling and improves stability.

3. The Self-Supervised Learning Loop

A. The Bilayer Loss Function

  • Differentiable Template: A 1D bi-Gaussian function $T(x; \sigma, \delta)$ is generated on the fly.
  • Cross-Correlation: The sampled intensity profile is cross-correlated with the template. The network minimizes negative correlation to "snap" the mesh to the membrane center.
  • Adaptive Priors: The network predicts local $\sigma$ (width) and $\delta$ (peak distance), allowing the prior to adapt to local biological variations.

B. Regularization (Physics-Informed Constraints)

  • Mesh Laplacian: Penalizes sharp spikes; ensures a smooth manifold.
  • Parameter Laplacian: Penalizes abrupt changes in the learned $\sigma$ and $\delta$ across the surface.
  • Missing Wedge Weighting: Vertices with normals parallel to the electron beam receive lower loss weights, forcing the network to rely on "clean" side-view data and neighborhood smoothing.

4. Implementation Strategy (Phase 1: 2D Prototype)

  • Pixel2Contour: Simplify to 2D slices to debug the oriented rectangle sampling and circular graph connectivity.
  • Differentiable Normal Updates: Ensure normals are re-calculated using the cross-product of adjacent edges in every pass so the "sampling search" is dynamic.
  • Hybrid SDF/Remeshing: Use Differentiable Marching Cubes (or FlexiCubes) every X iterations to fix topological errors (e.g., merging components) before fine-tuning with the GCN.

5. Initial 2D Testing Plan

Before full 3D training, we will implement a simplified 2D version to validate the loss landscape and gradient flow.

  • Architecture:
    • Input: 2D Tomogram Slice ($H \times W$).
    • Graph: 2D Contour (Circular linked list of vertices).
    • Encoder: Shallow 2D CNN.
    • Sampler: Bilinear interpolation on oriented 2D patches.
  • Workflow:
    1. Extract 2D contours from the initial 3D segmentation.
    2. Train the 2D model to refine these contours using the bi-Gaussian prior.
    3. Visualize the "snapping" behavior of vertices to the membrane density.