OpenAlphaDiffract / example_inference.py
linked-liszt's picture
Upload folder using huggingface_hub
8a10305 verified
"""
Example: load AlphaDiffract and run inference on a PXRD pattern.
Requirements:
pip install torch safetensors numpy
"""
import numpy as np
import torch
from model import AlphaDiffract
# 1. Load model ---------------------------------------------------------------
model = AlphaDiffract.from_pretrained(".", device="cpu") # or "cuda"
# 2. Prepare input -------------------------------------------------------------
# The model expects an 8192-point PXRD intensity pattern normalized to [0, 100].
# Replace this with your own data.
pattern = np.random.rand(8192).astype(np.float32) # placeholder
# Normalize to [0, 100]
pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min() + 1e-10) * 100.0
x = torch.from_numpy(pattern).unsqueeze(0) # shape: (1, 8192)
# 3. Inference -----------------------------------------------------------------
with torch.no_grad():
out = model(x)
cs_probs = torch.softmax(out["cs_logits"], dim=-1)
sg_probs = torch.softmax(out["sg_logits"], dim=-1)
lp = out["lp"]
# 4. Results -------------------------------------------------------------------
cs_idx = cs_probs.argmax(dim=-1).item()
sg_idx = sg_probs.argmax(dim=-1).item()
print(f"Crystal system : {AlphaDiffract.CRYSTAL_SYSTEMS[cs_idx]} "
f"({cs_probs[0, cs_idx]:.1%})")
print(f"Space group : #{sg_idx + 1} ({sg_probs[0, sg_idx]:.1%})")
labels = ["a", "b", "c", "alpha", "beta", "gamma"]
units = ["A", "A", "A", "deg", "deg", "deg"]
print("Lattice params :")
for name, val, unit in zip(labels, lp[0].tolist(), units):
print(f" {name:>5s} = {val:8.3f} {unit}")