| """ |
| This file is self-contained: download it alongside `model.safetensors`, |
| `config.json`, and `maxsub.json` to load and run the model. |
| """ |
|
|
| import json |
| from collections import deque |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
| def drop_path( |
| x: torch.Tensor, drop_prob: float = 0.0, training: bool = False |
| ) -> torch.Tensor: |
| if drop_prob == 0.0 or not training: |
| return x |
| keep_prob = 1 - drop_prob |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| random_tensor = random_tensor.floor() |
| return x.div(keep_prob) * random_tensor |
|
|
|
|
| class DropPath(nn.Module): |
| def __init__(self, drop_prob: float = 0.0): |
| super().__init__() |
| self.drop_prob = drop_prob |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return drop_path(x, self.drop_prob, self.training) |
|
|
|
|
| |
| |
| |
| class ConvNeXtBlock1D(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| kernel_size: int, |
| drop_path: float, |
| layer_scale_init_value: float, |
| activation: nn.Module, |
| ): |
| super().__init__() |
| self.dwconv = nn.Conv1d( |
| dim, dim, kernel_size=kernel_size, padding="same", groups=dim |
| ) |
| self.pwconv1 = nn.Linear(dim, 4 * dim) |
| self.act = activation() if isinstance(activation, type) else activation |
| self.pwconv2 = nn.Linear(4 * dim, dim) |
| self.gamma = ( |
| nn.Parameter(layer_scale_init_value * torch.ones(dim)) |
| if layer_scale_init_value > 0 |
| else None |
| ) |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| shortcut = x |
| x = self.dwconv(x) |
| x = x.permute(0, 2, 1) |
| x = self.pwconv1(x) |
| x = self.act(x) |
| x = self.pwconv2(x) |
| if self.gamma is not None: |
| x = x * self.gamma |
| x = x.permute(0, 2, 1) |
| x = shortcut + self.drop_path(x) |
| return x |
|
|
|
|
| class ConvNextBlock1DAdaptor(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| stride: int, |
| dropout: float, |
| use_batchnorm: bool, |
| activation: nn.Module, |
| layer_scale_init_value: float, |
| drop_path_rate: float, |
| block_type: str, |
| ): |
| super().__init__() |
| if in_channels != out_channels: |
| act = activation() if isinstance(activation, type) else activation |
| self.pwconv = nn.Sequential(nn.Linear(in_channels, out_channels), act) |
| else: |
| self.pwconv = None |
|
|
| if block_type == "convnext": |
| self.block = ConvNeXtBlock1D( |
| dim=out_channels, |
| kernel_size=kernel_size, |
| drop_path=drop_path_rate, |
| layer_scale_init_value=layer_scale_init_value, |
| activation=activation, |
| ) |
| else: |
| self.block = None |
|
|
| if stride > 1: |
| self.reduction_pool = nn.AvgPool1d(kernel_size=stride, stride=stride) |
| else: |
| self.reduction_pool = None |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.pwconv is not None: |
| x = x.permute(0, 2, 1) |
| x = self.pwconv(x) |
| x = x.permute(0, 2, 1) |
| if self.block is not None: |
| x = self.block(x) |
| if self.reduction_pool is not None: |
| x = self.reduction_pool(x) |
| return x |
|
|
|
|
| |
| |
| |
| def make_mlp( |
| input_dim: int, |
| hidden_dims: Optional[Tuple[int, ...]], |
| output_dim: int, |
| dropout: float = 0.2, |
| output_activation: Optional[nn.Module] = None, |
| ) -> nn.Module: |
| layers: List[nn.Module] = [] |
| last = input_dim |
| if hidden_dims is not None and len(hidden_dims) > 0: |
| for hd in hidden_dims: |
| layers.extend([nn.Linear(last, hd), nn.ReLU()]) |
| if dropout and dropout > 0: |
| layers.append(nn.Dropout(dropout)) |
| last = hd |
| layers.append(nn.Linear(last, output_dim)) |
| if output_activation is not None: |
| layers.append(output_activation) |
| return nn.Sequential(*layers) |
|
|
|
|
| |
| |
| |
| class MultiscaleCNNBackbone1D(nn.Module): |
| def __init__( |
| self, |
| dim_in: int, |
| channels: Tuple[int, ...], |
| kernel_sizes: Tuple[int, ...], |
| strides: Tuple[int, ...], |
| dropout_rate: float, |
| ramped_dropout_rate: bool, |
| block_type: str, |
| pooling_type: str, |
| final_pool: bool, |
| use_batchnorm: bool, |
| activation: nn.Module, |
| output_type: str, |
| layer_scale_init_value: float, |
| drop_path_rate: float, |
| ): |
| super().__init__() |
| assert len(channels) == len(kernel_sizes) == len(strides) |
| self.dim_in = dim_in |
| self.output_type = output_type |
|
|
| if ramped_dropout_rate: |
| dropout_per_stage = torch.linspace( |
| 0.0, dropout_rate, steps=len(channels) |
| ).tolist() |
| else: |
| dropout_per_stage = [dropout_rate] * len(channels) |
|
|
| if pooling_type == "average": |
| pool_cls = nn.AvgPool1d |
| pool_kwargs = {"kernel_size": 3, "stride": 2} |
| elif pooling_type == "max": |
| pool_cls = nn.MaxPool1d |
| pool_kwargs = {"kernel_size": 2, "stride": 2} |
| else: |
| raise ValueError(f"Invalid pooling_type '{pooling_type}'") |
|
|
| layers: List[nn.Module] = [] |
| in_ch = 1 |
| for i, (out_ch, k, s) in enumerate(zip(channels, kernel_sizes, strides)): |
| stage_block = ConvNextBlock1DAdaptor( |
| in_channels=in_ch, |
| out_channels=out_ch, |
| kernel_size=k, |
| stride=s, |
| dropout=dropout_per_stage[i], |
| use_batchnorm=use_batchnorm, |
| activation=activation, |
| layer_scale_init_value=layer_scale_init_value, |
| drop_path_rate=drop_path_rate, |
| block_type=block_type, |
| ) |
| layers.append(stage_block) |
| if i < len(channels) - 1 or final_pool: |
| layers.append(pool_cls(**pool_kwargs)) |
| in_ch = out_ch |
|
|
| self.net = nn.Sequential(*layers) |
|
|
| if self.output_type == "gap": |
| self.dim_output = channels[-1] |
| elif self.output_type == "flatten": |
| with torch.no_grad(): |
| dummy = torch.zeros(1, 1, self.dim_in) |
| out = self.net(dummy) |
| self.dim_output = int(out.shape[1] * out.shape[2]) |
| else: |
| raise ValueError(f"Invalid output_type '{self.output_type}'") |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if x.ndim == 2: |
| x = x[:, None, :] |
| x = self.net(x) |
| if self.output_type == "gap": |
| x = x.mean(dim=-1) |
| else: |
| x = x.reshape(x.shape[0], -1) |
| return x |
|
|
|
|
| |
| |
| |
| def _build_distance_matrix_from_maxsub_lut( |
| maxsub_lut: Dict[str, List[int]], |
| num_sg_classes: int, |
| ) -> torch.Tensor: |
| adjacency: List[set] = [set() for _ in range(num_sg_classes)] |
| for key, neighbors in maxsub_lut.items(): |
| src = int(key) - 1 |
| for raw_dst in neighbors: |
| dst = int(raw_dst) - 1 |
| adjacency[src].add(dst) |
| adjacency[dst].add(src) |
|
|
| distance_matrix = torch.zeros( |
| (num_sg_classes, num_sg_classes), dtype=torch.float32 |
| ) |
| for src in range(num_sg_classes): |
| dists = [-1] * num_sg_classes |
| dists[src] = 0 |
| queue = deque([src]) |
| while queue: |
| cur = queue.popleft() |
| for nxt in adjacency[cur]: |
| if dists[nxt] == -1: |
| dists[nxt] = dists[cur] + 1 |
| queue.append(nxt) |
| distance_matrix[src] = torch.tensor(dists, dtype=torch.float32) |
| return distance_matrix |
|
|
|
|
| def load_gemd_distance_matrix( |
| path: str, num_sg_classes: int = 230 |
| ) -> torch.Tensor: |
| with open(path, "r", encoding="utf-8") as f: |
| payload: Any = json.load(f) |
| if isinstance(payload, dict) and all(str(k).isdigit() for k in payload.keys()): |
| return _build_distance_matrix_from_maxsub_lut(payload, num_sg_classes) |
| elif isinstance(payload, list): |
| return torch.as_tensor(payload, dtype=torch.float32) |
| raise ValueError(f"Could not parse GEMD data from {path}") |
|
|
|
|
| |
| |
| |
| class AlphaDiffract(nn.Module): |
| """ |
| AlphaDiffract: multi-task 1D ConvNeXt for powder X-ray diffraction |
| pattern analysis. |
| |
| Predicts crystal system (7 classes), space group (230 classes), and |
| lattice parameters (6 values: a, b, c, alpha, beta, gamma). |
| """ |
|
|
| CRYSTAL_SYSTEMS = [ |
| "Triclinic", |
| "Monoclinic", |
| "Orthorhombic", |
| "Tetragonal", |
| "Trigonal", |
| "Hexagonal", |
| "Cubic", |
| ] |
|
|
| def __init__(self, config: dict, maxsub_path: Optional[str] = None): |
| super().__init__() |
| bb = config["backbone"] |
| heads = config["heads"] |
| tasks = config["tasks"] |
|
|
| activation = nn.GELU |
|
|
| self.backbone = MultiscaleCNNBackbone1D( |
| dim_in=bb["dim_in"], |
| channels=tuple(bb["channels"]), |
| kernel_sizes=tuple(bb["kernel_sizes"]), |
| strides=tuple(bb["strides"]), |
| dropout_rate=bb["dropout_rate"], |
| ramped_dropout_rate=bb["ramped_dropout_rate"], |
| block_type=bb["block_type"], |
| pooling_type=bb["pooling_type"], |
| final_pool=bb["final_pool"], |
| use_batchnorm=bb["use_batchnorm"], |
| activation=activation, |
| output_type=bb["output_type"], |
| layer_scale_init_value=bb["layer_scale_init_value"], |
| drop_path_rate=bb["drop_path_rate"], |
| ) |
| feat_dim = self.backbone.dim_output |
|
|
| self.cs_head = make_mlp( |
| feat_dim, tuple(heads["cs_hidden"]), tasks["num_cs_classes"], |
| dropout=heads["head_dropout"], |
| ) |
| self.sg_head = make_mlp( |
| feat_dim, tuple(heads["sg_hidden"]), tasks["num_sg_classes"], |
| dropout=heads["head_dropout"], |
| ) |
| self.lp_head = make_mlp( |
| feat_dim, tuple(heads["lp_hidden"]), tasks["num_lp_outputs"], |
| dropout=heads["head_dropout"], |
| ) |
|
|
| self.bound_lp_with_sigmoid = tasks["bound_lp_with_sigmoid"] |
| self.register_buffer( |
| "lp_min", |
| torch.tensor(tasks["lp_bounds_min"], dtype=torch.float32), |
| ) |
| self.register_buffer( |
| "lp_max", |
| torch.tensor(tasks["lp_bounds_max"], dtype=torch.float32), |
| ) |
|
|
| if maxsub_path is not None: |
| gemd = load_gemd_distance_matrix(maxsub_path) |
| self.register_buffer("gemd_distance_matrix", gemd) |
| else: |
| self.gemd_distance_matrix = None |
|
|
| def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| x: PXRD pattern tensor of shape ``(batch, 8192)`` or |
| ``(batch, 1, 8192)``, intensity-normalized to [0, 100]. |
| |
| Returns: |
| Dict with keys ``cs_logits``, ``sg_logits``, ``lp``. |
| """ |
| feats = self.backbone(x) |
| cs_logits = self.cs_head(feats) |
| sg_logits = self.sg_head(feats) |
| lp = self.lp_head(feats) |
| if self.bound_lp_with_sigmoid: |
| lp = torch.sigmoid(lp) * (self.lp_max - self.lp_min) + self.lp_min |
| return {"cs_logits": cs_logits, "sg_logits": sg_logits, "lp": lp} |
|
|
| |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| model_dir: str, |
| device: str = "cpu", |
| ) -> "AlphaDiffract": |
| """Load model from a directory containing config.json, |
| model.safetensors, and maxsub.json.""" |
| model_dir = Path(model_dir) |
| with open(model_dir / "config.json", "r") as f: |
| config = json.load(f) |
|
|
| maxsub_path = model_dir / "maxsub.json" |
| model = cls( |
| config, |
| maxsub_path=str(maxsub_path) if maxsub_path.exists() else None, |
| ) |
|
|
| weights_path = model_dir / "model.safetensors" |
| if weights_path.exists(): |
| from safetensors.torch import load_file |
| state_dict = load_file(str(weights_path), device=device) |
| else: |
| |
| pt_path = model_dir / "model.pt" |
| state_dict = torch.load(str(pt_path), map_location=device, weights_only=True) |
|
|
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
| return model |
|
|