""" ONNX Model Loader for Synapse-Base Handles model loading and inference CPU-optimized for HF Spaces """ import onnxruntime as ort import numpy as np import chess import logging from pathlib import Path logger = logging.getLogger(__name__) class SynapseModel: """ONNX Runtime wrapper for Synapse-Base model""" def __init__(self, model_path: str, num_threads: int = 2): """ Initialize model Args: model_path: Path to ONNX model file num_threads: Number of CPU threads to use """ self.model_path = Path(model_path) if not self.model_path.exists(): raise FileNotFoundError(f"Model not found: {model_path}") # ONNX Runtime session options (CPU optimized) sess_options = ort.SessionOptions() sess_options.intra_op_num_threads = num_threads sess_options.inter_op_num_threads = num_threads sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL # Create session logger.info(f"Loading model from {model_path}...") self.session = ort.InferenceSession( str(self.model_path), sess_options=sess_options, providers=['CPUExecutionProvider'] ) # Get input/output names self.input_name = self.session.get_inputs()[0].name self.output_names = [output.name for output in self.session.get_outputs()] logger.info(f"✅ Model loaded: {self.input_name} -> {self.output_names}") def fen_to_tensor(self, fen: str) -> np.ndarray: """ Convert FEN to 119-channel tensor Args: fen: FEN string Returns: numpy array of shape (1, 119, 8, 8) """ board = chess.Board(fen) tensor = np.zeros((1, 119, 8, 8), dtype=np.float32) # === CHANNELS 0-11: Piece Positions === piece_map = board.piece_map() piece_to_channel = { chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2, chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5 } for square, piece in piece_map.items(): rank = square // 8 file = square % 8 channel = piece_to_channel[piece.piece_type] if piece.color == chess.BLACK: channel += 6 tensor[0, channel, rank, file] = 1.0 # === CHANNELS 12-26: Game State Metadata === # Channel 12: Turn (1 = white to move) tensor[0, 12, :, :] = 1.0 if board.turn == chess.WHITE else 0.0 # Channels 13-16: Castling rights tensor[0, 13, :, :] = float(board.has_kingside_castling_rights(chess.WHITE)) tensor[0, 14, :, :] = float(board.has_queenside_castling_rights(chess.WHITE)) tensor[0, 15, :, :] = float(board.has_kingside_castling_rights(chess.BLACK)) tensor[0, 16, :, :] = float(board.has_queenside_castling_rights(chess.BLACK)) # Channel 17: En passant square if board.ep_square is not None: ep_rank = board.ep_square // 8 ep_file = board.ep_square % 8 tensor[0, 17, ep_rank, ep_file] = 1.0 # Channel 18: Halfmove clock (normalized) tensor[0, 18, :, :] = min(board.halfmove_clock / 100.0, 1.0) # Channel 19: Fullmove number (normalized) tensor[0, 19, :, :] = min(board.fullmove_number / 100.0, 1.0) # Channels 20-21: Check status tensor[0, 20, :, :] = float(board.is_check() and board.turn == chess.WHITE) tensor[0, 21, :, :] = float(board.is_check() and board.turn == chess.BLACK) # Channels 22-26: Material count (normalized) white_pawns = len(board.pieces(chess.PAWN, chess.WHITE)) black_pawns = len(board.pieces(chess.PAWN, chess.BLACK)) tensor[0, 22, :, :] = white_pawns / 8.0 tensor[0, 23, :, :] = black_pawns / 8.0 white_knights = len(board.pieces(chess.KNIGHT, chess.WHITE)) black_knights = len(board.pieces(chess.KNIGHT, chess.BLACK)) tensor[0, 24, :, :] = white_knights / 2.0 tensor[0, 25, :, :] = black_knights / 2.0 white_bishops = len(board.pieces(chess.BISHOP, chess.WHITE)) black_bishops = len(board.pieces(chess.BISHOP, chess.BLACK)) tensor[0, 26, :, :] = white_bishops / 2.0 # === CHANNELS 27-50: Attack Maps === # White attacks for square in chess.SQUARES: if board.is_attacked_by(chess.WHITE, square): rank = square // 8 file = square % 8 tensor[0, 27, rank, file] = 1.0 # Black attacks for square in chess.SQUARES: if board.is_attacked_by(chess.BLACK, square): rank = square // 8 file = square % 8 tensor[0, 28, rank, file] = 1.0 # === CHANNELS 51-66: Coordinate Encoding === # Rank encoding for rank in range(8): tensor[0, 51 + rank, rank, :] = 1.0 # File encoding for file in range(8): tensor[0, 59 + file, :, file] = 1.0 # === CHANNELS 67-118: Positional Biases (Static) === # Center control bonus center_squares = [chess.D4, chess.D5, chess.E4, chess.E5] for square in center_squares: rank = square // 8 file = square % 8 tensor[0, 67, rank, file] = 0.5 # King safety zones for color_offset, color in [(0, chess.WHITE), (1, chess.BLACK)]: king_square = board.king(color) if king_square is not None: king_rank = king_square // 8 king_file = king_square % 8 # Mark king zone (3x3 around king) for dr in [-1, 0, 1]: for df in [-1, 0, 1]: r = king_rank + dr f = king_file + df if 0 <= r < 8 and 0 <= f < 8: tensor[0, 68 + color_offset, r, f] = 1.0 # Fill remaining channels with zeros (placeholder for future features) # Channels 70-118 reserved return tensor def evaluate(self, fen: str) -> dict: """ Evaluate position Args: fen: FEN string Returns: dict with 'value' and optionally 'policy' """ # Convert FEN to tensor input_tensor = self.fen_to_tensor(fen) # Run inference outputs = self.session.run( self.output_names, {self.input_name: input_tensor} ) # Parse outputs result = {} # Value head (always first output) result['value'] = float(outputs[0][0][0]) # Policy head (if available) if len(outputs) > 1: result['policy'] = outputs[1][0] return result def get_size_mb(self) -> float: """Get model size in MB""" return self.model_path.stat().st_size / (1024 * 1024)