File size: 7,369 Bytes
b6d27d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
"""
ONNX Model Loader for Synapse-Base
Handles model loading and inference
CPU-optimized for HF Spaces
"""
import onnxruntime as ort
import numpy as np
import chess
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
class SynapseModel:
"""ONNX Runtime wrapper for Synapse-Base model"""
def __init__(self, model_path: str, num_threads: int = 2):
"""
Initialize model
Args:
model_path: Path to ONNX model file
num_threads: Number of CPU threads to use
"""
self.model_path = Path(model_path)
if not self.model_path.exists():
raise FileNotFoundError(f"Model not found: {model_path}")
# ONNX Runtime session options (CPU optimized)
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = num_threads
sess_options.inter_op_num_threads = num_threads
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Create session
logger.info(f"Loading model from {model_path}...")
self.session = ort.InferenceSession(
str(self.model_path),
sess_options=sess_options,
providers=['CPUExecutionProvider']
)
# Get input/output names
self.input_name = self.session.get_inputs()[0].name
self.output_names = [output.name for output in self.session.get_outputs()]
logger.info(f"✅ Model loaded: {self.input_name} -> {self.output_names}")
def fen_to_tensor(self, fen: str) -> np.ndarray:
"""
Convert FEN to 119-channel tensor
Args:
fen: FEN string
Returns:
numpy array of shape (1, 119, 8, 8)
"""
board = chess.Board(fen)
tensor = np.zeros((1, 119, 8, 8), dtype=np.float32)
# === CHANNELS 0-11: Piece Positions ===
piece_map = board.piece_map()
piece_to_channel = {
chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2,
chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5
}
for square, piece in piece_map.items():
rank = square // 8
file = square % 8
channel = piece_to_channel[piece.piece_type]
if piece.color == chess.BLACK:
channel += 6
tensor[0, channel, rank, file] = 1.0
# === CHANNELS 12-26: Game State Metadata ===
# Channel 12: Turn (1 = white to move)
tensor[0, 12, :, :] = 1.0 if board.turn == chess.WHITE else 0.0
# Channels 13-16: Castling rights
tensor[0, 13, :, :] = float(board.has_kingside_castling_rights(chess.WHITE))
tensor[0, 14, :, :] = float(board.has_queenside_castling_rights(chess.WHITE))
tensor[0, 15, :, :] = float(board.has_kingside_castling_rights(chess.BLACK))
tensor[0, 16, :, :] = float(board.has_queenside_castling_rights(chess.BLACK))
# Channel 17: En passant square
if board.ep_square is not None:
ep_rank = board.ep_square // 8
ep_file = board.ep_square % 8
tensor[0, 17, ep_rank, ep_file] = 1.0
# Channel 18: Halfmove clock (normalized)
tensor[0, 18, :, :] = min(board.halfmove_clock / 100.0, 1.0)
# Channel 19: Fullmove number (normalized)
tensor[0, 19, :, :] = min(board.fullmove_number / 100.0, 1.0)
# Channels 20-21: Check status
tensor[0, 20, :, :] = float(board.is_check() and board.turn == chess.WHITE)
tensor[0, 21, :, :] = float(board.is_check() and board.turn == chess.BLACK)
# Channels 22-26: Material count (normalized)
white_pawns = len(board.pieces(chess.PAWN, chess.WHITE))
black_pawns = len(board.pieces(chess.PAWN, chess.BLACK))
tensor[0, 22, :, :] = white_pawns / 8.0
tensor[0, 23, :, :] = black_pawns / 8.0
white_knights = len(board.pieces(chess.KNIGHT, chess.WHITE))
black_knights = len(board.pieces(chess.KNIGHT, chess.BLACK))
tensor[0, 24, :, :] = white_knights / 2.0
tensor[0, 25, :, :] = black_knights / 2.0
white_bishops = len(board.pieces(chess.BISHOP, chess.WHITE))
black_bishops = len(board.pieces(chess.BISHOP, chess.BLACK))
tensor[0, 26, :, :] = white_bishops / 2.0
# === CHANNELS 27-50: Attack Maps ===
# White attacks
for square in chess.SQUARES:
if board.is_attacked_by(chess.WHITE, square):
rank = square // 8
file = square % 8
tensor[0, 27, rank, file] = 1.0
# Black attacks
for square in chess.SQUARES:
if board.is_attacked_by(chess.BLACK, square):
rank = square // 8
file = square % 8
tensor[0, 28, rank, file] = 1.0
# === CHANNELS 51-66: Coordinate Encoding ===
# Rank encoding
for rank in range(8):
tensor[0, 51 + rank, rank, :] = 1.0
# File encoding
for file in range(8):
tensor[0, 59 + file, :, file] = 1.0
# === CHANNELS 67-118: Positional Biases (Static) ===
# Center control bonus
center_squares = [chess.D4, chess.D5, chess.E4, chess.E5]
for square in center_squares:
rank = square // 8
file = square % 8
tensor[0, 67, rank, file] = 0.5
# King safety zones
for color_offset, color in [(0, chess.WHITE), (1, chess.BLACK)]:
king_square = board.king(color)
if king_square is not None:
king_rank = king_square // 8
king_file = king_square % 8
# Mark king zone (3x3 around king)
for dr in [-1, 0, 1]:
for df in [-1, 0, 1]:
r = king_rank + dr
f = king_file + df
if 0 <= r < 8 and 0 <= f < 8:
tensor[0, 68 + color_offset, r, f] = 1.0
# Fill remaining channels with zeros (placeholder for future features)
# Channels 70-118 reserved
return tensor
def evaluate(self, fen: str) -> dict:
"""
Evaluate position
Args:
fen: FEN string
Returns:
dict with 'value' and optionally 'policy'
"""
# Convert FEN to tensor
input_tensor = self.fen_to_tensor(fen)
# Run inference
outputs = self.session.run(
self.output_names,
{self.input_name: input_tensor}
)
# Parse outputs
result = {}
# Value head (always first output)
result['value'] = float(outputs[0][0][0])
# Policy head (if available)
if len(outputs) > 1:
result['policy'] = outputs[1][0]
return result
def get_size_mb(self) -> float:
"""Get model size in MB"""
return self.model_path.stat().st_size / (1024 * 1024) |