[论文阅读] LANIT@Language-Driven Image-to-Image Translation for Unlabeled Data

pre

title: LANIT: Language-Driven Image-to-Image Translation for Unlabeled Data
accepted: arxiv
paper: arxiv
code: https://github.com/KU-CVLAB/LANIT

亮点:借助CLIP,可学习prompt

针对问题

图像翻译I2I里的两大关键问题:

  1. 严重依赖样本级的域标注,如每张图都必须有标签
  2. 缺乏处理多属性图片的能力,一张人脸图像可能具有金发、碧眼等属性

第一个问题去年的TUNIT实际上解决了,但作者在下面提到,那种聚类来的标签没有明确语义,不够令人满意。

相关研究

图1 不同级别的监督

a是传统方式,需要样本级域监督;b是无监督,但性能跟语义理解有限;c是本文的方式,需要有可能的文本域描述的数据集,即数据集级标注

以前的方法都假设每个样本都有一个one-hot的域标签,指的应该是每个样本按照单一属性分类,如人、猫、狗等
作者就举例人脸会有许多属性

近来有些办法引入few-shot,或半监督学习,比如 TUNIT 就使用统一模型以完全无监督的方式联合学习域聚类(把未标注的数据拿去聚类,用一个网络预测给定图像的域标签,生成器判别器再根据预测的标签进行学习),但由于标签是聚类来的,不具有明确的语义,作者认为这样会阻碍方法的应用,同时这些方法还都是one-hot的域标签。

CLIP被应用于许多新模型中,取得了很好的效果,特别是图片合成跟编辑方面,但I2I领域还没用到。

大规模与训练语言模型如GPT-3成功后,不少改进prompt的方法被提出,也有一些工作使用CLIP的同时还优化输入的prompt,如CoOp跟CPT,它们优化连续的prompt,比手动设计的离散prompt在zero-shot图片分类上好很多。但作者认为它们需要类别监督class supervision。

核心思想

先给定一系列关于数据集的文本描述,作为dataset-level的域标签,借助CLIP计算它们跟未标注图片的相似度,进而通过top-k选出multi-hot label,汇聚风格后送入生成器生成转换后的图片。

方法(模型)流程

Overview

图2 网络格局

模型由内容编码器C,风格编码器S,映射编码器M和生成器G构成,对于N个域里的第n个域,给定一系列可能的域描述作为prompt \(t_n\),然后用内容x跟风格y两张图片生成转换后的图片\(\hat{y}\),具体来说应该是用内容向量跟汇聚后的风格向量。为了得到y的伪域标签\(d^y\),利用CLIP之类的大规模视觉语言模型,它有视觉跟语言两个编码器。

reference-guided阶段,首先风格图通过V编码得到\(v^y\),候选域prompts(图2左上角那N个短语)通过L编码得到\(u^t\),然后二者计算相似度:

\[f^{y,t}_n = \bar{v}^y \cdot \bar{u}_n^t \tag{1} \]

上面加一横表示取平均

再通过top-k得到multi-hot伪域标签,应该是取最相似的k个元素记1,其余记0,转换为0-1的multi-hot向量即\(d^y\)

之后风格向量 \(s_n^y ,\space n=1...N\) 根据 \(d^y\) 汇聚成风格向量\(a^y\),之后\(a^y\)再跟内容特征\(c^x = C(x)\)一起送入生成器去生成\(\hat{y} = G(c^x,\; a^y)\)

作者提到这里的\(s_n^y\)也可以用\(\tilde{s} = \mathcal{E}_{M}(\mathbf{z})\)代替,它来自随机的隐向量z,感觉有点类似StyleGAN那个映射网络。(\(\mathcal{E}_{M}\) 就等于 上文说的 M,只是旁边那个符号太麻烦默认省略)

训练阶段将prompt \(t_n\) 跟图片转换器\(C, S, M, G\),即图2淡灰色框那些一起联合优化,而CLIP那俩编码器V、L保持固定。

图里没有判别器,但这里提到用多域判别器结合multi-hot域标签计算\(\hat{y}\)跟风格图y之间的对抗损失,也加上了循环一致性、风格重建和风格多样性损失,还有域一致性损失域多样性损失,明明连公式也没有写出来,莫名其妙。

Language-Driven Domain Labeling

使用文本域描述(textual domain description)作为prompt \(t_n\),用人工标注的域描述或关键词初始化,也就是dataset-level监督,然后在训练中微调。这个吊\(t_n\)说的不清不楚到底是一个词还是一系列的词,也不跟图保持一致,各种符号瞎jb乱来。

提及抽取的语言特征\(u^t\)来自所有的prompt \(t \in \mathbb{R}^{N\times t}\),所以这个t代表了所有的prompt的封装?,详见下方Setting Up Domain Prompt分析。

Image Translation with Pseudo Domain Label

作者说S预测的那些\(s_n^y\)现存方法都是根据ground-truth挑一个one-hot域标签,而本文则是将其汇聚起来

\[\mathbf{a}^{\mathbf{y}}={\frac{1}{M^{\mathbf{y}}}}\sum_{n=1}^{N}{\mathbf{s}}_{n}^{\mathbf{y}}d_{n}^{\mathbf{y}}, \tag{2} \]

这里\(M^y\)\(d^y\)中非零值的数量,其实就是根据伪域标签加权平均...

Setting Up Domain Prompt

本框架高度依赖prompt \(t_n\),如果不合适就会对性能造成影响, 因此提出prompt学习网络,见图3

图3 使用域一致性损失进行prompt学习

风格图跟输出分别得到域标签,然后计算损失,优化损失就可以学到最优的prompt,同时学习转换模块

所有域共享的可学习连续向量首先被定义为template,每个域一个可学习向量,用域描述初始化,然后在训练中微调,prompt \(t_n\)定义如下:

\[\mathbf{t}_{n} = [p_{1},p_{2},\ldots,p_{L},p_{n}^{\mathrm{domain}}], \tag{3} \]

其中\(p_l,\; l \in \{1, \ldots, L \}\)是template中一个词的向量,维度跟\(p_{n}^{domain}\)一样,这L从后文可以知道是template tokens的数量,其实就是这里p的数量,总之\(p_{n}^{domain}\)是第n个文本域描述的向量。

从附录可以知道template就是 "a face of, a photo of anime with, a car painted with..." 这种,应该是给CLIP作提示的,这样一想 \(t_n\)(即prompt) 应该就是对一个域描述(短语)的包装,代表一种域,比如 "a face of smiling" 这样,由提示template跟描述domain两部分组成。

有N个候选域应该就要有N个 prompt,学习时就更新身为变量的prompt,使之能更准确地描述一个域。

既然是学习,那么 template: "a face of", domain: "smiling" 大概都只是初始化才用,后面prompt自己学完向量长度固定,可能就没有明确对应的文本(如 a face of, smiling 这样)。

为了促进收敛,所有\(p_l\; and\; p_{n}^{domain}\)都用template跟给定的文本域描述初始化,如"a face with"跟"black hair",然后在训练时微调。这sb作者这么喜欢两个事情用 and + respectively 混在一起描述。而且还是没懂,这么多\(p_l\)向量都是 "a face with" 这种模板??\(t_n\)也没在图里体现,可能得去看代码才知道。

输入域描述应当被给出,但也可以用预定义的词典获得,比如根据图像与候选文本的相似性,可以选择与数据集相关度高的文本。

同时假设给定的域描述必须忠实代表每个样本,否则会造成伪标签不准确。因此加上一个空集域\(\emptyset\)来处理未知或不确定的域,并附加数据集级描述如"food", "face",总之含义更加宽泛。

Loss Functions

Adversarial Loss

\[\mathcal{L}_{adv} = {\mathbb{E}_{\bf{x},\bf{y}}\sum_{n=1}^{N}\left[\log D_{n}({\bf{y}})d_{n}^{{\bf{y}}}+\log(1-D_{n}({\cal{G}}({\bf{x}},{\bf{a}}^{\bf{y}}))d_{n}^{{\bf{y}}})\right]}, \tag{4} \]

经典对抗损失。D是多域判别器,只是输出要用\(d^y\)加权,\(D_{n}(\cdot)\)表示第n个判别器输出,\(d^y_n\)同理

Domain-Consistency Loss

要约束\(\hat{y}\)忠实按照风格图y的multi-hot域标签生成,容易想到\(\mathcal{L} = f^{y}\cdot f^{\hat{y}}\),但作者举反例,应该是说所有域用同一个prompt(constant prompt at all domains)

\[\mathcal{L}_{\mathrm{dc}}=\mathcal{H}(\mathbf{d}^{\mathrm{y}},\mathrm{f}^{\hat{\mathrm{y}}})+\mathcal{H}(\mathbf{d}^{\hat{\mathrm{y}}},\mathrm{f}^{\mathrm{y}}), \tag{5} \]

其中\(\mathcal{H}(\cdot)\)表示交叉熵,处理multi-hot标签。第一项只训练prompt,第二项固定当前prompt训练剩余的translation模块,有助于稳定收敛跟局部最小值问题的解决。f不确定是什么,应该是之前提到的相似度吧?

Cycle-Consistency Loss

其中\(c^{\hat{y}}\)表示来自\(\hat{y}\)的内容,\(a^x\)表示x的风格向量,经典的循环一致性损失

\[\mathcal{L}_{cyc} = \mathbb{E}_{x,y}\left[\left|\left|x - \mathcal{G}(c^{\hat{y}}, a^{x})\right|\right|_{1}\right], \tag{6} \]

Style Reconstruction Loss

\[\mathcal{L}_{\mathrm{sty}} = \mathbb{E}_{\mathbf{x,y}}[||s^{\mathbf{y}}-\mathcal{E}_{S}(\hat{\mathbf{y}})||_1]. \tag{7} \]

提取生成图片的风格跟风格图的风格计算L1损失,很直接的方式

Overall Objective

\[\mathcal{L}_{\mathrm{total}}=\lambda_{\mathrm{adv}}\mathcal{L}_{\mathrm{adv}}+\lambda_{\mathrm{dc}}\mathcal{L}_{\mathrm{dc}}+\lambda_{\mathrm{cyc}}\mathcal{L}_{\mathrm{cyc}}+\lambda_{\mathrm{sty}}\mathcal{L}_{\mathrm{sty}}, \tag{8} \]

其中那些\(\lambda\)都是超参数

实验

template tokens的数量,L,设置为4

数据集和指标

五个标准数据集: Animal Faces-10, Food-10, CelebA-HQ, FFHQ, 和 LHQ,跟一些细节...

四个定量指标:mFID,Density and Coverage (D&C) ,Acc (该不会D&C算两个吧)

实验结果

表1 CelebA-HQ上的定量比较

表2 Animal Faces-10 和 Food-10上的定量比较

图5 对应表1的定性比较,LANIT效果好,甚至跟有监督的StarGAN2对比也是

图6 风格指导的Reference-guided图片转换结果,注意LANIT比其他模型反映了更多的属性(胡须、刘海、微笑等)

图7 隐向量指导的Latent-guided多样图片合成结果,每张图片上面的是对应的域文本描述

风格指导感觉应该是拿风格图作为输入,提取风格向量。而隐向量指导应该是用目标域标签(或域文本)描述作输入,经过映射网络跟风格编码器再得到风格向量。
而且从图7可以看出一个域描述由多个词构成,但学习的时候应该是所有词混在一起,计算相似度再计算multi-hot标签

Ablation Study and Analysis

Number of Domain Descriptions N
见表3,研究候选域的数量N的影响,就是用于计算相似度的短语数量。

表3 更变域数量,LANIT跟TUNIT比较

图8 每个数据集上不同数量域描述的分类比较

不同的N对本模型的影响小于TUNIT,较鲁棒。候选域多一些性能好,但计算量也大。

∅ Domain
表1(III)就研究过,用上能大幅超越baseline。图8也能说明空集域带来的性能提升。

Prompt Learning
图1(IV)可说明用上prompt学习可以通过提炼prompt减少标签噪声,提高聚类效果

Number of Activated Attributes K
激活属性数量k,应该是top-k的那个k,表1可以看出top-3超越top-1表现,意思是用多个关键词/文本描述可以更准确地描述图片,先前的工作都假设一张图只有一个label。同时也能为Animal Faces-10 和 Food-10只选取1个label,因为物种跟食物类型无法描述为多属性,体现了本框架的灵活性。

User Study
图9,略,反正就是证明LANIT这模型好

总结

本模型LANIT利用数据集级监督处理未标注数据,既避免逐样本标注,对一张图也可以考虑multi-hot域标签。因为用户提供的域描述可能无法覆盖整个数据集,引入空集来过滤数据。同时引入可学习prompt,联合学习prompt跟图像转换模块。
LANIT性能类似或超越了现存需要样本级监督(要学习summer-winter转换就得有两张图片分别是summer跟winter)的I2I模型
将来会将LANIT用在其他任务上,如domain adaptation 和 domain generalization

也是一篇利用CLIP的工作,角度比较巧妙,但行文糟糕,一遍看下来还是没能搞明白那些prompt、domain、description、text、word互相都是什么关系,模型图没包括D跟M,而且过于简洁。prompt learning那更是糟糕,图跟没有差不多。

待解明

  1. prompt \(t_n\)到底是什么
  2. Domain-Consistency Loss为什么用\(d^y, \; f^y\)计算
posted @ 2022-10-30 11:39  NoNoe  阅读(327)  评论(0编辑  收藏  举报