Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import pyarrow as pa | |
| import torch | |
| import typer | |
| from rich.progress import Progress, TextColumn | |
| from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator | |
| def main( | |
| input_file: str, | |
| output_file: str, | |
| patching_device: str = "cuda", | |
| log_step: int = 10_000, | |
| entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", | |
| dry_run: bool = False, | |
| ): | |
| # TODO: Modify this to work with the new code | |
| raise NotImplementedError() | |
| iterator = ArrowFileIterator( | |
| file_path=input_file, | |
| worker_id=0, | |
| num_workers=1, | |
| ) | |
| tokenization_mode = "bytes" | |
| print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") | |
| print("Loading entropy model", entropy_model_checkpoint_dir) | |
| if dry_run: | |
| return | |
| entropy_model = load_entropy_model( | |
| entropy_model_checkpoint_dir, device=patching_device | |
| ) | |
| entropy_model, _ = to_device(entropy_model, patching_device) | |
| print("Creating patcher") | |
| patching_batch_size = 32 | |
| print("Creating tokenizer") | |
| tokenizer = Tokenizer( | |
| model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", | |
| tokenization_mode=tokenization_mode, | |
| # BYTE_UNITS | |
| vocab_size_unit_1=256, | |
| bos=True, | |
| eos=True, | |
| bpe_delim=False, | |
| # This isn't used, just stores a reference for other calls we don't use | |
| patcher=None, | |
| ) | |
| step = 0 | |
| print("starting") | |
| start_time = time.time() | |
| patch_time = 0 | |
| entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False) | |
| sample_id_field = pa.field("sample_id", pa.string(), nullable=False) | |
| text_field = pa.field("text", pa.string(), nullable=False) | |
| schema = pa.schema([sample_id_field, text_field, entropy_field]) | |
| arrow_batch_size = 1_000 | |
| try: | |
| with pa.OSFile(output_file, "wb") as sink: | |
| with pa.ipc.new_file(sink, schema) as writer: | |
| id_buffer = [] | |
| entropies_buffer = [] | |
| text_buffer = [] | |
| with Progress( | |
| *Progress.get_default_columns(), | |
| TextColumn("Completed: {task.completed}"), | |
| ) as progress: | |
| task = progress.add_task( | |
| "[green]Calculating entropies...", total=None | |
| ) | |
| for doc in iterator: | |
| sample_id = get_id_from_doc(doc) | |
| if "text" in doc: | |
| text = doc["text"] | |
| elif "content" in doc: | |
| text = doc["content"] | |
| else: | |
| raise ValueError( | |
| f"Could not find a text key from: {doc.keys()}" | |
| ) | |
| tokens = torch.tensor(tokenizer.encode(text)) | |
| patch_start = time.time() | |
| scores = calculate_entropies( | |
| tokens, | |
| entropy_model, | |
| patching_batch_size, | |
| patching_device, | |
| ) | |
| entropies_buffer.append( | |
| np.array(scores.tolist(), dtype=np.float16) | |
| ) | |
| id_buffer.append(sample_id) | |
| text_buffer.append(text) | |
| if len(entropies_buffer) == arrow_batch_size: | |
| batch = pa.record_batch( | |
| { | |
| "entropies": entropies_buffer, | |
| "sample_id": id_buffer, | |
| "text": text_buffer, | |
| }, | |
| schema, | |
| ) | |
| writer.write(batch) | |
| entropies_buffer = [] | |
| id_buffer = [] | |
| text_buffer = [] | |
| patch_time += time.time() - patch_start | |
| step += 1 | |
| if step % log_step == 0: | |
| print("Completed steps:", step) | |
| progress.update(task, advance=1) | |
| if len(entropies_buffer) > 0: | |
| # Write last things | |
| batch = pa.record_batch( | |
| { | |
| "entropies": entropies_buffer, | |
| "sample_id": id_buffer, | |
| "text": text_buffer, | |
| }, | |
| schema, | |
| ) | |
| writer.write(batch) | |
| entropies_buffer = [] | |
| id_buffer = [] | |
| text_buffer = [] | |
| Path(f"{output_file}.complete").touch() | |
| except: | |
| Path(output_file).unlink(missing_ok=True) | |
| raise | |
| elapsed = time.time() - start_time | |
| print("steps", step) | |
| print("done in:", elapsed) | |
| if __name__ == "__main__": | |
| typer.run(main) | |