| | |
| | """ |
| | Helion-2.5-Rnd Batch Inference |
| | Efficient batch processing for large-scale inference tasks |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import logging |
| | import time |
| | from pathlib import Path |
| | from typing import Dict, List, Optional, Union |
| |
|
| | import pandas as pd |
| | from tqdm import tqdm |
| |
|
| | from inference.client import HelionClient |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class BatchProcessor: |
| | """Process large batches of inference requests""" |
| | |
| | def __init__( |
| | self, |
| | client: HelionClient, |
| | batch_size: int = 10, |
| | max_retries: int = 3, |
| | retry_delay: float = 1.0 |
| | ): |
| | """ |
| | Initialize batch processor |
| | |
| | Args: |
| | client: HelionClient instance |
| | batch_size: Number of requests to process concurrently |
| | max_retries: Maximum retry attempts for failed requests |
| | retry_delay: Delay between retries in seconds |
| | """ |
| | self.client = client |
| | self.batch_size = batch_size |
| | self.max_retries = max_retries |
| | self.retry_delay = retry_delay |
| | |
| | self.stats = { |
| | 'total': 0, |
| | 'successful': 0, |
| | 'failed': 0, |
| | 'total_time': 0.0, |
| | 'avg_time_per_request': 0.0 |
| | } |
| | |
| | def process_prompts( |
| | self, |
| | prompts: List[str], |
| | temperature: float = 0.7, |
| | max_tokens: int = 1024, |
| | **kwargs |
| | ) -> List[Dict]: |
| | """ |
| | Process a list of prompts |
| | |
| | Args: |
| | prompts: List of input prompts |
| | temperature: Sampling temperature |
| | max_tokens: Maximum tokens per response |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | List of results with prompt, response, and metadata |
| | """ |
| | results = [] |
| | start_time = time.time() |
| | |
| | logger.info(f"Processing {len(prompts)} prompts...") |
| | |
| | for i in tqdm(range(0, len(prompts), self.batch_size)): |
| | batch = prompts[i:i + self.batch_size] |
| | |
| | for prompt in batch: |
| | result = self._process_single_with_retry( |
| | prompt, |
| | temperature, |
| | max_tokens, |
| | **kwargs |
| | ) |
| | results.append(result) |
| | |
| | |
| | self.stats['total'] = len(prompts) |
| | self.stats['successful'] = sum(1 for r in results if r['success']) |
| | self.stats['failed'] = len(prompts) - self.stats['successful'] |
| | self.stats['total_time'] = time.time() - start_time |
| | self.stats['avg_time_per_request'] = self.stats['total_time'] / len(prompts) |
| | |
| | logger.info(f"Batch processing complete. Success rate: {self.stats['successful']}/{self.stats['total']}") |
| | |
| | return results |
| | |
| | def _process_single_with_retry( |
| | self, |
| | prompt: str, |
| | temperature: float, |
| | max_tokens: int, |
| | **kwargs |
| | ) -> Dict: |
| | """Process single prompt with retry logic""" |
| | for attempt in range(self.max_retries): |
| | try: |
| | start = time.time() |
| | response = self.client.complete( |
| | prompt=prompt, |
| | temperature=temperature, |
| | max_tokens=max_tokens, |
| | **kwargs |
| | ) |
| | duration = time.time() - start |
| | |
| | return { |
| | 'prompt': prompt, |
| | 'response': response, |
| | 'success': True, |
| | 'duration': duration, |
| | 'attempts': attempt + 1 |
| | } |
| | |
| | except Exception as e: |
| | logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") |
| | |
| | if attempt < self.max_retries - 1: |
| | time.sleep(self.retry_delay) |
| | else: |
| | return { |
| | 'prompt': prompt, |
| | 'response': None, |
| | 'success': False, |
| | 'error': str(e), |
| | 'attempts': attempt + 1 |
| | } |
| | |
| | def process_chat_conversations( |
| | self, |
| | conversations: List[List[Dict]], |
| | temperature: float = 0.7, |
| | max_tokens: int = 1024, |
| | **kwargs |
| | ) -> List[Dict]: |
| | """ |
| | Process chat conversations in batch |
| | |
| | Args: |
| | conversations: List of message lists |
| | temperature: Sampling temperature |
| | max_tokens: Maximum tokens per response |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | List of conversation results |
| | """ |
| | results = [] |
| | start_time = time.time() |
| | |
| | logger.info(f"Processing {len(conversations)} conversations...") |
| | |
| | for conv in tqdm(conversations): |
| | try: |
| | start = time.time() |
| | response = self.client.chat( |
| | messages=conv, |
| | temperature=temperature, |
| | max_tokens=max_tokens, |
| | **kwargs |
| | ) |
| | duration = time.time() - start |
| | |
| | results.append({ |
| | 'conversation': conv, |
| | 'response': response, |
| | 'success': True, |
| | 'duration': duration |
| | }) |
| | |
| | except Exception as e: |
| | logger.error(f"Conversation processing failed: {str(e)}") |
| | results.append({ |
| | 'conversation': conv, |
| | 'response': None, |
| | 'success': False, |
| | 'error': str(e) |
| | }) |
| | |
| | total_time = time.time() - start_time |
| | successful = sum(1 for r in results if r['success']) |
| | |
| | logger.info(f"Processed {successful}/{len(conversations)} conversations in {total_time:.2f}s") |
| | |
| | return results |
| | |
| | def process_file( |
| | self, |
| | input_file: str, |
| | output_file: str, |
| | prompt_column: str = "prompt", |
| | temperature: float = 0.7, |
| | max_tokens: int = 1024, |
| | **kwargs |
| | ) -> pd.DataFrame: |
| | """ |
| | Process prompts from file |
| | |
| | Args: |
| | input_file: Input CSV/JSON file path |
| | output_file: Output file path |
| | prompt_column: Column name containing prompts |
| | temperature: Sampling temperature |
| | max_tokens: Maximum tokens per response |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | DataFrame with results |
| | """ |
| | |
| | input_path = Path(input_file) |
| | |
| | if input_path.suffix == '.csv': |
| | df = pd.read_csv(input_path) |
| | elif input_path.suffix == '.json': |
| | df = pd.read_json(input_path) |
| | else: |
| | raise ValueError(f"Unsupported file format: {input_path.suffix}") |
| | |
| | if prompt_column not in df.columns: |
| | raise ValueError(f"Column '{prompt_column}' not found in input file") |
| | |
| | |
| | prompts = df[prompt_column].tolist() |
| | results = self.process_prompts( |
| | prompts, |
| | temperature=temperature, |
| | max_tokens=max_tokens, |
| | **kwargs |
| | ) |
| | |
| | |
| | df['response'] = [r['response'] for r in results] |
| | df['success'] = [r['success'] for r in results] |
| | df['duration'] = [r.get('duration', None) for r in results] |
| | df['error'] = [r.get('error', None) for r in results] |
| | |
| | |
| | output_path = Path(output_file) |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| | |
| | if output_path.suffix == '.csv': |
| | df.to_csv(output_path, index=False) |
| | elif output_path.suffix == '.json': |
| | df.to_json(output_path, orient='records', indent=2) |
| | else: |
| | raise ValueError(f"Unsupported output format: {output_path.suffix}") |
| | |
| | logger.info(f"Results saved to {output_path}") |
| | |
| | return df |
| | |
| | def get_statistics(self) -> Dict: |
| | """Get processing statistics""" |
| | return self.stats.copy() |
| |
|
| |
|
| | class DatasetProcessor: |
| | """Process specific dataset formats""" |
| | |
| | def __init__(self, client: HelionClient): |
| | self.client = client |
| | self.processor = BatchProcessor(client) |
| | |
| | def process_qa_dataset( |
| | self, |
| | questions: List[str], |
| | contexts: Optional[List[str]] = None, |
| | temperature: float = 0.3, |
| | max_tokens: int = 512 |
| | ) -> List[Dict]: |
| | """Process question-answering dataset""" |
| | prompts = [] |
| | |
| | for i, question in enumerate(questions): |
| | if contexts and i < len(contexts): |
| | prompt = f"Context: {contexts[i]}\n\nQuestion: {question}\n\nAnswer:" |
| | else: |
| | prompt = f"Question: {question}\n\nAnswer:" |
| | |
| | prompts.append(prompt) |
| | |
| | return self.processor.process_prompts( |
| | prompts, |
| | temperature=temperature, |
| | max_tokens=max_tokens |
| | ) |
| | |
| | def process_code_dataset( |
| | self, |
| | tasks: List[str], |
| | languages: Optional[List[str]] = None, |
| | temperature: float = 0.2, |
| | max_tokens: int = 1024 |
| | ) -> List[Dict]: |
| | """Process code generation tasks""" |
| | prompts = [] |
| | |
| | for i, task in enumerate(tasks): |
| | lang = languages[i] if languages and i < len(languages) else "python" |
| | prompt = f"Write a {lang} function to: {task}\n\n```{lang}\n" |
| | prompts.append(prompt) |
| | |
| | return self.processor.process_prompts( |
| | prompts, |
| | temperature=temperature, |
| | max_tokens=max_tokens |
| | ) |
| | |
| | def process_translation_dataset( |
| | self, |
| | texts: List[str], |
| | source_lang: str, |
| | target_lang: str, |
| | temperature: float = 0.3, |
| | max_tokens: int = 1024 |
| | ) -> List[Dict]: |
| | """Process translation tasks""" |
| | prompts = [] |
| | |
| | for text in texts: |
| | prompt = f"Translate the following text from {source_lang} to {target_lang}:\n\n{text}\n\nTranslation:" |
| | prompts.append(prompt) |
| | |
| | return self.processor.process_prompts( |
| | prompts, |
| | temperature=temperature, |
| | max_tokens=max_tokens |
| | ) |
| | |
| | def process_summarization_dataset( |
| | self, |
| | documents: List[str], |
| | max_summary_length: int = 150, |
| | temperature: float = 0.5, |
| | max_tokens: int = 512 |
| | ) -> List[Dict]: |
| | """Process document summarization""" |
| | prompts = [] |
| | |
| | for doc in documents: |
| | prompt = f"Summarize the following document in {max_summary_length} words or less:\n\n{doc}\n\nSummary:" |
| | prompts.append(prompt) |
| | |
| | return self.processor.process_prompts( |
| | prompts, |
| | temperature=temperature, |
| | max_tokens=max_tokens |
| | ) |
| |
|
| |
|
| | def main(): |
| | """Main batch processing entry point""" |
| | parser = argparse.ArgumentParser(description="Batch inference with Helion") |
| | parser.add_argument("--base-url", type=str, default="http://localhost:8000") |
| | parser.add_argument("--input", type=str, required=True, help="Input file (CSV/JSON)") |
| | parser.add_argument("--output", type=str, required=True, help="Output file (CSV/JSON)") |
| | parser.add_argument("--prompt-column", type=str, default="prompt") |
| | parser.add_argument("--temperature", type=float, default=0.7) |
| | parser.add_argument("--max-tokens", type=int, default=1024) |
| | parser.add_argument("--batch-size", type=int, default=10) |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | client = HelionClient(base_url=args.base_url) |
| | processor = BatchProcessor(client, batch_size=args.batch_size) |
| | |
| | |
| | df = processor.process_file( |
| | input_file=args.input, |
| | output_file=args.output, |
| | prompt_column=args.prompt_column, |
| | temperature=args.temperature, |
| | max_tokens=args.max_tokens |
| | ) |
| | |
| | |
| | stats = processor.get_statistics() |
| | logger.info("\nProcessing Statistics:") |
| | logger.info(f"Total requests: {stats['total']}") |
| | logger.info(f"Successful: {stats['successful']}") |
| | logger.info(f"Failed: {stats['failed']}") |
| | logger.info(f"Total time: {stats['total_time']:.2f}s") |
| | logger.info(f"Avg time per request: {stats['avg_time_per_request']:.2f}s") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |