""" Example: load AlphaDiffract and run inference on a PXRD pattern. Requirements: pip install torch safetensors numpy """ import numpy as np import torch from model import AlphaDiffract # 1. Load model --------------------------------------------------------------- model = AlphaDiffract.from_pretrained(".", device="cpu") # or "cuda" # 2. Prepare input ------------------------------------------------------------- # The model expects an 8192-point PXRD intensity pattern normalized to [0, 100]. # Replace this with your own data. pattern = np.random.rand(8192).astype(np.float32) # placeholder # Normalize to [0, 100] pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min() + 1e-10) * 100.0 x = torch.from_numpy(pattern).unsqueeze(0) # shape: (1, 8192) # 3. Inference ----------------------------------------------------------------- with torch.no_grad(): out = model(x) cs_probs = torch.softmax(out["cs_logits"], dim=-1) sg_probs = torch.softmax(out["sg_logits"], dim=-1) lp = out["lp"] # 4. Results ------------------------------------------------------------------- cs_idx = cs_probs.argmax(dim=-1).item() sg_idx = sg_probs.argmax(dim=-1).item() print(f"Crystal system : {AlphaDiffract.CRYSTAL_SYSTEMS[cs_idx]} " f"({cs_probs[0, cs_idx]:.1%})") print(f"Space group : #{sg_idx + 1} ({sg_probs[0, sg_idx]:.1%})") labels = ["a", "b", "c", "alpha", "beta", "gamma"] units = ["A", "A", "A", "deg", "deg", "deg"] print("Lattice params :") for name, val, unit in zip(labels, lp[0].tolist(), units): print(f" {name:>5s} = {val:8.3f} {unit}")