# 基于keras的BiLstm与CRF实现命名实体标注

## 基于字的模型标注：

• B-Person
• I- Person
• B-Organization
• I-Organization
• O

## 加入CRF layer对LSTM网络输出结果的影响

**没有CRF layer的网络示意图 **

## CRF loss function

CRF loss function 如下：
Loss Function = $\frac{P_{RealPath}}{P_1 + P_2 + … + P_N}$

### 1、Real path score

$P_{RealPath}$ =$e^{S_i}$

$S_i$ = EmissionScore + TransitionScore

EmissionScore=$x_{0,START}+x_{1,B-Person}+x_{2,I-Person}+x_{3,O}+x_{4,B-Organization}+x_{5,O}+x_{6,END}$

### 2、total score

total scroe的计算相对比较复杂，可参看https://createmomo.github.io/2017/11/11/CRF-Layer-on-the-Top-of-BiLSTM-5/

## 实现代码（keras版本）

### 1、搭建网络模型

    model = Sequential()
model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True))  # Random embedding
model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True)))
crf = CRF(len(chunk_tags), sparse_target=True)
model.add(crf)
model.summary()
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])


### 2、清洗数据

《 O

'O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', "B-ORG", "I-ORG"


    train = _parse_data(open('data/train_data.data', 'rb'))
test = _parse_data(open('data/test_data.data', 'rb'))

word_counts = Counter(row[0].lower() for sample in train for row in sample)
vocab = [w for w, f in iter(word_counts.items()) if f >= 2]
chunk_tags = ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', "B-ORG", "I-ORG"]

# save initial config data
with open('model/config.pkl', 'wb') as outp:
pickle.dump((vocab, chunk_tags), outp)

train = _process_data(train, vocab, chunk_tags)
test = _process_data(test, vocab, chunk_tags)
return train, test, (vocab, chunk_tags)


### 3、训练数据

import bilsm_crf_model

EPOCHS = 10
model, (train_x, train_y), (test_x, test_y) = bilsm_crf_model.create_model()
# train model
model.fit(train_x, train_y,batch_size=16,epochs=EPOCHS, validation_data=[test_x, test_y])
model.save('model/crf.h5')


### 4、验证数据

import bilsm_crf_model
import process_data
import numpy as np

model, (vocab, chunk_tags) = bilsm_crf_model.create_model(train=False)
predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下，连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚'
str, length = process_data.process_data(predict_text, vocab)
model.load_weights('model/crf.h5')
raw = model.predict(str)[0][-length:]
result = [np.argmax(row) for row in raw]
result_tags = [chunk_tags[i] for i in result]

per, loc, org = '', '', ''

for s, t in zip(predict_text, result_tags):
if t in ('B-PER', 'I-PER'):
per += ' ' + s if (t == 'B-PER') else s
if t in ('B-ORG', 'I-ORG'):
org += ' ' + s if (t == 'B-ORG') else s
if t in ('B-LOC', 'I-LOC'):
loc += ' ' + s if (t == 'B-LOC') else s

print(['person:' + per, 'location:' + loc, 'organzation:' + org])


['person: 周恩来 陈毅, 王东', 'location: 埃塞俄比亚 非洲 阿尔巴尼亚', 'organzation: 中华人民共和国国务院 外交部']

posted @ 2018-03-26 16:29  帅虫哥  阅读(...)  评论(...编辑  收藏