File size: 7,369 Bytes
b6d27d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""
ONNX Model Loader for Synapse-Base
Handles model loading and inference
CPU-optimized for HF Spaces
"""

import onnxruntime as ort
import numpy as np
import chess
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


class SynapseModel:
    """ONNX Runtime wrapper for Synapse-Base model"""
    
    def __init__(self, model_path: str, num_threads: int = 2):
        """
        Initialize model
        
        Args:
            model_path: Path to ONNX model file
            num_threads: Number of CPU threads to use
        """
        self.model_path = Path(model_path)
        
        if not self.model_path.exists():
            raise FileNotFoundError(f"Model not found: {model_path}")
        
        # ONNX Runtime session options (CPU optimized)
        sess_options = ort.SessionOptions()
        sess_options.intra_op_num_threads = num_threads
        sess_options.inter_op_num_threads = num_threads
        sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        
        # Create session
        logger.info(f"Loading model from {model_path}...")
        self.session = ort.InferenceSession(
            str(self.model_path),
            sess_options=sess_options,
            providers=['CPUExecutionProvider']
        )
        
        # Get input/output names
        self.input_name = self.session.get_inputs()[0].name
        self.output_names = [output.name for output in self.session.get_outputs()]
        
        logger.info(f"✅ Model loaded: {self.input_name} -> {self.output_names}")
    
    def fen_to_tensor(self, fen: str) -> np.ndarray:
        """
        Convert FEN to 119-channel tensor
        
        Args:
            fen: FEN string
        
        Returns:
            numpy array of shape (1, 119, 8, 8)
        """
        board = chess.Board(fen)
        tensor = np.zeros((1, 119, 8, 8), dtype=np.float32)
        
        # === CHANNELS 0-11: Piece Positions ===
        piece_map = board.piece_map()
        piece_to_channel = {
            chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2,
            chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5
        }
        
        for square, piece in piece_map.items():
            rank = square // 8
            file = square % 8
            channel = piece_to_channel[piece.piece_type]
            if piece.color == chess.BLACK:
                channel += 6
            tensor[0, channel, rank, file] = 1.0
        
        # === CHANNELS 12-26: Game State Metadata ===
        # Channel 12: Turn (1 = white to move)
        tensor[0, 12, :, :] = 1.0 if board.turn == chess.WHITE else 0.0
        
        # Channels 13-16: Castling rights
        tensor[0, 13, :, :] = float(board.has_kingside_castling_rights(chess.WHITE))
        tensor[0, 14, :, :] = float(board.has_queenside_castling_rights(chess.WHITE))
        tensor[0, 15, :, :] = float(board.has_kingside_castling_rights(chess.BLACK))
        tensor[0, 16, :, :] = float(board.has_queenside_castling_rights(chess.BLACK))
        
        # Channel 17: En passant square
        if board.ep_square is not None:
            ep_rank = board.ep_square // 8
            ep_file = board.ep_square % 8
            tensor[0, 17, ep_rank, ep_file] = 1.0
        
        # Channel 18: Halfmove clock (normalized)
        tensor[0, 18, :, :] = min(board.halfmove_clock / 100.0, 1.0)
        
        # Channel 19: Fullmove number (normalized)
        tensor[0, 19, :, :] = min(board.fullmove_number / 100.0, 1.0)
        
        # Channels 20-21: Check status
        tensor[0, 20, :, :] = float(board.is_check() and board.turn == chess.WHITE)
        tensor[0, 21, :, :] = float(board.is_check() and board.turn == chess.BLACK)
        
        # Channels 22-26: Material count (normalized)
        white_pawns = len(board.pieces(chess.PAWN, chess.WHITE))
        black_pawns = len(board.pieces(chess.PAWN, chess.BLACK))
        tensor[0, 22, :, :] = white_pawns / 8.0
        tensor[0, 23, :, :] = black_pawns / 8.0
        
        white_knights = len(board.pieces(chess.KNIGHT, chess.WHITE))
        black_knights = len(board.pieces(chess.KNIGHT, chess.BLACK))
        tensor[0, 24, :, :] = white_knights / 2.0
        tensor[0, 25, :, :] = black_knights / 2.0
        
        white_bishops = len(board.pieces(chess.BISHOP, chess.WHITE))
        black_bishops = len(board.pieces(chess.BISHOP, chess.BLACK))
        tensor[0, 26, :, :] = white_bishops / 2.0
        
        # === CHANNELS 27-50: Attack Maps ===
        # White attacks
        for square in chess.SQUARES:
            if board.is_attacked_by(chess.WHITE, square):
                rank = square // 8
                file = square % 8
                tensor[0, 27, rank, file] = 1.0
        
        # Black attacks
        for square in chess.SQUARES:
            if board.is_attacked_by(chess.BLACK, square):
                rank = square // 8
                file = square % 8
                tensor[0, 28, rank, file] = 1.0
        
        # === CHANNELS 51-66: Coordinate Encoding ===
        # Rank encoding
        for rank in range(8):
            tensor[0, 51 + rank, rank, :] = 1.0
        
        # File encoding
        for file in range(8):
            tensor[0, 59 + file, :, file] = 1.0
        
        # === CHANNELS 67-118: Positional Biases (Static) ===
        # Center control bonus
        center_squares = [chess.D4, chess.D5, chess.E4, chess.E5]
        for square in center_squares:
            rank = square // 8
            file = square % 8
            tensor[0, 67, rank, file] = 0.5
        
        # King safety zones
        for color_offset, color in [(0, chess.WHITE), (1, chess.BLACK)]:
            king_square = board.king(color)
            if king_square is not None:
                king_rank = king_square // 8
                king_file = king_square % 8
                
                # Mark king zone (3x3 around king)
                for dr in [-1, 0, 1]:
                    for df in [-1, 0, 1]:
                        r = king_rank + dr
                        f = king_file + df
                        if 0 <= r < 8 and 0 <= f < 8:
                            tensor[0, 68 + color_offset, r, f] = 1.0
        
        # Fill remaining channels with zeros (placeholder for future features)
        # Channels 70-118 reserved
        
        return tensor
    
    def evaluate(self, fen: str) -> dict:
        """
        Evaluate position
        
        Args:
            fen: FEN string
        
        Returns:
            dict with 'value' and optionally 'policy'
        """
        # Convert FEN to tensor
        input_tensor = self.fen_to_tensor(fen)
        
        # Run inference
        outputs = self.session.run(
            self.output_names,
            {self.input_name: input_tensor}
        )
        
        # Parse outputs
        result = {}
        
        # Value head (always first output)
        result['value'] = float(outputs[0][0][0])
        
        # Policy head (if available)
        if len(outputs) > 1:
            result['policy'] = outputs[1][0]
        
        return result
    
    def get_size_mb(self) -> float:
        """Get model size in MB"""
        return self.model_path.stat().st_size / (1024 * 1024)