| """ |
| Example: load AlphaDiffract and run inference on a PXRD pattern. |
| |
| Requirements: |
| pip install torch safetensors numpy |
| """ |
|
|
| import numpy as np |
| import torch |
| from model import AlphaDiffract |
|
|
| |
| model = AlphaDiffract.from_pretrained(".", device="cpu") |
|
|
| |
| |
| |
| pattern = np.random.rand(8192).astype(np.float32) |
|
|
| |
| pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min() + 1e-10) * 100.0 |
| x = torch.from_numpy(pattern).unsqueeze(0) |
|
|
| |
| with torch.no_grad(): |
| out = model(x) |
|
|
| cs_probs = torch.softmax(out["cs_logits"], dim=-1) |
| sg_probs = torch.softmax(out["sg_logits"], dim=-1) |
| lp = out["lp"] |
|
|
| |
| cs_idx = cs_probs.argmax(dim=-1).item() |
| sg_idx = sg_probs.argmax(dim=-1).item() |
|
|
| print(f"Crystal system : {AlphaDiffract.CRYSTAL_SYSTEMS[cs_idx]} " |
| f"({cs_probs[0, cs_idx]:.1%})") |
| print(f"Space group : #{sg_idx + 1} ({sg_probs[0, sg_idx]:.1%})") |
|
|
| labels = ["a", "b", "c", "alpha", "beta", "gamma"] |
| units = ["A", "A", "A", "deg", "deg", "deg"] |
| print("Lattice params :") |
| for name, val, unit in zip(labels, lp[0].tolist(), units): |
| print(f" {name:>5s} = {val:8.3f} {unit}") |
|
|