Synapse-Base-Inference-api / model_loader.py
Rafs-an09002's picture
Create model_loader.py
b6d27d2 verified
"""
ONNX Model Loader for Synapse-Base
Handles model loading and inference
CPU-optimized for HF Spaces
"""
import onnxruntime as ort
import numpy as np
import chess
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
class SynapseModel:
"""ONNX Runtime wrapper for Synapse-Base model"""
def __init__(self, model_path: str, num_threads: int = 2):
"""
Initialize model
Args:
model_path: Path to ONNX model file
num_threads: Number of CPU threads to use
"""
self.model_path = Path(model_path)
if not self.model_path.exists():
raise FileNotFoundError(f"Model not found: {model_path}")
# ONNX Runtime session options (CPU optimized)
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = num_threads
sess_options.inter_op_num_threads = num_threads
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Create session
logger.info(f"Loading model from {model_path}...")
self.session = ort.InferenceSession(
str(self.model_path),
sess_options=sess_options,
providers=['CPUExecutionProvider']
)
# Get input/output names
self.input_name = self.session.get_inputs()[0].name
self.output_names = [output.name for output in self.session.get_outputs()]
logger.info(f"✅ Model loaded: {self.input_name} -> {self.output_names}")
def fen_to_tensor(self, fen: str) -> np.ndarray:
"""
Convert FEN to 119-channel tensor
Args:
fen: FEN string
Returns:
numpy array of shape (1, 119, 8, 8)
"""
board = chess.Board(fen)
tensor = np.zeros((1, 119, 8, 8), dtype=np.float32)
# === CHANNELS 0-11: Piece Positions ===
piece_map = board.piece_map()
piece_to_channel = {
chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2,
chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5
}
for square, piece in piece_map.items():
rank = square // 8
file = square % 8
channel = piece_to_channel[piece.piece_type]
if piece.color == chess.BLACK:
channel += 6
tensor[0, channel, rank, file] = 1.0
# === CHANNELS 12-26: Game State Metadata ===
# Channel 12: Turn (1 = white to move)
tensor[0, 12, :, :] = 1.0 if board.turn == chess.WHITE else 0.0
# Channels 13-16: Castling rights
tensor[0, 13, :, :] = float(board.has_kingside_castling_rights(chess.WHITE))
tensor[0, 14, :, :] = float(board.has_queenside_castling_rights(chess.WHITE))
tensor[0, 15, :, :] = float(board.has_kingside_castling_rights(chess.BLACK))
tensor[0, 16, :, :] = float(board.has_queenside_castling_rights(chess.BLACK))
# Channel 17: En passant square
if board.ep_square is not None:
ep_rank = board.ep_square // 8
ep_file = board.ep_square % 8
tensor[0, 17, ep_rank, ep_file] = 1.0
# Channel 18: Halfmove clock (normalized)
tensor[0, 18, :, :] = min(board.halfmove_clock / 100.0, 1.0)
# Channel 19: Fullmove number (normalized)
tensor[0, 19, :, :] = min(board.fullmove_number / 100.0, 1.0)
# Channels 20-21: Check status
tensor[0, 20, :, :] = float(board.is_check() and board.turn == chess.WHITE)
tensor[0, 21, :, :] = float(board.is_check() and board.turn == chess.BLACK)
# Channels 22-26: Material count (normalized)
white_pawns = len(board.pieces(chess.PAWN, chess.WHITE))
black_pawns = len(board.pieces(chess.PAWN, chess.BLACK))
tensor[0, 22, :, :] = white_pawns / 8.0
tensor[0, 23, :, :] = black_pawns / 8.0
white_knights = len(board.pieces(chess.KNIGHT, chess.WHITE))
black_knights = len(board.pieces(chess.KNIGHT, chess.BLACK))
tensor[0, 24, :, :] = white_knights / 2.0
tensor[0, 25, :, :] = black_knights / 2.0
white_bishops = len(board.pieces(chess.BISHOP, chess.WHITE))
black_bishops = len(board.pieces(chess.BISHOP, chess.BLACK))
tensor[0, 26, :, :] = white_bishops / 2.0
# === CHANNELS 27-50: Attack Maps ===
# White attacks
for square in chess.SQUARES:
if board.is_attacked_by(chess.WHITE, square):
rank = square // 8
file = square % 8
tensor[0, 27, rank, file] = 1.0
# Black attacks
for square in chess.SQUARES:
if board.is_attacked_by(chess.BLACK, square):
rank = square // 8
file = square % 8
tensor[0, 28, rank, file] = 1.0
# === CHANNELS 51-66: Coordinate Encoding ===
# Rank encoding
for rank in range(8):
tensor[0, 51 + rank, rank, :] = 1.0
# File encoding
for file in range(8):
tensor[0, 59 + file, :, file] = 1.0
# === CHANNELS 67-118: Positional Biases (Static) ===
# Center control bonus
center_squares = [chess.D4, chess.D5, chess.E4, chess.E5]
for square in center_squares:
rank = square // 8
file = square % 8
tensor[0, 67, rank, file] = 0.5
# King safety zones
for color_offset, color in [(0, chess.WHITE), (1, chess.BLACK)]:
king_square = board.king(color)
if king_square is not None:
king_rank = king_square // 8
king_file = king_square % 8
# Mark king zone (3x3 around king)
for dr in [-1, 0, 1]:
for df in [-1, 0, 1]:
r = king_rank + dr
f = king_file + df
if 0 <= r < 8 and 0 <= f < 8:
tensor[0, 68 + color_offset, r, f] = 1.0
# Fill remaining channels with zeros (placeholder for future features)
# Channels 70-118 reserved
return tensor
def evaluate(self, fen: str) -> dict:
"""
Evaluate position
Args:
fen: FEN string
Returns:
dict with 'value' and optionally 'policy'
"""
# Convert FEN to tensor
input_tensor = self.fen_to_tensor(fen)
# Run inference
outputs = self.session.run(
self.output_names,
{self.input_name: input_tensor}
)
# Parse outputs
result = {}
# Value head (always first output)
result['value'] = float(outputs[0][0][0])
# Policy head (if available)
if len(outputs) > 1:
result['policy'] = outputs[1][0]
return result
def get_size_mb(self) -> float:
"""Get model size in MB"""
return self.model_path.stat().st_size / (1024 * 1024)