BERT模型+rabbitmq队列,进行实时预测,防止每次预测都重新加载图

1. 创建一个新类,使用tensorflow内置的from_generator函数,通过生成器传入句子,生成器中使用while循环,通过channel获取rabbitmq的句子进行预测

# 用于实时预测的一个类,内置了rabbitmq消息队列,由消息队列传入预测句子,最终实时打印出预测结果
class BertPredictByGen(object):
    def __init__(self, estimator, label_list, tokenizer, channel, queue_name):
        self.estimator = estimator
        self.label_list = label_list
        self.tokenizer = tokenizer
        self.channel = channel
        self.queue_name = queue_name

    def input_fn_builder2(self):
        def gen():
            while True:
                method, properties, qs = self.channel.basic_get(self.queue_name, auto_ack=False)
                if not qs:
                    continue
                self.channel.basic_ack(delivery_tag=method.delivery_tag)  # 应答
                text = str(qs, encoding='UTF-8')
                # guid这里实际用不到,可以随便写,但是label必须是label_list中的一个
                examples = [InputExample(guid=0, text_a=text, text_b=None, label="其他")]
                features = convert_examples_to_features(examples, self.label_list, FLAGS.max_seq_length,
                                                        self.tokenizer)
                all_input_ids = []
                all_input_mask = []
                all_segment_ids = []
                all_label_ids = []

                for feature in features:
                    all_input_ids.append(feature.input_ids)
                    all_input_mask.append(feature.input_mask)
                    all_segment_ids.append(feature.segment_ids)
                    all_label_ids.append(feature.label_id)

                yield {
                       'input_ids': all_input_ids,
                       'input_mask': all_input_mask,
                       'segment_ids': all_segment_ids,
                       'label_ids': all_label_ids,
                       }

        def input_fn(params):
            # batch_size = params["batch_size"]
            types = {
                     'input_ids': tf.int32,
                     'input_mask': tf.int32,
                     'segment_ids': tf.int32,
                     'label_ids': tf.int32,
                     }
            shapes = {
                      'input_ids': (None, FLAGS.max_seq_length),
                      'input_mask': (None, FLAGS.max_seq_length),
                      'segment_ids': (None, FLAGS.max_seq_length),
                      'label_ids': (None,),
                      }
            return tf.data.Dataset.from_generator(gen, output_types=types, output_shapes=shapes).prefetch(1)

        return input_fn

    def predict(self):
        for result in self.estimator.predict(self.input_fn_builder2(), yield_single_examples=False):
            answer = self.label_list[np.argmax(result['probabilities'])]  # 预测结果
            print("raw result:", answer)

2. 原本源码的do_predict函数改成如下:

if FLAGS.do_predict:
        project_config = modeling.BertConfig.from_json_file(FLAGS.project_config_file)  # 加载项目配置文件,自己按照bert_config_file写一个配置文件,用于存储rabbitmq的配置
        credentials = pika.PlainCredentials(username=project_config.queue_username, password=project_config.queue_password)
        connection = pika.BlockingConnection(
            pika.ConnectionParameters(host=project_config.queue_host, virtual_host=project_config.queue_virtual_host,
                                      credentials=credentials))
        channel = connection.channel()  # 创建频道

        classifer = BertPredictByGen(estimator=estimator, label_list=label_list, tokenizer=tokenizer, channel=channel,
                                     queue_name=project_config.queue_name)  # 实例化类
        classifer.predict()  # 进行预测

 

posted @ 2021-11-19 11:31  一朵包纸  阅读(45)  评论(0编辑  收藏  举报