Paddle对话故事续写, 创造你的无限世界

  • 如果可以,你最想穿越到哪部电影,小说里?
  • 这次利用 GPT2 文本生成和 wechaty 对话展开故事续写,与 AI 互动共同创造剧情
  • 预先设置你想交互的世界场景,比如面对三体人的进攻,发挥你的脑洞输入你想到的任何动作或者剧情大纲,模型的续写保证让你赞(hei)不(ren)绝(wen)口(hao)

项目介绍

本地测试

github 中有具体的本机部署流程

本 notebook 中代码则是运行在命令行体验

灵感来源

  • 跑团:一个桌面角色扮演游戏,主持人描述剧情,其他人来推动剧情,每个人有自己的职业、技能等等。

  • 无限流: 穿越到各个电影,动漫,小说中与原剧情互动。参考作品:幻影英雄,魔方大厦,无限恐怖。

  • AI Dungeon : 这是世界上第一款真正意义上纯 AI生成的文字冒险游戏, 通过这款 AIDungeon 2 ,玩家将接触到一个完全自由的虚拟世界。整个游戏世界,都将围绕着玩家的输入展开与续写,可谓真正的 Everything is in my hands

    https://play.aidungeon.io/main/home

生成样例展示

  • 示例1:
主神阿巴:冰冷,抖动……醒来的瞬间,你猛的从地面跳了起来,惊慌的看向四周,脑海里的办公室环境和眼前的环境瞬间出现了混淆,几秒之后你从混淆里清醒过来。现在你来到一个城市之中,丧尸在街道上嘶吼,你意识到核弹将在一小时后降临。

我:环顾四周,看看有没有防身的武器

主神阿巴: 是否有可能存活下去。 

我:突然在旁边发现一把手枪,装满了子弹,这时,一只丧尸冲了出来

主神阿巴:毫不犹豫的向丧尸头部开枪,然后丧尸倒地,亡。

我:看看周围有没有能发动的汽车,尽快逃离这座城市

主神阿巴:丧尸围了过来,你不得以跑向旁边的建筑,此时你来到一个地下实验室,旁边冰冷的警报声在提示着你T病毒已经泄露,一管病毒试剂摆在你面前。
  • 示例2:
我: 最后我开始反攻三体世界

主神阿巴:借助科技,我终于变出了超英赶美,猩红的双眼中闪烁着人类坚定的眼神,超英基地将被我攻陷,下一秒,无数丧尸朝我冲来,记住用喷火器,嗯,我是革命战士,拿起手中的那根电锯,万物一体,指向丧尸大迁徙的方向......

我:最后赶超英美,达到世界和平

主神阿巴:达到世界霸权,走向威武雄狮的彩虹之路......(好像越来越不押韵了) 先更到这里,有时间再来补个结局<eod>改行IT 到底有没有前途? 我是一个毕业了一年的 程序员 ,苦逼的生活让我找份工作都是挑战。无处可逃的你跑到了“自然选择”号飞船的队伍之中,你意识到三体人的舰队即将抵达,章北海会劫持飞船逃离太阳系。

我: ???

B站链接

安装依赖

# 注意想体验 gpt2 必须安装这个版本
!pip install paddlehub==2.0.4 paddlenlp==2.0.0rc14 sentencepiece==0.1.92

创建Model并加载官方预训练模型

  • PaddleNLP 中内置了许多预训练语义模型参数,其中就包括中文 GPT2 模型
  • 通过 GPT2ForPretraining.from_pretrained('gpt2-base-cn') 就可以调用内置的预训练模型参数
  • 安装高版本 paddlenlp 则可以 GPTForPretraining.from_pretrained('gpt-cpm-small-cn-distill') 体验CPM等模型
import paddle
import paddle.nn as nn
import numpy as np
import paddlehub as hub
from paddlenlp.transformers import *
import queue

class Model():
    def __init__(self):
        # self.model = GPTForPretraining.from_pretrained('gpt-cpm-small-cn-distill')
        self.model = GPT2ForPretraining.from_pretrained('gpt2-base-cn')
        self.model.eval()
        # self.tokenizer = GPTChineseTokenizer.from_pretrained('gpt-cpm-small-cn-distill')
        self.tokenizer = GPT2ChineseTokenizer.from_pretrained('gpt2-base-cn')
        self.tokenizer.encode('.') # 初始化


    def top_k_top_p_filtering(self, logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
        top_k = min(top_k, logits.shape[-1])  # Safety check
        logits_np = logits.numpy()
        if top_k > 0:
            # Remove all tokens with a probability less than the last token of the top-k
            indices_to_remove = logits_np < np.sort(logits_np)[-top_k]
            logits_np[indices_to_remove] = filter_value

        if top_p < 1.0:
            sorted_logits = paddle.sort(logits, descending=True)
            sorted_indices = paddle.argsort(logits, descending=True).numpy()
            cumulative_probs = paddle.cumsum(paddle.nn.functional.softmax(
                sorted_logits, axis=-1), axis=-1).numpy()

            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            # Shift the indices to the right to keep also the first token above the threshold
            sorted_indices_to_remove[...,
                                        1:] = sorted_indices_to_remove[..., :-1]
            sorted_indices_to_remove[..., 0] = 0

            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits_np[indices_to_remove] = filter_value

        return paddle.to_tensor(logits_np)

    def sample(self, text, max_len=32, end_word='。', repitition_penalty=1.0, temperature=1.0, top_p=0.9):
        with paddle.no_grad():
            # 终止标志
            if end_word is not None:
                stop_id = self.tokenizer.encode(end_word)
                if 'input_ids' in stop_id: # 判断使用 cpm 的情况
                    stop_id = stop_id['input_ids']
                stop_id = stop_id[-1]

            ids = self.tokenizer.encode(text)
            if 'input_ids' in ids: # 判断使用 cpm 的情况 
                ids = ids['input_ids']
                input_ids = paddle.to_tensor(ids).unsqueeze(0)
            else:
                input_ids = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
            output, cached_kvs = self.model(input_ids, use_cache=True)
            next_token_logits = output[0, -1, :]
            for id in set(ids):
                next_token_logits[id] /= repitition_penalty
            next_token_logits = next_token_logits / temperature
            filtered_logits = self.top_k_top_p_filtering(next_token_logits, top_k=0, top_p=1.0)
            next_token = paddle.multinomial(paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
            ids += [int(next_token)]
            response = [int(next_token)] # 只要后面生成的话

            for i in range(max_len):
                input_id = paddle.to_tensor(np.array([next_token]).reshape(1, -1).astype('int64'))
                output, cached_kvs = self.model(input_id, use_cache=True, cache=cached_kvs)
                next_token_logits = output[0, -1, :]
                for id in set(ids):
                    next_token_logits[id] /= repitition_penalty
                next_token_logits = next_token_logits / temperature
                filtered_logits = self.top_k_top_p_filtering(next_token_logits, top_k=0, top_p=1.0)
                next_token = paddle.multinomial(paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
                ids += [int(next_token)]
                response += [int(next_token)]

                # 根据终止标志停止预测
                if (end_word is not None) and (int(next_token) == stop_id):
                    break

            # 如果使用 cpm, 截止8.3日的tokenizer还没有deocde方法
            if hasattr(self.tokenizer, 'convert_ids_to_string'):
                return self.tokenizer.convert_ids_to_string(response)

            return self.tokenizer.decode(response)


设定世界观

background_events = [
    '现在你来到一个城市之中,丧尸在街道上嘶吼,你好像是来到了是浣熊市,而核弹将在一小时后降临。',
    '历经千辛万苦,你来到一个地下实验室,旁边冰冷的警报声在提示着你T病毒已经泄露,一管病毒试剂摆在你面前。',
    '你来到了三体世界,你身上沾染的病毒不小心泄露到了这个世界,现在整个世界都在面临生化危机。',
    '无处可逃的你跑到了“自然选择”号飞船的队伍之中,你意识到三体人的舰队即将抵达,章北海会劫持飞船逃离太阳系。',
]

class World():
    def __init__(self):
        self.init = False
        self.model = Model()

    def sample(self, sentence):
        outputs = self.model.sample(
            sentence, # 输入文本
            max_len=128, # 最大生成文本的长度
            end_word='。', # 碰到句号生成结束
        )
        return outputs

    def start(self):
        self.context_list = [
            '冰冷,抖动……醒来的瞬间,你猛的从地面跳了起来,惊慌的看向四周,脑海里的办公室环境和眼前的环境瞬间出现了混淆,几秒之后你从混淆里清醒过来。',
        ]
        self.lastPoint = 1
        self.init = True
        self.worlds_event = queue.Queue(maxsize=0)
        for i in background_events:
            i = i.replace('\n','')
            self.worlds_event.put(i)
        event = self.worlds_event.get()
        self.context_list.append(event)
        return ''.join(self.context_list)

    def receive(self, user_input):
        if self.worlds_event.empty():
            print('你的冒险已经结束')
            return '你的冒险已经结束'
        user_output = ''
        inputs = user_input + ','
        self.context_list.append(inputs)
        context = ''.join(self.context_list) # 可以选择完整的历史记录
        # context = ''.join(self.context_list[-2:]) # 选择最近的几条聊天历史
        # context = self.summary(self.context_list) # 对之前的对话历史做文本摘要
        user_output = self.sample(context)
        self.context_list.append(user_output)
        self.lastPoint += 1

        if self.lastPoint % 4 == 0:
            if self.worlds_event.empty():
                print(user_output, '结束啦')
                return '你的冒险已经结束'
            event = self.worlds_event.get()
            self.context_list.append(event)
            user_output = user_output + event
        
        print(user_output)
        return user_output

载入模型

# 这里要下载一个 gpt 13个G,耐心等待
world = World()

开始游戏

  • 由于硬件显存问题,多次运行可能遇到***out of memory***的问题,需重启执行器
print(world.start())
while True:
    ans = input()
    print(ans)
    if world.receive(ans) == '你的冒险已经结束':
        break
冰冷,抖动……醒来的瞬间,你猛的从地面跳了起来,惊慌的看向四周,脑海里的办公室环境和眼前的环境瞬间出现了混淆,几秒之后你从混淆里清醒过来。现在你来到一个城市之中,丧尸在街道上嘶吼,你好像是来到了是浣熊市,而核弹将在一小时后降临。
爱丽丝在吗,不在我就往城外跑,顺便打一打路上的丧尸
我们可以打退他们的进攻。
其实我觉得打退不了,他们太多了,一会儿爆炸了怎么办啊
要不看看倒计时,六秒之后,我就成功逃走了...... 地图上的的标志都变成了一个红色的半径。
??那我往哪儿跑的,要不朝厕所跑吧
好远啊,你又不像打游戏跑地图的,不会迷路。历经千辛万苦,你来到一个地下实验室,旁边冰冷的警报声在提示着你T病毒已经泄露,一管病毒试剂摆在你面前。
富人靠科技,穷人靠变异,打了这针我就可以进化了
直接变怪兽,这样也成啊。
这是生化危机,不是虐杀原型,现在我可以一路杀过去了
扮猪吃老虎。
那我最后跑到了城外,活了下来
幸好我记得一切的一切,但为什么眼前的女孩子是我?满眼满眼地都是,窒息的感觉,只能一闭眼,人生真是不公啊,第一眼就可以知道人家想的什么。
我打了针反而变成了女生,可我想变的是超人
猩红的双眼,喷涌而出火热的鲜血。你来到了三体世界,你身上沾染的病毒不小心泄露到了这个世界,现在整个世界都在面临生化危机。
让我找个安全的地方,朝着南北极出发
弗瑞奥的家很近啊。
找到一个北极的地下基地,让我看看怎么阻止世界末日吧
这样可以有较多的时间与这个世界宝贵的人类在一起了。
最后我开始反攻三体世界
借助科技,我终于变出了超英赶美,猩红的双眼中闪烁着人类坚定的眼神,超英基地将被我攻陷,下一秒,无数丧尸朝我冲来,记住用喷火器,嗯,我是革命战士,拿起手中的那根电锯,万物一体,指向丧尸大迁徙的方向......命运啊,不要让我一路做着坏事,背上那黑黝黝的兽皮北向,这些都为我铺好了成功的
最后赶超英美,达到世界和平
达到世界霸权,走向威武雄狮的彩虹之路......(好像越来越不押韵了) 先更到这里,有时间再来补个结局<eod>改行IT 到底有没有前途? 我是一个毕业了一年的 程序员 ,苦逼的生活让我找份工作都是挑战。无处可逃的你跑到了“自然选择”号飞船的队伍之中,你意识到三体人的舰队即将抵达,章北海会劫持飞船逃离太阳系。
好的好的知道了,那我就开启流浪地球计划
你的冒险已经结束

总结

  • 利用PaddleNLP和其内置的中文预训练模型,可以实现简单的文本生成任务
  • 为了保证剧情顺畅可延续,需要预先定义部分关键剧情,比如“一管病毒试剂摆在你面前你要怎么做”等等。
  • 现在没有finetune,生成故事的质量参差不齐, 如果有特定语料微调效果会更好, 但经过最新测试,低估了数据获取的难度,没有合适的上文,动作,下文的数据集
  • 而通过自己标注的样本去 finetune GPT-2,测试发现几十个样本的 finetune 效果几乎为0
  • 原版的Ai Dungeon 已经向我们展示了在有良好的训练流程,训练数据的情况下,这种游戏可以达到一个什么样的水平。不论这类项目的思路是不是未来游戏发展的方向,但毫无疑问会在某些地方发挥重要的作用。比如让游戏中的NPC有一定程度的可交互性,与文字工作者进行头脑风暴,提供创作灵感等等。

参考项目

关于

天津大学语言与心理计算 LMc Lab, 研究方向为智能对话与心理计算,个性化推荐等

欢迎大家有问题留言交流学习,共同进步成长

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐