-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
59 lines (46 loc) · 1.96 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
This file is used to run the model.
"""
import argparse
import logging
import time
import torch
from model import BigramLanguageModel
from train import config
# Setting up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
def parse_arguments():
parser = argparse.ArgumentParser(description="Generate text using the Bigram Language Model.")
parser.add_argument("--prompt", type=str, default="", help="Initial text to start generation")
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
start_time = time.time()
logging.info("Loading model...")
logging.info("Config: %s", config)
# load the dataset
with open("data/data.txt", "r") as f:
text = f.read()
chars = sorted(list(set(text)))
# Encoding
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder
decode = lambda l: "".join([itos[i] for i in l]) # decoder
# Example model load
state_dict = torch.load("models/model_9999.pt", map_location=torch.device(config.device))
model = BigramLanguageModel(config=config)
model.load_state_dict(state_dict)
logging.info("Model loaded in {:.2f}s".format(time.time() - start_time))
# generate
start_generation_time = time.time()
initial_context = args.prompt if args.prompt else " "
if len(initial_context) > 256:
raise ValueError("Initial context must be less than 256 characters.")
context = torch.tensor([encode(initial_context)], dtype=torch.long)
generated_text = decode(model.generate(context, max_new_tokens=2000)[0].tolist())
# save the generated text in generations/ folder
with open(f"generations/{generated_text[:6]}.txt", "w") as f:
f.write(generated_text)
logging.info(generated_text)
logging.info("Generation completed in {:.2f}s".format(time.time() - start_generation_time))