Rafs-an09002's picture
Create engine/transposition.py
94d271b verified
"""
Transposition Table with Zobrist Hashing
Research: Stockfish uses 2GB TT, we use 400MB for Colab constraints
References:
- Zobrist (1970) - Hash functions for chess positions
- Stockfish TT - Replacement strategies
- AlphaBeta enhancements - Exact/Lower/Upper bounds
"""
import chess
import numpy as np
from typing import Optional, Dict, Tuple
from enum import Enum
class NodeType(Enum):
"""Type of transposition table entry"""
EXACT = 0 # PV-node (exact score)
LOWER_BOUND = 1 # Cut-node (beta cutoff)
UPPER_BOUND = 2 # All-node (failed low)
class TTEntry:
"""Single transposition table entry"""
__slots__ = ['zobrist_key', 'depth', 'score', 'node_type', 'best_move', 'age']
def __init__(
self,
zobrist_key: int,
depth: int,
score: float,
node_type: NodeType,
best_move: Optional[chess.Move],
age: int
):
self.zobrist_key = zobrist_key
self.depth = depth
self.score = score
self.node_type = node_type
self.best_move = best_move
self.age = age
class TranspositionTable:
"""
Zobrist-hashed transposition table
Replacement strategy: Always replace if deeper or newer
"""
def __init__(self, size_mb: int = 256):
"""
Initialize transposition table
Args:
size_mb: Table size in megabytes (default 256MB)
"""
# Calculate number of entries (each entry ~64 bytes)
bytes_per_entry = 64
self.max_entries = (size_mb * 1024 * 1024) // bytes_per_entry
# Hash table (dict for simplicity, could use array for speed)
self.table: Dict[int, TTEntry] = {}
# Statistics
self.hits = 0
self.misses = 0
self.collisions = 0
self.current_age = 0
# Zobrist keys for hashing (initialized once)
self._init_zobrist_keys()
def _init_zobrist_keys(self):
"""
Initialize Zobrist random keys
One key per (piece_type, color, square) combination
"""
np.random.seed(42) # Reproducible keys
self.zobrist_pieces = np.random.randint(
0, 2**63, size=(12, 64), dtype=np.int64
)
# Additional keys for game state
self.zobrist_turn = np.random.randint(0, 2**63, dtype=np.int64)
self.zobrist_castling = np.random.randint(0, 2**63, size=4, dtype=np.int64)
self.zobrist_ep = np.random.randint(0, 2**63, size=8, dtype=np.int64)
def compute_zobrist_key(self, board: chess.Board) -> int:
"""
Compute Zobrist hash for position
Args:
board: chess.Board
Returns:
64-bit Zobrist key
"""
key = 0
# Piece positions
piece_to_index = {
(chess.PAWN, chess.WHITE): 0,
(chess.KNIGHT, chess.WHITE): 1,
(chess.BISHOP, chess.WHITE): 2,
(chess.ROOK, chess.WHITE): 3,
(chess.QUEEN, chess.WHITE): 4,
(chess.KING, chess.WHITE): 5,
(chess.PAWN, chess.BLACK): 6,
(chess.KNIGHT, chess.BLACK): 7,
(chess.BISHOP, chess.BLACK): 8,
(chess.ROOK, chess.BLACK): 9,
(chess.QUEEN, chess.BLACK): 10,
(chess.KING, chess.BLACK): 11,
}
for square, piece in board.piece_map().items():
piece_idx = piece_to_index[(piece.piece_type, piece.color)]
key ^= self.zobrist_pieces[piece_idx, square]
# Turn
if board.turn == chess.BLACK:
key ^= self.zobrist_turn
# Castling rights
if board.has_kingside_castling_rights(chess.WHITE):
key ^= self.zobrist_castling[0]
if board.has_queenside_castling_rights(chess.WHITE):
key ^= self.zobrist_castling[1]
if board.has_kingside_castling_rights(chess.BLACK):
key ^= self.zobrist_castling[2]
if board.has_queenside_castling_rights(chess.BLACK):
key ^= self.zobrist_castling[3]
# En passant
if board.ep_square is not None:
ep_file = board.ep_square % 8
key ^= self.zobrist_ep[ep_file]
return key
def probe(
self,
zobrist_key: int,
depth: int,
alpha: float,
beta: float
) -> Optional[Tuple[float, Optional[chess.Move]]]:
"""
Probe transposition table
Args:
zobrist_key: Zobrist hash of position
depth: Current search depth
alpha: Alpha value
beta: Beta value
Returns:
(score, best_move) if usable entry found, else None
"""
entry = self.table.get(zobrist_key)
if entry is None:
self.misses += 1
return None
# Zobrist collision check
if entry.zobrist_key != zobrist_key:
self.collisions += 1
return None
# Depth check: only use if searched deeper
if entry.depth < depth:
self.misses += 1
return None
self.hits += 1
# Check if score is usable based on node type
score = entry.score
if entry.node_type == NodeType.EXACT:
return (score, entry.best_move)
elif entry.node_type == NodeType.LOWER_BOUND:
if score >= beta:
return (score, entry.best_move)
elif entry.node_type == NodeType.UPPER_BOUND:
if score <= alpha:
return (score, entry.best_move)
# Entry exists but not usable for cutoff
# Still return best_move for move ordering
return (None, entry.best_move)
def store(
self,
zobrist_key: int,
depth: int,
score: float,
node_type: NodeType,
best_move: Optional[chess.Move]
):
"""
Store entry in transposition table
Args:
zobrist_key: Zobrist hash
depth: Search depth
score: Position score
node_type: Type of node (exact/lower/upper)
best_move: Best move found
"""
# Check if we should replace existing entry
existing = self.table.get(zobrist_key)
if existing is not None:
# Always replace if:
# 1. New search is deeper
# 2. Same depth but newer (generational replacement)
if depth < existing.depth and existing.age == self.current_age:
return # Keep existing deeper entry
# Store new entry
self.table[zobrist_key] = TTEntry(
zobrist_key=zobrist_key,
depth=depth,
score=score,
node_type=node_type,
best_move=best_move,
age=self.current_age
)
# Cleanup if table too large (simple strategy)
if len(self.table) > self.max_entries:
self._cleanup_old_entries()
def _cleanup_old_entries(self):
"""Remove oldest 10% of entries"""
entries_to_remove = self.max_entries // 10
# Remove oldest entries (by age)
old_keys = sorted(
self.table.keys(),
key=lambda k: self.table[k].age
)[:entries_to_remove]
for key in old_keys:
del self.table[key]
def increment_age(self):
"""Increment generation counter (call at search start)"""
self.current_age += 1
def clear(self):
"""Clear all entries"""
self.table.clear()
self.hits = 0
self.misses = 0
self.collisions = 0
def get_stats(self) -> Dict:
"""Get table statistics"""
total_probes = self.hits + self.misses
hit_rate = (self.hits / total_probes * 100) if total_probes > 0 else 0
return {
'entries': len(self.table),
'max_entries': self.max_entries,
'usage_percent': len(self.table) / self.max_entries * 100,
'hits': self.hits,
'misses': self.misses,
'hit_rate': hit_rate,
'collisions': self.collisions
}