PersonaGPT 是一款开放域对话代理,旨在完成两项任务:
它基于 DialoGPT-medium 预训练模型构建,而该模型又基于 GPT-2 架构。 此模型在 Persona-Chat 数据集上进行训练,并添加了特殊标记,以便更好地区分对话历史和双人对话中的个性特征。此外,还采用了一些主动学习方法来训练模型,使其能够使用轮次级目标进行受控解码。
预处理、训练和实现细节可在 personaGPT 代码库 中找到。
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
tokenizer = GPT2Tokenizer.from_pretrained("af1tang/personaGPT")
model = GPT2LMHeadModel.from_pretrained("af1tang/personaGPT")
if torch.cuda.is_available():
model = model.cuda()
## utility functions ##
flatten = lambda l: [item for sublist in l for item in sublist]
def to_data(x):
if torch.cuda.is_available():
x = x.cpu()
return x.data.numpy()
def to_var(x):
if not torch.is_tensor(x):
x = torch.Tensor(x)
if torch.cuda.is_available():
x = x.cuda()
return x
def display_dialog_history(dialog_hx):
for j, line in enumerate(dialog_hx):
msg = tokenizer.decode(line)
if j %2 == 0:
print(">> User: "+ msg)
else:
print("Bot: "+msg)
print()
def generate_next(bot_input_ids, do_sample=True, top_k=10, top_p=.92,
max_length=1000, pad_token=tokenizer.eos_token_id):
full_msg = model.generate(bot_input_ids, do_sample=True,
top_k=top_k, top_p=top_p,
max_length=max_length, pad_token_id=tokenizer.eos_token_id)
msg = to_data(full_msg.detach()[0])[bot_input_ids.shape[-1]:]
return msg# get personality facts for conversation
personas = []
for i in range(3):
response = input(">> Fact %d: "%(i+1))+ tokenizer.eos_token
personas.append(response)
personas = tokenizer.encode(''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>']))# converse for 8 turns
dialog_hx = []
for step in range(8):
# encode the user input
user_inp = tokenizer.encode(input(">> User: ") + tokenizer.eos_token)
# append to the chat history
dialog_hx.append(user_inp)
# generated a response while limiting the total chat history to 1000 tokens,
bot_input_ids = to_var([personas + flatten(dialog_hx)]).long()
msg = generate_next(bot_input_ids)
dialog_hx.append(msg)
print("Bot: {}".format(tokenizer.decode(msg, skip_special_tokens=True)))个性化解码示例:
| 人物设定事实 | |
|---|---|
| 1 | 我64岁了,但感觉自己还很年轻。 |
| 2 | 我努力吃得健康,但每周会吃一次麦当劳。 |
| 3 | 我后悔过去20年一直当医生。 |
| 4 | 我的秘密爱好是制作自助类YouTube视频。 |
| 5 | 我去过西班牙很多次,希望能再去一次。 |
| 角色 | 回应 |
|---|---|
| 人类 | 哟,最近咋样啊,伙计? |
| 机器人 | 你好,最近过得怎么样? |
| 人类 | 还行,你今天在干嘛呢? |
| 机器人 | 到了60多岁,现在吃得健康多了。 |
| 人类 | 我更喜欢快餐。 |
| 机器人 | 我也是,但为了健康,我会尽量避免。 |
| 人类 | 你平时有什么娱乐活动? |
| 机器人 | 我喜欢制作视频,也很注重健康饮食。 |
| 人类 | 不错啊,是什么类型的视频? |
| 机器人 | 我喜欢制作关于自助内容的视频。 |
| 人类 | 我觉得我需要一些帮助。 |
| 机器人 | 你需要哪方面的帮助呢? |
| 人类 | 我21岁了,感觉人生很迷茫。 |
| 机器人 | 我能理解,我才65岁呢。 |
## available actions ##
action_space = [ 'ask about kids.', "ask about pets.", 'talk about work.',
'ask about marital status.', 'talk about travel.', 'ask about age and gender.',
'ask about hobbies.', 'ask about favorite food.', 'talk about movies.',
'talk about music.', 'talk about politics.']
# converse for 8 turns
dialog_hx = []
for step in range(8):
# choose an action
act = None
while act not in action_space:
display_dialog_history(dialog_hx)
print()
print(" actions: ")
for k,v in enumerate(action_space): print(k,v)
try:
act = action_space[int(input(" input [0-10]: " ))]
except:
act = None
print()
# format into prefix code
action_prefix = tokenizer.encode(''.join(['<|act|> '] + [act] + ['<|p1|>'] + [] + ['<|sep|>'] + ['<|start|>']))
bot_input_ids = to_var([action_prefix + flatten(dialog_hx)]).long()
# generate query conditioned on action
msg = generate_next(bot_input_ids, top_k=top_k, top_p=top_p, max_length=max_length)
dialog_hx.append(msg)
# generate bot response
bot_input_ids = to_var([personas+ flatten(dialog_hx)]).long()
msg = generate_next(bot_input_ids, top_k=top_k, top_p=top_p, max_length=max_length)
dialog_hx.append(msg)
display_dialog_history(dialog_hx)受控响应生成示例:
| Bot 2 的角色设定事实 | |
|---|---|
| 1 | 我是一名安静的工程师。 |
| 2 | 我单身,正在寻找爱情。 |
| 3 | 遗憾的是,我没有什么能引起共鸣的爱好。 |
| 4 | 不过幸运的是,我身材高大且热爱运动。 |
| 5 | 周五晚上,我独自看《辛普森一家》的重播。 |
| 角色 | 动作 | 响应 |
|---|---|---|
| Bot 1 | 谈论电影 | 你好,你喜欢电影吗? |
| Bot 2 | - | 我喜欢《辛普森一家》。你呢? |
| Bot 1 | 谈论音乐 | 我喜欢《行尸走肉》。你会演奏乐器吗? |
| Bot 2 | - | 是的,我会拉小提琴。 |
| Bot 1 | 询问婚姻状况 | 你已婚还是单身? |
| Bot 2 | - | 我单身,但正在寻找爱情。 |
| Bot 1 | 谈论工作 | 我的工作是在工厂上班。你的职业是什么? |
| Bot 2 | - | 工程师。我很安静,所以没人注意到我。 |
| Bot 1 | 谈论爱好 | 你有什么爱好吗? |
| Bot 2 | - | 我看《辛普森一家》的重播。 |
| Bot 1 | 询问喜欢的食物 | 你最喜欢的食物是什么? |
| Bot 2 | - | 我喜欢披萨。你呢? |
| Bot 1 | 询问宠物 | 我也喜欢披萨。你喜欢动物吗? |
| Bot 2 | - | 我有两只狗。你的职业是什么? |
| Bot 1 | 谈论工作 | 我是一名工厂工人。你的梦想工作是什么? |
| Bot 2 | - | 我希望有一天能成为一名作家。 |