| import re |
| from typing import List, Dict |
| from tqdm import tqdm |
|
|
| |
| input_path = "/home/data/raw/test/1159-L6.parquet" |
| output_path = "/home/data/raw/test/1159-L6_format.parquet" |
|
|
| EOT_TAIL = re.compile(r"<\|eot_id\|>\s*$") |
| TAIL_TAGS = re.compile(r"[<]?\|eot_id\|[>]?\s*$") |
| PIPE_TRAIL = re.compile(r"(?:\|[A-Za-z0-9_]+\|[^\n]*)\s*$") |
| def is_mistral_format(text): |
| return "<|im_start|>" in text and "<|im_end|>" in text |
|
|
| def convert_to_mistral_format(text: str, add_generation_prompt: bool = False) -> str: |
| |
| |
| |
| |
| |
| |
| |
| if "<|im_start|>" in text and "<|im_end|>" in text: |
| text = re.sub(r"\s*<\|im_end\|>", r"\n<|im_end|>", text) |
| text = re.sub(r"\n{3,}", "\n\n", text) |
| text = re.sub(r"\s*<\|im_end\|>\s*$", "", text).rstrip() |
| return text |
| output = "" |
|
|
| |
| |
| if "<|start_header_id|>" in text and "<|end_header_id|>" in text: |
| segments = re.split(r"<\|start_header_id\|>(.*?)<\|end_header_id\|>", text, flags=re.S) |
|
|
| role_content_pairs = [] |
| for i in range(1, len(segments), 2): |
| role = segments[i].strip() |
| content_block = segments[i + 1].strip() |
|
|
| |
| for part in re.split(r"<\|eot_id\|>", content_block): |
| part = part.strip() |
| if part: |
| role_content_pairs.append((role, part)) |
|
|
| |
| for idx, (role, content) in enumerate(role_content_pairs): |
| is_last_pair = idx == len(role_content_pairs) - 1 |
|
|
| if role == "system": |
| output += f"<|im_start|>system\n{content}\n<|im_end|>\n" |
|
|
| elif role == "user": |
| output += f"<|im_start|>user\n{content}\n<|im_end|>\n" |
|
|
| elif role == "assistant": |
| if is_last_pair: |
| |
| if not content.endswith("\n"): |
| content += "\n" |
| output += f"<|im_start|>assistant\n{content}" |
| else: |
| output += f"<|im_start|>assistant\n{content}\n<|im_end|>\n" |
|
|
| |
| elif "[INST]" in text and "[/INST]" in text: |
| system_match = re.search(r"\[SYSTEM_PROMPT\](.*?)\[/SYSTEM_PROMPT\]", text, re.S) |
| if system_match: |
| output += f"<|im_start|>system\n{system_match.group(1).strip()}\n<|im_end|>\n" |
|
|
| turns = re.findall(r"\[INST\](.*?)\[/INST\](.*?)(?=(\[INST\]|</s>|$))", text, re.S) |
| for user_msg, assistant_msg, _ in turns: |
| output += f"<|im_start|>user\n{user_msg.strip()}\n<|im_end|>\n" |
| if assistant_msg.strip(): |
| output += f"<|im_start|>assistant\n{assistant_msg.strip()}\n<|im_end|>\n" |
| else: |
| output += f"<|im_start|>assistant\n" |
|
|
|
|
| |
| elif "<start_of_turn>" in text: |
| |
| system_match = re.search(r"\[System:(.*?)\]", text, re.S) |
| if system_match: |
| output += f"<|im_start|>system\n{system_match.group(1).strip()}\n<|im_end|>\n" |
|
|
| |
| turns = re.findall(r"<start_of_turn>(user|model)\s*\n?(.*?)<end_of_turn>", text, re.S) |
|
|
| for idx, (role, content) in enumerate(turns): |
| role = "assistant" if role == "model" else "user" |
| is_last = idx == len(turns) - 1 |
|
|
| if role == "assistant" and is_last: |
| |
| if not content.endswith("\n"): |
| content += "\n" |
| output += f"<|im_start|>assistant\n{content}" |
| else: |
| output += f"<|im_start|>{role}\n{content.strip()}\n<|im_end|>\n" |
|
|
|
|
| |
| elif "<start>" in text or re.search(r"(?m)^You[::]|^.*?[::].*?$", text): |
| |
| persona_match = re.search(r"(.*?)<start>", text, re.S) |
| if persona_match: |
| output += f"<|im_start|>system\n{persona_match.group(1).strip()}\n<|im_end|>\n" |
|
|
| |
| dialogue = text.split("<start>")[-1] |
| lines = [l.strip() for l in dialogue.strip().split("\n") if ":" in l] |
|
|
| for idx, line in enumerate(lines): |
| is_last = idx == len(lines) - 1 |
|
|
| if re.match(r"^(You|User|你)[::]", line): |
| content = re.sub(r"^(You|User|你)[::]", "", line).strip() |
| output += f"<|im_start|>user\n{content}\n<|im_end|>\n" |
| else: |
| _, content = line.split(":", 1) |
| content = content.strip() |
| if is_last: |
| |
| if not content.endswith("\n"): |
| content += "\n" |
| output += f"<|im_start|>assistant\n{content}" |
| else: |
| output += f"<|im_start|>assistant\n{content}\n<|im_end|>\n" |
| |
| else: |
| return None |
|
|
| |
| output = output.strip() |
|
|
| |
| if output.endswith("<|im_start|>assistant\n<|im_end|>"): |
| output = output[:-len("<|im_end|>")].rstrip() |
|
|
| |
| |
| |
| last_assistant_pattern = r"<\|im_start\|>assistant\n([^\n<\|]{1,100}):\s*<\|im_end\|>$" |
| if re.search(last_assistant_pattern, output): |
| output = re.sub(r"<\|im_end\|>$", "", output).rstrip() |
|
|
| |
| if add_generation_prompt and not output.endswith("<|im_start|>assistant"): |
| output += f"\n<|im_start|>assistant" |
|
|
| return output.strip() |
|
|
| |
| def standardize_dataset_to_mistral_format(dataset: List[Dict]) -> List[Dict]: |
| converted = [] |
| for sample in tqdm(dataset): |
| text = sample.get("text", "") |
| new_text = convert_to_mistral_format(text) |
| if new_text: |
| converted.append({"text": new_text}) |
| return converted |
|
|
| from datasets import load_dataset |
| from tqdm import tqdm |
| import re |
|
|
|
|
| def clean_chosen_tail(text: str) -> str: |
| """删除 chosen 末尾所有非正常内容(标记符、空白等),保留正文""" |
| if not isinstance(text, str): |
| return text |
| |
| text = TAIL_TAGS.sub("", text) |
| |
| text = PIPE_TRAIL.sub("", text) |
| return text.rstrip() |
|
|
| def apply_format_conversion(example): |
| |
| for k in ("chosen", "reject"): |
| if isinstance(example[k], str): |
| example[k] = EOT_TAIL.sub("", example[k]) |
|
|
| |
| if isinstance(example.get("chosen"), str): |
| example["chosen"] = clean_chosen_tail(example["chosen"]) |
| |
| example["reject"] = clean_chosen_tail(example["reject"]) |
|
|
| |
| new_prompt = convert_to_mistral_format(example["chosen_prompt"], add_generation_prompt=False) |
| if new_prompt is None: |
| return None |
| example["chosen_prompt"] = new_prompt |
| return example |
|
|
| |
| dataset = load_dataset("parquet", data_files=input_path, split="train") |
| converted_dataset = dataset.map(apply_format_conversion) |
| converted_dataset = converted_dataset.filter(lambda example: example is not None) |
| |
| |
| converted_dataset.to_pandas().to_parquet(output_path, index=False) |
|
|
| print(f"✅ 处理完成,共保留样本 {len(converted_dataset)} 条,已保存至:{output_path}") |
|
|
| import random |
|
|
| |
| dataset1 = load_dataset("parquet", data_files=output_path, split="train") |
|
|
| indices = random.sample(range(len(dataset1)), 15) |
| samples = dataset1.select(indices) |
|
|
| |
| for idx, item in zip(indices, samples): |
| print(f"\n=== Sample index {idx} ===") |
| for key, value in item.items(): |
| print(f"[{key}]") |
| print(value) |
| print("-" * 60) |
|
|
|
|