po3a  
# 导入相关包
import os
import re
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# 导入预训练CNN
from tensorflow.keras.applications import efficientnet
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

import json
import jieba # ! pip install jieba
import tqdm
import cv2
config = tf.compat.v1.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5
session = tf.compat.v1.Session(config=config)
# 图片地址
IMAGES_PATH = "./ai_challenger_caption_validation_20170910/caption_validation_images_20170910/"

# 目标大小
IMAGE_SIZE = (299, 299)

# 词汇量大小
VOCAB_SIZE = 10000

# 输出句子单词长度
SEQ_LENGTH = 25

# 特征向量长度
EMBED_DIM = 512

# 输出层维度大小
FF_DIM = 512

# 参数
BATCH_SIZE = 64
EPOCHS = 30
AUTOTUNE = tf.data.AUTOTUNE

class ImageCaptioningModel(keras.Model):
    def __init__(
            self, cnn_model, encoder, decoder, num_captions_per_image=5, image_aug=None,
    ):
        super().__init__()
        self.cnn_model = cnn_model
        self.encoder = encoder
        self.decoder = decoder
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.acc_tracker = keras.metrics.Mean(name="accuracy")
        self.num_captions_per_image = num_captions_per_image
        self.image_aug = image_aug

    def calculate_loss(self, y_true, y_pred, mask):
        loss = self.loss(y_true, y_pred)
        mask = tf.cast(mask, dtype=loss.dtype)
        loss *= mask
        return tf.reduce_sum(loss) / tf.reduce_sum(mask)

    def calculate_accuracy(self, y_true, y_pred, mask):
        accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
        accuracy = tf.math.logical_and(mask, accuracy)
        accuracy = tf.cast(accuracy, dtype=tf.float32)
        mask = tf.cast(mask, dtype=tf.float32)
        return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)

    def _compute_caption_loss_and_acc(self, img_embed, batch_seq, training=True):
        '''
        计算loss
        '''

        # 图片的embedding特征输入encoder,得到新的seq,大小(N,100,512)
        encoder_out = self.encoder(img_embed, training=training)

        # batch_seq的shape:(64, 25)
        # 前24个单词(去尾)
        batch_seq_inp = batch_seq[:, :-1]

        # 后24个单词(掐头),用做ground truth标注
        batch_seq_true = batch_seq[:, 1:]

        # mask掩码,将batch_seq_true中的每一个元素和0作对比,返回类似[true,true,false]形式的mask,遇到0,则会变成false,0表示字符串中长度不够25的补白部分(padding)
        mask = tf.math.not_equal(batch_seq_true, 0)

        # 输入decoder预测的序列
        batch_seq_pred = self.decoder(
            batch_seq_inp, encoder_out, training=training, mask=mask
        )
        # 计算loss和acc
        loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
        acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
        return loss, acc

    def train_step(self, batch_data):
        '''
        训练步骤
        '''
        # 获取图片和标注
        batch_img, batch_seq = batch_data
        # 初始化
        batch_loss = 0
        batch_acc = 0
        # 是否使用数据增强
        if self.image_aug:
            batch_img = self.image_aug(batch_img)

        # 获取图片embedding特征
        img_embed = self.cnn_model(batch_img)

        # 遍历5个文本标注
        for i in range(self.num_captions_per_image):
            with tf.GradientTape() as tape:
                # 计算loss和acc
                # batch_seq的shape:(64, 5, 25)
                loss, acc = self._compute_caption_loss_and_acc(
                    img_embed, batch_seq[:, i, :], training=True
                )

                # 更新loss和acc
                batch_loss += loss
                batch_acc += acc

            # 获取所有可训练参数
            train_vars = (
                    self.encoder.trainable_variables + self.decoder.trainable_variables
            )

            # 获取梯度
            grads = tape.gradient(loss, train_vars)

            # 更新参数
            self.optimizer.apply_gradients(zip(grads, train_vars))

        # 更新
        batch_acc /= float(self.num_captions_per_image)
        self.loss_tracker.update_state(batch_loss)
        self.acc_tracker.update_state(batch_acc)

        return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}

    def test_step(self, batch_data):
        batch_img, batch_seq = batch_data
        batch_loss = 0
        batch_acc = 0

        # 获取图片embedding特征
        img_embed = self.cnn_model(batch_img)

        # 遍历5个文本标注
        for i in range(self.num_captions_per_image):
            loss, acc = self._compute_caption_loss_and_acc(
                img_embed, batch_seq[:, i, :], training=False
            )

            batch_loss += loss
            batch_acc += acc

        batch_acc /= float(self.num_captions_per_image)

        self.loss_tracker.update_state(batch_loss)
        self.acc_tracker.update_state(batch_acc)

        return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}

    @property
    def metrics(self):
        return [self.loss_tracker, self.acc_tracker]


class TransformerEncoderBlock(layers.Layer):
    # transformer encoder网络:https://www.youtube.com/watch?v=n9TlOhRjYoc
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__()
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.0
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.dense_1 = layers.Dense(embed_dim, activation="relu")

    def call(self, inputs, training, mask=None):
        # layer norm
        inputs = self.layernorm_1(inputs)

        inputs = self.dense_1(inputs)
        # multi head attention
        # training:布尔值,表示推理还是训练(是否使用 dropout)
        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=None,
            training=training,
        )
        # residual然后再layer norm
        out_1 = self.layernorm_2(inputs + attention_output_1)
        return out_1


class PositionalEmbedding(layers.Layer):
    # 位置编码
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super().__init__()
        '''
        embedding用法:https://stats.stackexchange.com/questions/270546/how-does-keras-embedding-layer-work
        input_dim:词汇数量;output_dim:特征向量大小
        将下图左列转为右边
        +------------+------------+
        |   index    |  Embedding |
        +------------+------------+
        |     0      | [1.2, 3.1] |
        |     1      | [0.1, 4.2] |
        |     2      | [1.0, 3.1] |
        |     3      | [0.3, 2.1] |
        |     4      | [2.2, 1.4] |
        |     5      | [0.7, 1.7] |
        |     6      | [4.1, 2.0] |
        +------------+------------+
        '''
        # token embedding:长度为vocab_size,特征向量为:embed_dim
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        # position_embeddings:
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

        # 512开根号:22.627416998:https://jalammar.github.io/illustrated-transformer/
        self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))

    def call(self, inputs):
        # 获取caption长度,这里是24个(前24个单词)
        length = tf.shape(inputs)[-1]

        # 生成0~length的数字
        positions = tf.range(start=0, limit=length, delta=1)

        # 输入的句子index转为embedding特征,大小:(N, 24, 512)
        embedded_tokens = self.token_embeddings(inputs)
        # 乘以22.62
        embedded_tokens = embedded_tokens * self.embed_scale

        # 位置编码,大小:(24, 512)
        embedded_positions = self.position_embeddings(positions)

        # 加和 返回
        return embedded_tokens + embedded_positions




class TransformerDecoderBlock(layers.Layer):

    def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):
        super().__init__()
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.1
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.1
        )
        self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu")
        self.ffn_layer_2 = layers.Dense(embed_dim)

        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()

        # 位置编码
        self.embedding = PositionalEmbedding(
            embed_dim=EMBED_DIM, sequence_length=SEQ_LENGTH, vocab_size=VOCAB_SIZE
        )

        self.out = layers.Dense(VOCAB_SIZE, activation="softmax")

        self.dropout_1 = layers.Dropout(0.3)
        self.dropout_2 = layers.Dropout(0.5)
        self.supports_masking = True

    def call(self, inputs, encoder_outputs, training, mask=None):
        # 获取位置编码,(N,24) --> (N,24,512)
        inputs = self.embedding(inputs)

        '''
        shape:(64,24,24)
        64个一模一样,大小为(24, 24)的mask

        [[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0]
         [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]

        '''
        causal_mask = self.get_causal_attention_mask(inputs)

        '''
        mask (64,24) --> padding_mask (64, 24, 1)

        64个大小为(24, 1)的mask

        [[1][1][1]...[0][0][0][0][0]]

        '''
        padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)

        '''
        mask (64,24) --> combined_mask (64, 1, 24)            
        64个大小为(1, 24)的mask
        [[1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

        '''
        combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)

        '''
        在combined_mask与causal_mask选择最小值,大小(64, 24, 24)
        64个不再一模一样,大小为(24, 24)的mask

        [[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
         [1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

        '''

        combined_mask = tf.minimum(combined_mask, causal_mask)

        # 第一个masked self  attention,QKV都是inputs, mask是causal mask,强制训练时只关注输出位置左侧的token,以便模型可以自回归地推断

        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=combined_mask,
            training=training,
        )
        out_1 = self.layernorm_1(inputs + attention_output_1)

        # cross attention,其中K、V来自encoder,Q来自decoder前一个的attention输出,mask是padding mask,用来遮挡25个单词中补白的部分
        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=padding_mask,
            training=training,
        )
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        ffn_out = self.ffn_layer_1(out_2)
        ffn_out = self.dropout_1(ffn_out, training=training)
        ffn_out = self.ffn_layer_2(ffn_out)

        ffn_out = self.layernorm_3(ffn_out + out_2, training=training)
        ffn_out = self.dropout_2(ffn_out, training=training)

        # 最后输出为VOCAB_SIZE大小的向量,对应位置的大小为概率,可以查索引来获取相应原单词
        preds = self.out(ffn_out)
        return preds

    def get_causal_attention_mask(self, inputs):
        '''
        causal: 因果关系mask
        '''
        # (N,24,512)
        input_shape = tf.shape(inputs)
        # 分别为N,24
        batch_size, sequence_length = input_shape[0], input_shape[1]

        # 范围0~24的列表,变成大小(24, 1)的数组
        i = tf.range(sequence_length)[:, tf.newaxis]
        # 范围0~24的列表
        j = tf.range(sequence_length)
        mask = tf.cast(i >= j, dtype="int32")

        # 大小为(1, 24, 24)
        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))

        scale = tf.concat(
            [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
            axis=0,
        )
        # (1, 24, 24)铺成(64, 24, 24)
        result = tf.tile(mask, scale)

        return result

def get_cnn_model():
    # CNN模型
    base_model = efficientnet.EfficientNetB0(
        input_shape=(*IMAGE_SIZE, 3), include_top=False, weights="imagenet",
    )
    # 冻住特征提取层
    base_model.trainable = False
    base_model_out = base_model.output
    base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out)
    cnn_model = keras.models.Model(base_model.input, base_model_out)
    return cnn_model


image_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.2),
        layers.experimental.preprocessing.RandomContrast(0.3),
    ]
)


cnn_model = get_cnn_model()
encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1)
decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2)
strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
strip_chars = strip_chars.replace("<", "")
strip_chars = strip_chars.replace(">", "")
# train_data, valid_data = train_val_split(captions_mapping)
def custom_standardization(input_string):
    # 全部转为小写
    lowercase = tf.strings.lower(input_string)
    return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")

token_len = []
def load_captions_json(filename):
    filename=r'D:\V1.0\code\image_caption\ai_challenger_caption_validation_20170910\caption_validation_annotations_20170910.json'
    caption_mapping = {}
    text_data = []
    images_to_skip = set()

    with open(filename) as f:
        # 读取json文件
        json_data = json.load(f)
        # 遍历
        for item in tqdm.tqdm(json_data):
            # 文件名
            img_name = item['image_id']
            img_name = os.path.join(IMAGES_PATH, img_name.strip())
            # 遍历5个标注
            for caption in item['caption']:

                # 分词
                tokens = [word for word in jieba.cut(caption)]

                # 根据tokens构造caption(打空格)
                caption = " ".join(tokens)

                # 统计一下长度
                token_len.append(len(tokens))

                if len(tokens) < 3 or len(tokens) > SEQ_LENGTH:
                    images_to_skip.add(img_name)
                    continue

                # 如果文件名以jpg结尾,且标注不在images_to_skip中
                if img_name.endswith("jpg") and img_name not in images_to_skip:
                    # 增加开始和结束token
                    caption = "<start> " + caption.strip() + " <end>"
                    text_data.append(caption)

                    if img_name in caption_mapping:
                        # 追加
                        caption_mapping[img_name].append(caption)
                    else:
                        # 初始化
                        caption_mapping[img_name] = [caption]

        # 如果文件名在images_to_skip中,则将caption_mapping中的元素删除掉
        for img_name in images_to_skip:
            if img_name in caption_mapping:
                del caption_mapping[img_name]


        return caption_mapping, text_data
captions_mapping, text_data = load_captions_json("D:\V1.0\code\image_caption\ai_challenger_caption_validation_20170910\caption_validation_annotations_20170910.json")

vectorization = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode="int",
    output_sequence_length=SEQ_LENGTH,
    standardize=custom_standardization,
)
vectorization.adapt(text_data)
# valid_dataset = make_dataset(list(valid_data.keys()), list(valid_data.values()))



caption_model = ImageCaptioningModel(
    cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation,
)
load_status = caption_model.load_weights("image_caption/my_model/checkpoint")
vocab = vectorization.get_vocabulary()
index_lookup = dict(zip(range(len(vocab)), vocab))
max_decoded_sentence_length = SEQ_LENGTH - 1
# valid_images = list(valid_data.keys())
# valid_caption = list(valid_data.values())
# # valid_len = len(valid_images)


def decode_and_resize(img_path):
    # 读取图片,并缩放
    img = tf.io.read_file(img_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMAGE_SIZE)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img



import random


def generate_caption():
    # 在测试集中随机取一张图片
    # random_index = random.randrange(0, valid_len)
    # sample_img = valid_images[random_index]
    # sample_img = r"C:\Users\admin\Desktop\9.virtual reader\image_caption\ai_challenger_caption_validation_20170910\caption_validation_images_20170910\0a1bd8c4a0bdeacb4142138b471bd09ec739f20f.jpg"

    # sample_img=cv2.imread(sample_img)
    # img=tf.convert_to_tensor(sample_img, dtype=tf.uint8)
    # img = tf.image.resize(sample_img, IMAGE_SIZE)
    # img = tf.image.convert_image_dtype(img, tf.float32)

    # img_show=cv2.resize(sample_img,(299,299))
    # img_show= tf.image.convert_image_dtype(sample_img, tf.float32)
    # sample_caption = valid_caption[random_index][0]
    # 读取图片
    sample_img="D:\V1.0\code\image_caption\image_caption.jpg"

    # 添加文件存在验证
    if not os.path.exists(sample_img):
        print(f"错误:文件不存在 - {sample_img}")
        # 列出目录内容以帮助调试
        dir_path = os.path.dirname(sample_img)
        print(f"目录内容: {os.listdir(dir_path)}")
        return "无法生成描述"
    # sample_img=r"C:\Users\admin\Desktop\9.virtual reader\image_caption\ai_challenger_caption_validation_20170910\caption_validation_images_20170910\0a1bd8c4a0bdeacb4142138b471bd09ec739f20f.jpg"
    # sample_img = r"C:\Users\admin\Desktop\9.virtual reader\image_caption\ai_challenger_caption_validation_20170910\caption_validation_images_20170910\0a6cb526ac4fc835f2bde94cff56d568c0158fba.jpg"
    # sample_img = r"C:\Users\admin\Desktop\9.virtual reader\image_caption\ai_challenger_caption_validation_20170910\caption_validation_images_20170910\0a02cfeb05ad00160daee682f655eb04b92d9199.jpg"
    # sample_img = r"C:\Users\admin\Documents\Tencent Files\3105625151\FileRecv\MobileFile\demo.jpg"
    sample_img = decode_and_resize(sample_img)
    img_show = sample_img.numpy().clip(0, 255).astype(np.uint8)
    # print(type(img_show))
    # img_show=cv2.imread(r"C:\Users\admin\Desktop\9.virtual reader\image_caption\ai_challenger_caption_validation_20170910\caption_validation_images_20170910\0a2a6dbc1c6646510907e980d746bfeb97a0cf6c.jpg")

    plt.imshow(img_show)
    plt.axis('off')
    plt.show()

    # 保存
    # cv2.imwrite('./img/raw.jpg', cv2.cvtColor(img_show, cv2.COLOR_RGB2BGR))
    # 获取CNN特征
    img = tf.expand_dims(sample_img, 0)
    img = caption_model.cnn_model(img)

    # 传给encoder
    encoded_img = caption_model.encoder(img, training=False)

    # 1.先提供"<start> "
    # 2.传给decoder推理,
    # 3.不断投喂给模型,直到遇到<end>停止
    # 4.如果循环次数超出句子长度,也停止
    decoded_caption = "<start> "
    for i in range(max_decoded_sentence_length):
        tokenized_caption = vectorization([decoded_caption])[:, :-1]
        mask = tf.math.not_equal(tokenized_caption, 0)

        # 预测
        predictions = caption_model.decoder(
            tokenized_caption, encoded_img, training=False, mask=mask
        )
        sampled_token_index = np.argmax(predictions[0, i, :])
        sampled_token = index_lookup[sampled_token_index]
        if sampled_token == " <end>":
            break
        decoded_caption += " " + sampled_token
        # decoded_caption += sampled_token

    decoded_caption = decoded_caption.replace("<start> ", "")
    decoded_caption = decoded_caption.replace(" <end>", "").strip()
    #
    # sample_caption = sample_caption.replace("<start> ", "")
    # sample_caption = sample_caption.replace(" <end>", "").strip()
    print(decoded_caption)
    return decoded_caption
    # print('真实:', sample_caption)
# generate_caption()

 

posted on 2025-04-29 09:25  po3a  阅读(13)  评论(0)    收藏  举报