File size: 7,898 Bytes
9ee9ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import torch
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def gaussian_layer_stack_pipeline(
    x: torch.Tensor,
    n_layers: int,
    base_ksize: int = 3,
    ksize_growth: int = 2,
    sigma: float | None = None,
    eps: float = 1e-8,
):
    """
    All-in-one GPU batch pipeline:
      1) Per-sample min-max normalize to [0,1]
      2) Resize to (32,32)
      3) Apply L Gaussian blurs with increasing kernel size in a single
         horizontal conv + single vertical conv using depthwise groups
         (via a shared max kernel padded with zeros)
      4) Renormalize each layer to [0,1]
      5) Return stacked (B,L,32,32), flat (B,L,1024), tiled (B,L,1024,1024 view)

    Args:
      x: (B,H,W) or (B,1,H,W) tensor (any device/dtype)
      n_layers: number of layers
      base_ksize: starting odd kernel size (e.g., 3)
      ksize_growth: increment per layer (e.g., 2) -> ensures odd sizes
      sigma: if None, uses (ksize-1)/6 per layer; else fixed sigma for all
      eps: small number for safe division

    Returns:
      stacked: (B, n_layers, 32, 32)  float on x.device
      flat:    (B, n_layers, 1024)
      tiled:   (B, n_layers, 1024, 1024)  (expand view; memory-cheap)
    """
    assert n_layers >= 1, "n_layers must be >= 1"

    # ---- Ensure 4D, 1 channel; cast to float (stay on same device) ----
    if x.ndim == 3:
        x = x.unsqueeze(1)  # (B,1,H,W)
    elif x.ndim != 4 or x.shape[1] not in (1,):
        raise ValueError(f"Expected (B,H,W) or (B,1,H,W); got {tuple(x.shape)}")
    x = x.float()

    B, _, H, W = x.shape

    # ---- Per-sample min-max normalize to [0,1] ----
    xmin = x.amin(dim=(2, 3), keepdim=True)
    xmax = x.amax(dim=(2, 3), keepdim=True)
    denom = (xmax - xmin).clamp_min(eps)
    x = (x - xmin) / denom  # (B,1,H,W) in [0,1]

    # ---- Resize to 32x32 on GPU ----
    x = F.interpolate(x, size=(32, 32), mode="bilinear", align_corners=False)  # (B,1,32,32)

    # ---- Prepare per-layer kernel sizes (odd) ----
    ksizes = []
    for i in range(n_layers, 0, -1):  # to keep your original ordering: L...1
        k = base_ksize + i * ksize_growth
        k = int(k)
        if k % 2 == 0:
            k += 1
        k = max(k, 1)
        ksizes.append(k)

    Kmax = max(ksizes)
    pad = Kmax // 2

    # ---- Build per-layer 1D Gaussian vectors and embed into shared Kmax kernel ----
    # We create horizontal weights of shape (L,1,1,Kmax) and vertical (L,1,Kmax,1)
    device, dtype = x.device, x.dtype
    weight_h = torch.zeros((n_layers, 1, 1, Kmax), device=device, dtype=dtype)
    weight_v = torch.zeros((n_layers, 1, Kmax, 1), device=device, dtype=dtype)

    for idx, k in enumerate(ksizes):
        # choose sigma
        sig = sigma if (sigma is not None and sigma > 0) else (k - 1) / 6.0
        r = k // 2
        xp = torch.arange(-r, r + 1, device=device, dtype=dtype)
        g = torch.exp(-(xp * xp) / (2.0 * sig * sig))
        g = g / g.sum()  # (k,)

        # center g into Kmax with zeros around
        start = (Kmax - k) // 2
        end = start + k

        # horizontal row
        weight_h[idx, 0, 0, start:end] = g  # (1 x Kmax)

        # vertical column
        weight_v[idx, 0, start:end, 0] = g  # (Kmax x 1)

    # ---- Duplicate input across L channels (depthwise groups) ----
    xL = x.expand(B, n_layers, 32, 32).contiguous()  # (B,L,32,32)

    # ---- Separable Gaussian blur with a single pass per axis (groups=L) ----
    # Horizontal
    xh = F.pad(xL, (pad, pad, 0, 0), mode="reflect")
    xh = F.conv2d(xh, weight=weight_h, bias=None, stride=1, padding=0, groups=n_layers)  # (B,L,32,32)

    # Vertical
    xv = F.pad(xh, (0, 0, pad, pad), mode="reflect")
    yL = F.conv2d(xv, weight=weight_v, bias=None, stride=1, padding=0, groups=n_layers)  # (B,L,32,32)

    # ---- Renormalize each layer to [0,1] (per-sample, per-layer) ----
    y_min = yL.amin(dim=(2, 3), keepdim=True)
    y_max = yL.amax(dim=(2, 3), keepdim=True)
    y_den = (y_max - y_min).clamp_min(eps)
    stacked = (yL - y_min) / y_den  # (B,L,32,32) in [0,1]

    # ---- Flatten + tile (expand view; caution w/ later materialization) ----
    flat = stacked.reshape(B, n_layers, 32 * 32)               # (B,L,1024)
    tiled = flat.unsqueeze(-2).expand(-1, -1, 2 * 32 * 32, -1)     # (B,L,1024,1024) view

    return stacked, flat, tiled

def plot_layers_any(
    x,
    *,
    max_batches=None,
    vlim=(0, 1),
    one_indexed: bool = False,
    max_cols: int = 6,
):
    """
    Plot layers for each batch sample in separate figures.

    Accepts:
      - stacked: (B, L, H, W)
      - flat:    (B, L, HW)
      - tiled:   (B, L, HW, HW)

    Behavior:
      - Creates one figure PER BATCH (up to `max_batches`).
      - At most `max_cols` layers per row (default 6).
      - Column headers: 'Layer {i}' descending from n-1 -> 0 (or n -> 1 if one_indexed=True).
      - Figure title per batch: 'Masks for input {i} out of {B}'.

    Returns:
      A list of (fig, axes) tuples, one per plotted batch.
    """
    # ---- Normalize input to torch ----
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    if not isinstance(x, torch.Tensor):
        raise TypeError(f"Expected torch.Tensor or np.ndarray, got {type(x)}")

    if x.ndim not in (3, 4):
        raise ValueError(f"Expected ndim 3 or 4, got shape {tuple(x.shape)}")

    # ---- Convert to (B, L, H, W) 'stacked' ----
    if x.ndim == 4:
        B, L, A, B_ = x.shape
        if A == B_:
            # Could be stacked (H==W) or tiled (HW x HW). Heuristic: if A is a perfect square
            # and reasonably large (e.g., 1024), treat as tiled and collapse to flat.
            s = int(math.isqrt(A))
            if s * s == A and A >= 64:
                flat = x[..., 0, :].detach()  # (B, L, HW)
                H = W = s
                stacked = flat.reshape(B, L, H, W)
            else:
                stacked = x.detach()
        else:
            stacked = x.detach()
    else:
        # x.ndim == 3 -> (B, L, HW)
        B, L, HW = x.shape
        s = int(math.isqrt(HW))
        if s * s != HW:
            if HW != 32 * 32:
                raise ValueError(
                    f"Cannot infer square image size from HW={HW}. "
                    f"Provide stacked (B,L,H,W) or flat with square HW."
                )
            s = 32
        H = W = s
        stacked = x.detach().reshape(B, L, H, W)

    # Ensure float & CPU for plotting
    stacked = stacked.to(torch.float32).cpu().numpy()

    # ---- Batch selection ----
    B, L, H, W = stacked.shape
    plot_B = B if max_batches is None else max(1, min(B, int(max_batches)))

    # ---- Layout params ----
    cols = max(1, int(max_cols))
    rows_needed = lambda L: (L + cols - 1) // cols

    figs = []
    for b in range(plot_B):
        # number of rows for this batch
        r = rows_needed(L)
        fig, axes = plt.subplots(r, cols, figsize=(cols * 3, r * 3), squeeze=False)
        fig.suptitle(f"Masks for input {b} out of {B}", fontsize=12, y=1.02)

        for l in range(L):
            rr = l // cols
            cc = l % cols
            ax = axes[rr, cc]
            if vlim is None:
                ax.imshow(stacked[b, l], cmap="gray")
            else:
                ax.imshow(stacked[b, l], cmap="gray", vmin=vlim[0], vmax=vlim[1])
            ax.axis("off")

            # Set column titles only on the first row of the grid
            label_num = (l + 1) if one_indexed else l
            ax.set_title(f"Layer {label_num}", fontsize=10)

        # Hide any unused axes (when L is not a multiple of cols)
        total_slots = r * cols
        for empty_idx in range(L, total_slots):
            rr = empty_idx // cols
            cc = empty_idx % cols
            axes[rr, cc].axis("off")

        plt.tight_layout()
        figs.append((fig, axes))
    return figs