File size: 2,045 Bytes
fcbe832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
from pathlib import Path
from llama import ModelArgs, Transformer, Tokenizer, LLaMA


def load(
        ckpt_dir: str,
        tokenizer_path: str,
        max_seq_len: int,
        max_batch_size: int,
) -> LLaMA:
    print("Creating model...")
    start_time = time.time()
    checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))

    with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

    model_args: ModelArgs = ModelArgs(
        max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
    )

    tokenizer = Tokenizer(model_path=tokenizer_path)
    model_args.vocab_size = tokenizer.n_words

    model = Transformer(model_args)
    model.to("cpu")

    print("Loading merged checkpoint...")
    checkpoint = torch.load(checkpoints[-1], map_location="cuda")
    model.load_state_dict(checkpoint, strict=False)
    del checkpoint

    generator = LLaMA(model, tokenizer)
    print(f"Loaded model in {time.time() - start_time:.2f} seconds")
    return generator


def main(
        ckpt_dir: str = './model',
        tokenizer_path: str = './tokenizer/tokenizer.model',
        temperature: float = 0.8,
        top_p: float = 0.95,
        max_seq_len: int = 256,  # up to 2048
        max_batch_size: int = 5,
):
    # torch.manual_seed(1)
    torch.set_default_dtype(torch.bfloat16)

    generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)

    while True:
        prompt = input(f'prompt> ')
        if len(prompt.strip()) > 0:
            prompts = [prompt]
            results = generator.generate(
                prompts, max_gen_len=256, temperature=temperature, top_p=top_p
            )

            for result in results:
                print(result)


if __name__ == "__main__":
    fire.Fire(main)