| | |
| | ''' |
| | @File : inference_cogview.py |
| | @Time : 2021/10/09 19:41:58 |
| | @Author : Ming Ding |
| | @Contact : [email protected] |
| | ''' |
| |
|
| | |
| | import os |
| | import sys |
| | import math |
| | import random |
| | import torch |
| | import argparse |
| | import stat |
| |
|
| | from SwissArmyTransformer import mpu, get_args, get_tokenizer |
| | from SwissArmyTransformer.model import CachedAutoregressiveModel |
| | from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy |
| | from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence |
| | from SwissArmyTransformer.generation.utils import timed_name, generate_continually |
| | from SwissArmyTransformer.training import set_random_seed |
| |
|
| | import json |
| |
|
| | def main(args): |
| |
|
| | ''' |
| | 2022/06/17 |
| | Modify load_checkpoint to from_pretraind |
| | ''' |
| | |
| | |
| | |
| | model_path = '/path/to/checkpoints/' |
| | |
| | model, args = CachedAutoregressiveModel.from_pretrained(args, model_path) |
| |
|
| | if args.fp16: |
| | model = model.half() |
| | model = model.to(args.device) |
| | set_random_seed(args.seed) |
| | model.eval() |
| | |
| | tokenizer = get_tokenizer(args) |
| | |
| | |
| | end_tokens = [tokenizer.get_command('eos').Id] |
| | strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens) |
| | |
| | def process(raw_text): |
| | if args.with_id: |
| | query_id, raw_text = raw_text.split('\t') |
| | raw_text = json.loads(raw_text) |
| | question=raw_text["question"] + "答:" |
| | raw_text = question |
| | seq = tokenizer._encode(raw_text) |
| | if len(seq) != 0 and seq[0] == 20005: |
| | seq = seq[1:] |
| | seq = [tokenizer.get_command('ENC').Id] + seq |
| | seq += [-1] * (args.max_sequence_length - len(seq)) |
| | if len(seq) > args.max_sequence_length: |
| | raise ValueError('text too long.') |
| | |
| | seq = torch.cuda.LongTensor(seq, device=args.device) |
| | mbz = args.max_inference_batch_size |
| | assert args.batch_size < mbz or args.batch_size % mbz == 0 |
| | output_list = [] |
| | for tim in range(max(args.batch_size // mbz, 1)): |
| | output = filling_sequence(model, seq.clone(), |
| | batch_size=min(args.batch_size, mbz), |
| | strategy=strategy, |
| | log_attention_weights=None |
| | )[0] |
| | if isinstance(output, torch.Tensor): |
| | output = list(output) |
| |
|
| | output_list.extend(output) |
| | |
| | for i in range(len(output_list)): |
| | output = output_list[i].tolist() |
| | try: |
| | unfinished = output.index(-1) |
| | except ValueError: |
| | unfinished = len(output) |
| | if output[unfinished - 1] in end_tokens: |
| | unfinished -= 1 |
| | output_list[i] = output[1:unfinished] |
| | bog = output.index(tokenizer.get_command('eos').Id) |
| | output_list[i] = output[1:bog] + output[bog+1:unfinished] |
| | |
| | |
| | txts = [] |
| | for seq in output_list: |
| | decode_tokens = tokenizer.DecodeIds(seq) |
| | txts.append(decode_tokens) |
| | |
| | |
| | if args.with_id: |
| | full_path = os.path.join(args.output_path, query_id + '.txt') |
| | else: |
| | prefix = raw_text.replace('/', '')[:20] |
| | full_path = timed_name(prefix, '.txt', args.output_path) |
| | print(txts[0]) |
| | test_eval_path = os.path.join(args.output_path, 'test_eval.txt') |
| | with open(test_eval_path, 'a', encoding='utf-8') as fout: |
| | fout.write(txts[0] + '\n') |
| | os.chmod(test_eval_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU) |
| |
|
| | os.makedirs(args.output_path, exist_ok=True) |
| | generate_continually(process, args.input_source) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | py_parser = argparse.ArgumentParser(add_help=False) |
| | |
| | known, args_list = py_parser.parse_known_args() |
| | args = get_args(args_list) |
| | args = argparse.Namespace(**vars(args), **vars(known)) |
| | args.do_train = False |
| | |
| | with torch.no_grad(): |
| | main(args) |
| |
|
| |
|