-
Notifications
You must be signed in to change notification settings - Fork 5
/
chat.py
98 lines (81 loc) · 3.87 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-
"""
Created on Sun May 12 16:34:49 2019
@author: Danish
"""
"""
Script for chatting with a trained chatbot model
"""
import datetime
from os import path
import general_utils
import chat_command_handler
from chat_settings import ChatSettings
from chatbot_model import ChatbotModel
from vocabulary import Vocabulary
"""'models/cornell_movie_dialog/trained_model_v2/best_weights_training.ckpt'"""
#Read the hyperparameters and configure paths
_, model_dir, hparams, checkpoint, _, _ = general_utils.initialize_session("chat")
#Load the vocabulary
print()
print("Loading vocabulary...")
if hparams.model_hparams.share_embedding:
shared_vocab_filepath = path.join(model_dir, Vocabulary.SHARED_VOCAB_FILENAME)
input_vocabulary = Vocabulary.load(shared_vocab_filepath)
output_vocabulary = input_vocabulary
else:
input_vocab_filepath = path.join(model_dir, Vocabulary.INPUT_VOCAB_FILENAME)
input_vocabulary = Vocabulary.load(input_vocab_filepath)
output_vocab_filepath = path.join(model_dir, Vocabulary.OUTPUT_VOCAB_FILENAME)
output_vocabulary = Vocabulary.load(output_vocab_filepath)
# Setting up the chat
chatlog_filepath = path.join(model_dir, "chat_logs", "chatlog_{0}.txt".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S")))
chat_settings = ChatSettings(hparams.model_hparams, hparams.inference_hparams)
def chat_fun_english(question, model):
terminate_chat = False
reload_model = False
while not terminate_chat:
#Create the model
print()
print("Initializing model..." if not reload_model else "Re-initializing model...")
print()
with ChatbotModel(mode = "infer",
model_hparams = chat_settings.model_hparams,
input_vocabulary = input_vocabulary,
output_vocabulary = output_vocabulary,
model_dir = model_dir) as model:
#Load the weights
print()
print("Loading model weights...")
print()
model.load(checkpoint)
#Show the commands
if not reload_model:
chat_command_handler.print_commands()
while True:
#Get the input and check if it is a question or a command, and execute if it is a command
#question = input("You: ")
is_command, terminate_chat, reload_model = chat_command_handler.handle_command(question, model, chat_settings)
if terminate_chat or reload_model:
break
elif is_command:
continue
else:
#If it is not a command (it is a question), pass it on to the chatbot model to get the answer
question_with_history, answer = model.chat(question, chat_settings)
#Print the answer or answer beams and log to chat log
if chat_settings.show_question_context:
print("Question with history (context): {0}".format(question_with_history))
print("\n1st if")
if chat_settings.show_all_beams:
for i in range(len(answer)):
print("ChatBot (Beam {0}): {1}".format(i, answer[i]))
print("\n2nd if")
else:
print("ChatBot: {0}".format(answer))
#print("\n else")
print()
if chat_settings.inference_hparams.log_chat:
chat_command_handler.append_to_chatlog(chatlog_filepath, question, answer)
return answer
#chat_fun()