| | """ |
| | Sequence Prediction Dataset Generator. |
| | |
| | Generates image pairs for sequence prediction tasks with various |
| | mathematical sequences (arithmetic, geometric, fibonacci, etc.) |
| | """ |
| |
|
| | import json |
| | import random |
| | from pathlib import Path |
| | from typing import Callable |
| |
|
| | import matplotlib.pyplot as plt |
| | import matplotlib.patches as patches |
| |
|
| |
|
| | |
| |
|
| | def arithmetic_seq(start: int, diff: int, length: int = 4) -> list[int]: |
| | """Arithmetic sequence: a, a+d, a+2d, ...""" |
| | return [start + i * diff for i in range(length)] |
| |
|
| |
|
| | def geometric_seq(start: int, ratio: int, length: int = 4) -> list[int]: |
| | """Geometric sequence: a, a*r, a*r^2, ...""" |
| | return [start * (ratio ** i) for i in range(length)] |
| |
|
| |
|
| | def square_seq(start: int, length: int = 4) -> list[int]: |
| | """Square numbers: n^2, (n+1)^2, ...""" |
| | return [(start + i) ** 2 for i in range(length)] |
| |
|
| |
|
| | def cube_seq(start: int, length: int = 4) -> list[int]: |
| | """Cube numbers: n^3, (n+1)^3, ...""" |
| | return [(start + i) ** 3 for i in range(length)] |
| |
|
| |
|
| | def triangular_seq(start: int, length: int = 4) -> list[int]: |
| | """Triangular numbers: n(n+1)/2""" |
| | return [(start + i) * (start + i + 1) // 2 for i in range(length)] |
| |
|
| |
|
| | def fibonacci_like_seq(a: int, b: int, length: int = 4) -> list[int]: |
| | """Fibonacci-like: a, b, a+b, a+2b, ...""" |
| | seq = [a, b] |
| | for _ in range(length - 2): |
| | seq.append(seq[-1] + seq[-2]) |
| | return seq[:length] |
| |
|
| |
|
| | def prime_seq(start_idx: int, length: int = 4) -> list[int]: |
| | """Prime numbers starting from index.""" |
| | primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47] |
| | return primes[start_idx:start_idx + length] |
| |
|
| |
|
| | def power_of_two_seq(start: int, length: int = 4) -> list[int]: |
| | """Powers of 2: 2^n, 2^(n+1), ...""" |
| | return [2 ** (start + i) for i in range(length)] |
| |
|
| |
|
| | def factorial_seq(start: int, length: int = 4) -> list[int]: |
| | """Factorial sequence: n!, (n+1)!, ...""" |
| | from math import factorial |
| | return [factorial(start + i) for i in range(length)] |
| |
|
| |
|
| | |
| |
|
| | SEQUENCE_TYPES = { |
| | "arithmetic": lambda rng: arithmetic_seq( |
| | rng.randint(1, 20), rng.randint(1, 10) |
| | ), |
| | "arithmetic_neg": lambda rng: arithmetic_seq( |
| | rng.randint(20, 50), -rng.randint(1, 5) |
| | ), |
| | "geometric_2": lambda rng: geometric_seq( |
| | rng.randint(1, 5), 2 |
| | ), |
| | "geometric_3": lambda rng: geometric_seq( |
| | rng.randint(1, 3), 3 |
| | ), |
| | "square": lambda rng: square_seq(rng.randint(1, 10)), |
| | "cube": lambda rng: cube_seq(rng.randint(1, 5)), |
| | "triangular": lambda rng: triangular_seq(rng.randint(1, 10)), |
| | "fibonacci": lambda rng: fibonacci_like_seq( |
| | rng.randint(1, 5), rng.randint(1, 5) |
| | ), |
| | "prime": lambda rng: prime_seq(rng.randint(0, 10)), |
| | "power_of_2": lambda rng: power_of_two_seq(rng.randint(0, 6)), |
| | } |
| |
|
| |
|
| | def generate_sequence_pair(seq: list[int]) -> tuple[list, list]: |
| | """ |
| | Generate a pair of sequences for the task. |
| | |
| | Returns: |
| | (partial, complete): partial has last element as "", complete is full. |
| | """ |
| | partial = seq[:-1] + [""] |
| | return partial, seq |
| |
|
| |
|
| | |
| |
|
| | def round_to_multiple(x: int, multiple: int = 16) -> int: |
| | """Round x up to nearest multiple.""" |
| | return ((x + multiple - 1) // multiple) * multiple |
| |
|
| |
|
| | def create_number_grid( |
| | numbers: list, |
| | save_path: str, |
| | height: int = 224, |
| | width: int = 896, |
| | fontsize: int = 48, |
| | size_multiple: int = 16, |
| | ) -> None: |
| | """ |
| | Create a 1xN grid image with numbers in each cell. |
| | |
| | Args: |
| | numbers: List of numbers/strings to display. |
| | save_path: Output file path. |
| | height: Target height in pixels (will be rounded to size_multiple). |
| | width: Target width in pixels (will be rounded to size_multiple). |
| | fontsize: Font size for the numbers. |
| | size_multiple: Ensure dimensions are multiples of this (default 16). |
| | """ |
| | from PIL import Image |
| | |
| | n = len(numbers) |
| | |
| | |
| | width = round_to_multiple(width, size_multiple) |
| | height = round_to_multiple(height, size_multiple) |
| | |
| | |
| | dpi = 100 |
| | fig_width = width / dpi |
| | fig_height = height / dpi |
| | |
| | fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=dpi) |
| | fig.subplots_adjust(left=0, right=1, top=1, bottom=0) |
| | |
| | for i, num in enumerate(numbers): |
| | rect = patches.Rectangle( |
| | (i, 0), 1, 1, linewidth=2, |
| | edgecolor='black', facecolor='white' |
| | ) |
| | ax.add_patch(rect) |
| | ax.text( |
| | i + 0.5, 0.5, str(num), fontsize=fontsize, |
| | ha='center', va='center', fontweight='bold' |
| | ) |
| | |
| | ax.set_xlim(0, n) |
| | ax.set_ylim(0, 1) |
| | ax.set_aspect('equal') |
| | ax.axis('off') |
| | |
| | |
| | fig.savefig(save_path, dpi=dpi, facecolor='white', edgecolor='none') |
| | plt.close(fig) |
| | |
| | |
| | img = Image.open(save_path) |
| | if img.size != (width, height): |
| | img = img.resize((width, height), Image.Resampling.LANCZOS) |
| | img.save(save_path) |
| |
|
| |
|
| | |
| |
|
| | class SequenceDatasetGenerator: |
| | """Generate sequence prediction dataset with train/test splits.""" |
| | |
| | def __init__( |
| | self, |
| | output_dir: str, |
| | seed: int = 42, |
| | num_pairs: tuple[int, int] = (2, 3), |
| | seq_types: list[str] | None = None, |
| | image_height: int = 224, |
| | image_width: int = 896, |
| | fontsize: int = 48, |
| | ): |
| | """ |
| | Args: |
| | output_dir: Directory to save the dataset. |
| | seed: Random seed for reproducibility. |
| | num_pairs: Range of pairs per sample (min, max inclusive). |
| | seq_types: List of sequence types to use (None = all). |
| | image_height: Image height in pixels (rounded to 16). |
| | image_width: Image width in pixels (rounded to 16). |
| | fontsize: Font size for numbers. |
| | """ |
| | self.output_dir = Path(output_dir) |
| | self.rng = random.Random(seed) |
| | self.num_pairs = num_pairs |
| | self.seq_types = seq_types or list(SEQUENCE_TYPES.keys()) |
| | self.image_height = round_to_multiple(image_height, 16) |
| | self.image_width = round_to_multiple(image_width, 16) |
| | self.fontsize = fontsize |
| | |
| | |
| | for split in ["train", "test"]: |
| | (self.output_dir / split / "images").mkdir(parents=True, exist_ok=True) |
| | |
| | def _generate_sample(self, sample_id: int) -> dict: |
| | """Generate a single sample with multiple sequence pairs.""" |
| | num_pairs = self.rng.randint(*self.num_pairs) |
| | seq_type = self.rng.choice(self.seq_types) |
| | |
| | |
| | base_seq = SEQUENCE_TYPES[seq_type](self.rng) |
| | |
| | pairs = [] |
| | for i in range(num_pairs): |
| | |
| | if seq_type.startswith("arithmetic"): |
| | diff = base_seq[1] - base_seq[0] |
| | seq = [x + i * diff for x in base_seq] |
| | elif seq_type.startswith("geometric"): |
| | ratio = base_seq[1] // base_seq[0] if base_seq[0] != 0 else 2 |
| | seq = [x * (ratio ** i) for x in base_seq] |
| | else: |
| | |
| | seq = [x + i for x in base_seq] |
| | |
| | partial, complete = generate_sequence_pair(seq) |
| | pairs.append({ |
| | "partial": partial, |
| | "complete": complete, |
| | "answer": complete[-1], |
| | }) |
| | |
| | return { |
| | "id": sample_id, |
| | "seq_type": seq_type, |
| | "num_pairs": num_pairs, |
| | "pairs": pairs, |
| | } |
| | |
| | def _save_sample_images( |
| | self, sample: dict, split: str, include_last_answer: bool = True |
| | ) -> dict: |
| | """Save images for a sample and return metadata.""" |
| | sample_id = sample["id"] |
| | image_dir = self.output_dir / split / "images" |
| | |
| | images = [] |
| | img_idx = 0 |
| | |
| | for i, pair in enumerate(sample["pairs"]): |
| | |
| | partial_path = f"{sample_id:05d}_{img_idx}.png" |
| | create_number_grid( |
| | pair["partial"], image_dir / partial_path, |
| | height=self.image_height, width=self.image_width, |
| | fontsize=self.fontsize, |
| | ) |
| | images.append(partial_path) |
| | img_idx += 1 |
| | |
| | |
| | is_last = (i == sample["num_pairs"] - 1) |
| | if include_last_answer or not is_last: |
| | complete_path = f"{sample_id:05d}_{img_idx}.png" |
| | create_number_grid( |
| | pair["complete"], image_dir / complete_path, |
| | height=self.image_height, width=self.image_width, |
| | fontsize=self.fontsize, |
| | ) |
| | images.append(complete_path) |
| | img_idx += 1 |
| | |
| | return { |
| | "id": sample_id, |
| | "seq_type": sample["seq_type"], |
| | "num_pairs": sample["num_pairs"], |
| | "images": images, |
| | "answer": sample["pairs"][-1]["answer"], |
| | "sequences": [p["complete"] for p in sample["pairs"]], |
| | } |
| | |
| | def generate(self, num_train: int, num_test: int) -> None: |
| | """ |
| | Generate the full dataset. |
| | |
| | Args: |
| | num_train: Number of training samples. |
| | num_test: Number of test samples. |
| | """ |
| | train_meta, test_meta = [], [] |
| | |
| | |
| | print(f"Generating {num_train} training samples...") |
| | for i in range(num_train): |
| | sample = self._generate_sample(i) |
| | meta = self._save_sample_images(sample, "train", include_last_answer=True) |
| | train_meta.append(meta) |
| | if (i + 1) % 50 == 0: |
| | print(f" Train: {i + 1}/{num_train}") |
| | |
| | |
| | print(f"Generating {num_test} test samples...") |
| | for i in range(num_test): |
| | sample = self._generate_sample(num_train + i) |
| | meta = self._save_sample_images(sample, "test", include_last_answer=False) |
| | test_meta.append(meta) |
| | if (i + 1) % 50 == 0: |
| | print(f" Test: {i + 1}/{num_test}") |
| | |
| | |
| | with open(self.output_dir / "train.json", "w") as f: |
| | json.dump(train_meta, f, indent=2) |
| | with open(self.output_dir / "test.json", "w") as f: |
| | json.dump(test_meta, f, indent=2) |
| | |
| | print(f"\nDataset saved to {self.output_dir}") |
| | print(f" Train: {num_train} samples") |
| | print(f" Test: {num_test} samples") |
| | print(f" Image size: {self.image_width}x{self.image_height}") |
| | print(f" Sequence types: {self.seq_types}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | generator = SequenceDatasetGenerator( |
| | output_dir="/home/claude/sequence_dataset", |
| | seed=42, |
| | num_pairs=(2, 3), |
| | ) |
| | generator.generate(num_train=100, num_test=20) |