File size: 1,588 Bytes
8a10305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
Example: load AlphaDiffract and run inference on a PXRD pattern.

Requirements:
    pip install torch safetensors numpy
"""

import numpy as np
import torch
from model import AlphaDiffract

# 1. Load model ---------------------------------------------------------------
model = AlphaDiffract.from_pretrained(".", device="cpu")  # or "cuda"

# 2. Prepare input -------------------------------------------------------------
# The model expects an 8192-point PXRD intensity pattern normalized to [0, 100].
# Replace this with your own data.
pattern = np.random.rand(8192).astype(np.float32)  # placeholder

# Normalize to [0, 100]
pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min() + 1e-10) * 100.0
x = torch.from_numpy(pattern).unsqueeze(0)  # shape: (1, 8192)

# 3. Inference -----------------------------------------------------------------
with torch.no_grad():
    out = model(x)

cs_probs = torch.softmax(out["cs_logits"], dim=-1)
sg_probs = torch.softmax(out["sg_logits"], dim=-1)
lp = out["lp"]

# 4. Results -------------------------------------------------------------------
cs_idx = cs_probs.argmax(dim=-1).item()
sg_idx = sg_probs.argmax(dim=-1).item()

print(f"Crystal system : {AlphaDiffract.CRYSTAL_SYSTEMS[cs_idx]} "
      f"({cs_probs[0, cs_idx]:.1%})")
print(f"Space group    : #{sg_idx + 1} ({sg_probs[0, sg_idx]:.1%})")

labels = ["a", "b", "c", "alpha", "beta", "gamma"]
units = ["A", "A", "A", "deg", "deg", "deg"]
print("Lattice params :")
for name, val, unit in zip(labels, lp[0].tolist(), units):
    print(f"  {name:>5s} = {val:8.3f} {unit}")