mengzi-t5预训练模型
首先在huggingface下载mengzi-t5-base模型以便后续训练。因为huggingface在国内下载速度较慢,可以使用代理下载,或者直接下载到本地再上传到服务器。这里使用镜像网站 下载。
1 2 3 4 !curl -L -O https://hf-mirror.com/Langboat/mengzi-t5-base/resolve/main/pytorch_model.bin?download=True !curl -L -O https://hf-mirror.com/Langboat/mengzi-t5-base/resolve/main/config.json?download=true !curl -L -O https://hf-mirror.com/Langboat/mengzi-t5-base/resolve/main/spiece.vocab?download=true !curl -L -O https://hf-mirror.com/Langboat/mengzi-t5-base/resolve/main/spiece.model?download=true
SHELL
数据准备
数据集下载
这里的数据是使用chinese-poetry 收集的唐诗宋词,由于飞桨平台已经内置该数据集,所以我们只需添加进来就可以了,这里是解压缩数据。
1 !unzip -n ./data/data70759/poems_json.zip
SHELL
数据处理
由于数据集中的诗词是繁体,使用chinese-converter库将繁体转换为简体。
1 !pip install chinese-converter
SHELL
导入库。
1 2 3 4 5 6 7 8 9 10 11 12 import jsonimport urllib.requestimport pandas as pdimport chinese_converter import pickleimport osimport pandas as pdimport numpy as np IS_TEST_FLOW = False
PYTHON
使用IS_TEST_FLOW作为测试和训练的标志,如果是测试则只处理少量数据。
数据集格式为json,每个json文件有1000首诗,格式如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 [ { "author" : "太宗皇帝" , "paragraphs" : [ "秦川雄帝宅,函谷壯皇居。" , "綺殿千尋起,離宮百雉餘。" , "連甍遙接漢,飛觀迥凌虛。" , "雲日隱層闕,風煙出綺疎。" ] , "note" : [ ] , "title" : "帝京篇十首 一" } ]
JSON
处理json文件,创建df_list列表,每个元素是一个dataframe,最后使用pd.concat合并。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 POEM_CONTENT = { 'tang' : { 'total' : 58 , 'pattern' : "./poems_json/poet.tang.{0}.json" }, 'song' : { 'total' : 255 , 'pattern' : "./poems_json/poet.song.{0}.json" } }def get_poems (is_test=True , verbose=True ): df_list = [] for dynasty in POEM_CONTENT: size = 3 if is_test else POEM_CONTENT[dynasty]['total' ] for i in range (size): url = POEM_CONTENT[dynasty]['pattern' ].format (i * 1000 ) if verbose: print (f"load {url} now" ) df_list.append(pd.read_json(url)) return pd.concat(df_list)
PYTHON
使用df.apply将繁体转换为简体。
1 2 3 4 5 6 7 8 9 10 11 12 13 df = get_poems(is_test=IS_TEST_FLOW, verbose=True ) df['concat_paragraphs' ] = ['' .join(map (str , l)) for l in df['paragraphs' ]] df = df[['author' , 'title' , 'concat_paragraphs' ]]def convert_schinese (tchinese ): return chinese_converter.to_simplified(tchinese) df['s_content' ] = df.apply(lambda row: convert_schinese('' .join(row.concat_paragraphs)), axis=1 ) df['s_title' ] = df.apply(lambda row: convert_schinese('' .join(row.title)), axis=1 ) df['s_author' ] = df.apply(lambda row: convert_schinese('' .join(row.author)), axis=1 ) my_df = dfprint ("my_df size" , len (my_df))
PYTHON
创建trim函数,替换掉一些特殊字符,限制作者、标题、内容的长度。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 MAX_AUTHOR_CHAR = 4 MAX_TITLE_CHAR = 12 MIN_CONTENT_CHAR = 20 MAX_CONTENT_CHAR = 32 BAD_TOKENS = " ()[]《》()□{}abcdefgxyz一" def trim_author_fn (row ): return row.s_author[:MAX_AUTHOR_CHAR]def trim_title_fn (row ): trimed_title = row.s_title[:MAX_TITLE_CHAR] for b in BAD_TOKENS: trimed_title = trimed_title.replace(b, "" ) return trimed_titledef trim_content_fn (row ): trimed_content = row.s_content[:MAX_CONTENT_CHAR] for b in BAD_TOKENS: trimed_content = trimed_content.replace(b, "" ) last_period = trimed_content.rfind("。" ) return trimed_content[:last_period+1 ] my_df['s_author_trim' ] = my_df.copy().apply(trim_author_fn, axis=1 ) my_df['s_title_trim' ] = my_df.copy().apply(trim_title_fn, axis=1 ) my_df['s_content_trim' ] = my_df.copy().apply(trim_content_fn, axis=1 )print ("my_df size" , len (my_df))
PYTHON
过滤掉一些无效数据,比如标题为空、内容太短、无正文等。
1 2 3 4 5 6 7 8 9 10 11 empty_title_mask = (my_df['s_title_trim' ].str .len () == 0 ) too_short_cotent_mask = (my_df['s_content_trim' ].str .len () <= MIN_CONTENT_CHAR) invalid_mask = (('无正文' == my_df['s_content_trim' ]) | ('无正文' == my_df['s_author_trim' ])) too_short_mask = empty_title_mask | too_short_cotent_mask | invalid_mask my_df = my_df.loc[~too_short_mask][[ 's_author_trim' , 's_title_trim' , 's_content_trim' ]]print ("my_df size" , len (my_df))
PYTHON
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 import re result_dict = { 's_author_trim' : [], 's_title_trim' : [], 's_content_trim' : [], }for i, row in my_df.iterrows(): c = row['s_content_trim' ] snippets = list (re.split(',|。|?' , c)) lens = [len (s) for s in snippets if s.strip() != '' ] if max (lens) != min (lens) or max (lens) not in [5 , 7 ]: continue result_dict['s_author_trim' ].append(row['s_author_trim' ]) result_dict['s_title_trim' ].append(row['s_title_trim' ]) result_dict['s_content_trim' ].append(c) my_df = pd.DataFrame(data=result_dict)print ("left" , len (my_df))
PYTHON
构建数据集
构建数据集,包括source_text和target_text。
1 2 3 4 5 6 7 8 9 10 11 12 AUTHOR_PROMPT = "模仿:" TITLE_PROMPT = "作诗:" EOS_TOKEN = '</s>' def build_dataset_df (df, include_author=True ): dfc = df.copy() if include_author: dfc['source_text' ] = TITLE_PROMPT + df['s_title_trim' ] + EOS_TOKEN + AUTHOR_PROMPT + df['s_author_trim' ] else : dfc['source_text' ] = TITLE_PROMPT + df['s_title_trim' ] dfc['target_text' ] = df['s_content_trim' ] dfc = dfc[['source_text' , 'target_text' ]] return dfc
PYTHON
带有作者的数据集。
1 df_author_title_content = build_dataset_df(my_df, True )
PYTHON
不带作者的数据集。
1 df_title_content = build_dataset_df(my_df, False )
PYTHON
合并数据集。
1 2 merged_df = pd.concat([df_author_title_content, df_title_content]) merged_df = merged_df.sample(frac=1. )
PYTHON
这里的frac=1.表示打乱数据集。
训练
安装一下torch, simplet5等必要库。
1 2 3 4 5 !pip install torch !pip install simplet5 import torch from simplet5 import SimpleT5 from transformers import T5Tokenizer, T5ForConditionalGeneration
SHELL
定义模型
定义模型类,继承SimpleT5,加载mengzi-t5-base模型。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 torch.cuda.empty_cache() local_model_path = "./MengziT5_base" extra_ids = 100 additional_special_tokens = [f"<extra_id_{i} >" for i in range (extra_ids)]class MengziSimpleT5 (SimpleT5 ): def __init__ (self ) -> None : super ().__init__() self .device = torch.device("cuda" ) def load_my_model (self, use_gpu: bool = True ): self .tokenizer = T5Tokenizer.from_pretrained(local_model_path) self .model = T5ForConditionalGeneration.from_pretrained(local_model_path) model = MengziSimpleT5() model.load_my_model() model.model = model.model.to('cuda' )
PYTHON
划分数据集
将数据集以0.98, 0.02的比例划分为训练集和验证集。
1 2 3 4 from sklearn.model_selection import train_test_split merged_df = merged_df.sample(frac=1 ) train_df, eval_df = train_test_split(merged_df, test_size=0.02 )print ("train" , len (train_df), "eval" , len (eval_df))
PYTHON
训练模型
训练模型,使用train_df训练,eval_df验证。
1 2 3 4 5 6 7 8 model.train(train_df=train_df, eval_df=eval_df, source_max_token_len=(len (TITLE_PROMPT) + MAX_TITLE_CHAR + 1 + len (AUTHOR_PROMPT) + MAX_AUTHOR_CHAR), target_max_token_len=MAX_CONTENT_CHAR, batch_size=256 , max_epochs=5 , use_gpu=True , outputdir="./Models/t5-poem-v2.1" )
PYTHON
测试模型
使用模型生成诗词。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def poem (title_str, opt_author=None , model=model, is_input_traditional_chinese=False , num_beams=2 ): model.model = model.model.to('cuda' ) if opt_author: in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] + EOS_TOKEN + AUTHOR_PROMPT + opt_author[:MAX_AUTHOR_CHAR] else : in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] if is_input_traditional_chinese: in_request = chinese_converter.to_simplified(in_request) out = model.predict(in_request, max_length=MAX_CONTENT_CHAR, num_beams=num_beams)[0 ].replace("," , "," ) if is_input_traditional_chinese: out = chinese_converter.to_traditional(out) print (f"標題: {in_request.replace('</s>' , ' ' )} \n詩歌: {out} " ) else : print (f"标题: {in_request.replace('</s>' , ' ' )} \n诗歌: {out} " )
PYTHON
测试模型。
1 2 3 4 5 for title in ['秋思' , "百花" , '佳人有约' ]: for author in ['' , "杜甫" , "李白" , "李清照" , "苏轼" ]: poem(title, author) print ()
PYTHON
使用不同的num_beams测试模型。
1 2 3 4 5 6 for title in ['冬雪' ]: for author in ['' , "杜甫" ]: for num_beams in (2 , 3 , 5 , 10 , 20 , 50 , 100 , 200 ): print (f"num beams: {num_beams} " ) poem(title, author, num_beams=num_beams) print ("-" *80 )
PYTHON
使用模型
使用模型生成诗词。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 import jsonfrom transformers import LogitsProcessorfrom transformers import LogitsProcessorListimport torchfrom simplet5 import SimpleT5from transformers import T5Tokenizer, T5ForConditionalGenerationimport chinese_converter MODEL_PATH = "./Models/t5-poem-v2.1/simplet5-epoch-4-train-loss-3.4329-val-loss-3.4315" class PoemModel (SimpleT5 ): def __init__ (self ) -> None : super ().__init__() self .device = torch.device("cuda" ) def load_my_model (self ): self .tokenizer = T5Tokenizer.from_pretrained(MODEL_PATH) self .model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH)
PYTHON
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 AUTHOR_PROMPT = "模仿:" TITLE_PROMPT = "作诗:" EOS_TOKEN = '</s>' poem_model = PoemModel() poem_model.load_my_model() poem_model.model = poem_model.model.to('cuda' ) MAX_AUTHOR_CHAR = 4 MAX_TITLE_CHAR = 12 MIN_CONTENT_CHAR = 10 MAX_CONTENT_CHAR = 64 def poem (title_str, opt_author=None , model=poem_model, is_input_traditional_chinese=False , num_beams=100 ): model.model = model.model.to('cuda' ) if opt_author: in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] + EOS_TOKEN + AUTHOR_PROMPT + opt_author[:MAX_AUTHOR_CHAR] else : in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] if is_input_traditional_chinese: in_request = chinese_converter.to_simplified(in_request) out = model.predict(in_request, max_length=MAX_CONTENT_CHAR, num_beams=num_beams)[0 ].replace("," , "," ) if is_input_traditional_chinese: out = chinese_converter.to_traditional(out) print (f"標題: {in_request.replace('</s>' , ' ' )} \n詩歌: {out} " ) else : print (f"标题: {in_request.replace('</s>' , ' ' )} \n诗歌: {out} " )
PYTHON
1 2 3 4 5 for title in ['秋思' , '佳人' , '相思' ,"幽梦" ]: for author in ['' , "杜甫" , "李白" , "李清照" , "苏轼" ]: poem(title, author) print ()
PYTHON
结论
微调mengzi-t5模型,使用唐诗宋词数据集训练了古诗生成模型,实现古诗生成。
slide见这里 。
实现效果在这里 。
github地址:poem_generate
飞桨地址:test
主要参考(抄)了chinese-ai-writing-share
参考
aistudio
chinese-poetry
hf-mirror
chinese-ai-writing-share
aichpoem