style-transfer
《计算机图形学》期末大作业,代码在github
复现
先执行bash train_all.sh
进行模型微调,在../styles下生成bin文件,
之后运行python run_all.py
在当前code文件夹下生成结果图片,
最后运行python changePic.py
进行部分图片的替换。
原理
整体框架:DreamBooth-Lora
我们的整体框架依旧是Dreambooth-Lora的结构。将参考图片作为输入微调基底模型,生成给定物体在特定风格下的图片。微调时prompt统一采取物体+in style xx的结构,使得模型将一组图片的风格与prompt中的style xx绑定,生成时改变物体的prompt即可。
使用Lora去微调stablediffusion里的UNet结构。Lora将注意力层的矩阵分解为两个较小低秩的矩阵,不对微调结果产生较大影响的同时,显著降低了训练的计算量,以及存储的空间。
创新点
1.超参数调整
我们调整了代码里的各个参数,探究其对生成质量的影响,调节的参数有lora_rank,train_steps,learning_rate_scheduler以及gradient_accumulation_steps。
- lora_rank:
当秩较低时,模型更新的自由度较低,微调过程只能在一个较小的空间内调整参数。这可以减少过拟合的风险,并且计算成本较低,但可能无法捕捉到更复杂的风格信息。相对地,当秩较高时,模型有更大的自由度进行微调,能够捕捉更复杂的风格信息。然而,这也意味着计算成本的增加,并且如果秩过高,可能会导致过拟合。实验中,我们尝试了rank为8,16,32,64的情况,最终选择rank=32。
- train_steps:train_steps更多意味着模型可以充分学习参考图像的风格,从而生成更好的图像,但过高也会导致过拟合问题,导致模型不遵从文本的现象发生。
- learning_rate_scheduler:我们使用了cosine学习率调度器,有以下几个优势:
- 平滑的学习率变化:余弦退火学习率调度器提供了一种平滑的学习率降低方式,避免了学习率突然变化可能导致的训练不稳定。
- 避免局部最小值:通过在训练过程中周期性地增加学习率,模型有机会跳出当前所在的局部最小值,从而可能找到更好的解决方案。
- gradient_accumulation_steps:梯度累积的基本思想是在更新模型参数之前,先累积多个小批量数据的梯度。这样做可以使用更大的批量大小(batch
size)而不会遇到内存不足的问题。但实验中我们设置gradient_accumulation_steps=10,期望着可以平滑梯度,一次学到参考图像的共性(风格),而减少对特性(内容)的学习,实践中发现这会显著提升训练时间,效果没有显著提升,因此最终依旧设置gradient_accumulation_steps=10。
#### 2.prompt调优
在生成prompt过程中,由于数据量对人工来说比较大,我们使用了openai大模型的api识别图片批量生成prompt,详见generate_all.py。
##### prompt设计 在baseline的训练中,图片的prompt只有style
xx,我们额外加入内容参考,使其成为物体+in style
xx结构,这样在训练时模型能够结合先验知识区分图片的内容和风格,更容易学到风格特征。
后续,观察到某些风格模型不易学习,我们使用大模型识图主动将一些风格特征喂给模型,降低模型学习的难度。 ##### prompt对点调优 由于通用的prompt对于个别点表现比较差,我们可以人工查看某些点在特定prompt下的生成结果,调整prompt,使得生成结果更加符合要求,详见run_specific.py。在对点调优过程中,我们有如下发现: 由于Dreambooth的风格迁移本质上是将物体与style xx绑定在物体+in style xx的语义下,因此调整物体部分的prompt显著影响生成效果,例如,可以提供更精确的描述帮助理解内容,或者提供粗略的风格帮助理解风格。然而,上述绑定存在问题:由于是基于语义的绑定,因此物体内容和风格的区分并不够好,具体表现为:
- 不方便有效控制主次:物体部分的prompt过长,或有特征明显的prompt,会导致风格退化或内容错误。
- 前者过于复杂的要求使得模型难以聚焦到要求的主体物体,导致生成不明内容的色块。
- 后者例如风格00以梵高油画风格为粗略的风格帮助理解风格,但梵高特征过于明显,导致prompt较短时错误生成梵高自画像;风格07以黄橙色火焰拼贴画风格为粗略的风格,但生成物体Strawberry时,由于其鲜明的外形与颜色,导致生成图片颜色为玫红色且火焰外形不明显。
针对以上问题,我们提出了以下策略: - 同一个prompt,同义换一换语序可能方便理解 - 太长的prompt不work,试着删一些过多的要求 - 不方便理解的词,用近义词替代或用简单词汇解释含义 - 避免使用否定词,如没有盛放物品的盘子,模型难以理解
3.数据增强
对图片进行随机翻转
由于样本量不足,训练轮数提高时容易产生过拟合现象,因此我们尝试进行一些简单的数据增强。在训练时,我们对参考图片进行随机水平或竖直翻转,以提高模型的泛化能力。
对图片进行裁剪
最初的考虑是图片的风格与内容无关,只是一种纹理,只学图片的一部分也可以学到,在参考图片1024*1024的情况下,尝试将其裁剪为4张512*512的图片进行训练。这会导致生成图片不完整,因为参考图片本身不完整。另一方面,查看具体参考图片后,风格不单单是物体的纹理,对物体的位置形状等也有一些要求,所以不能喂给模型不完整的图片。
不完整的图片如下
4.强化学习:迭代训练
在实验过程中,发现生成图片的质量与prompt高度相关,对于一些常见词,如dog、flower等,很容易生成高质量的图片,但对于一些比较生僻的词,如museum,生成的图片质量较差。由于数据集受到限制,我们希望模型自我学习,即按照强化学习的思路从生成结果中学习好的生成结果,改善坏的生成结果。因此,我们设计了一种迭代训练
的框架:
-
采样阶段:批量推理得到样本。该过程可以通过调整prompt、seed、超参数等方式获取多种分布下的样本。
- 人工标注阶段:人工评估样本质量,并给出打分。 -
迭代训练阶段:将人工评分超过阈值的样本加入训练用的数据集,重新训练模型。重复上述过程,直到满意为止。
然而在实验过程中,发现迭代训练效果不理想,出现了以下问题: -
分辨率降低,可能是由于训练用的参考图片是
论文主页Direct Preference for Denoising Diffusion Policy Optimization (D3PO)
论文地址Using Human Feedback to Fine-tune Diffusion Models without Any Reward Model
在采用了上述的迭代训练发现效果不理想后,我们认识到了这种“原始”的强化学习的弊端,尝试使用更加成熟的深度强化学习框架。这篇被CVPR2024收录的论文引入了直接偏好去噪扩散策略优化(Direct
Preference for Denoising Diffusion Policy
Optimization,D3PO)方法来直接微调扩散模型。其基本思路是通过比较img
pair,每个prompt生成两张图,让模型向rm分数较好的图像w学习,并远离rm分数较低的图像。整体流程如下:
-
采样阶段:在推理模式下,使用同一prompt生成两张图,并采集prompt、各时间步必要数据。
- 人类反馈阶段:人类标注两张图的rm分数。 -
计算必要数据:根据采样数据与人类标注,计算各时间步的特定数据,为优化做准备。
- 训练优化阶段:使用图中公式计算D3PO loss,更新模型参数,实现优化。
我们选用D3PO,有以下重要原因: - 无需reward model:不仅无需对奖励模型进行训练,因此更直接、更经济、计算开销最小;更重要的是符合比赛要求,不引入其他模型。 - 良好的兼容性:作者在github的issue中指出D3PO兼容我们的框架。
因此,我们试图参考D3PO的github实现(基于PyTorch),在jittor框架下实现D3PO。然而,遇到了严重问题:实现了采样器sample.py后,发现PyTorch下可以保存的采样数据在jittor下会报错。由于无法解决,因此没有继续实现后续的update.py训练器,最终遗憾放弃。同时,由于jittor框架下的pipeline与PyTorch的pipeline有一定差异而不方便采样,因此我改造了原有的pipeline以获取采样数据,详见上传的pipeline_stable_diffusion_jittor.py。
6.扩散模型优化:修正噪声采样均值方差
论文主页Diffusion in Style (ivrl.github.io)
CVF论文地址ICCV 2023 Open Access Repository (thecvf.com)
通过实现发现,在Stable-diffusion中,初始潜张量影响生成的图像的风格和布局。使用相同的初始潜张量和不同的文本提示生成的图像通常会导致具有共享属性的图像,例如相似的颜色,亮度和对象定位。因此我们尝试着使用风格相关的初始潜张量开始去噪过程。我们通过简单地估计一小组目标风格图像的潜在编码的元素方式的平均值和标准差,获得初始潜张量的风格特定的分布。后续采样噪声的均值标准差也与初始噪声一致
- 实现一:根据参考图像的潜在向量来估计高斯分布的均值与方差,即
- 实现二:与实现一类似,根据论文,更适合对于少量参考图片的情况。
实现后发现效果不理想,有模糊现象,如下。
可能因为参考图像过少,也可能因为均值和方差取得不是很好(只简单的考虑了原本图片的潜向量),具体原因以及改进有待后续进一步研究。 #### 7.扩散模型优化:有偏移的初始噪声
参考博客
一般的扩散模型总是生成平均值接近0.5的图像,大多数情况下,这些图像仍然是合理的。但是,这种平均值趋近于0.5的软约束可能导致图像显得淡化、亮雾区域平衡其他暗区域、高频纹理(在标志中)而不是空白区域、灰色背景而不是白色或黑色等。
在代码中,目前的训练循环使用的噪声如下:
noise = torch.randn_like(latents)
取而代之改成如下方法
noise = torch.randn_like(latents) + 0.1 * torch.randn(latents.shape[0], latents.shape[1], 1, 1)
使用偏移噪声会展示更丰富的黑色调。 #### 8.扩散模型优化:FreeU
论文主页FreeU : Free Lunch in Diffusion U-Net
论文地址FreeU: Free Lunch in Diffusion U-Net
这篇CVPR2024
Oral论文提出了新的方法优化了扩散模型。扩散模型由扩散过程和去噪过程组成。在扩散过程中,高斯噪声逐渐添加到输入数据中,最终将其破坏为近似纯高斯噪声。在去噪过程中,通过可学习的逆扩散操作,从噪声逐渐恢复为原始输入数据。
作者发现低频分量体现了图像的全局结构,高频分量包含图像的边缘和纹理信息。而在UNet结构中,backbone对应于低频分量与去噪能力,而skip
connection对应于高频分量。然而,在推理阶段Backbone
固有的去噪能力减弱了。因此,作者提出了FreeU方法,通过标量因子s衰减高频分量,标量因子b增强低频分量。结果是使用FreeU方法推理的图片质量更高,语义对齐和细节更好。
因此我们尝试也在比赛中使用本方法。本以为只需在推理部分加入一行代码即可使用pipeline内置的FreeU方法:
1
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.4, b2=1.6)
经过实验,发现FreeU方法在本项目的表现并不尽如人意。典型情况如下:
与未使用FreeU方法的结果相比,虽然使用FreeU方法使图像细节、结构更好(如示例的狮子毛发细节与狮头结构),但在风格上却没能充分学习风格(该示例来自风格00,为梵高油画风格,而示例过于写实,并表现出原本的黄色而非风格下的蓝棕色)。提交后发现各项得分与上述分析结论相同,最终得分(综合风格、质量与语义)甚至有所下降。
分析原因,可能是因为本挑战赛赛题要求是风格迁移,而风格可以认为是风格化结构+纹理+颜色。而第一项“风格化”结构本身要求与正常结构相比有一定变形,而后两项更是集中于高频分量。因此,FreeU方法衰减高频分量,可能容易破坏风格。
9.建立优质图库
我们注意到,提供的用于训练的参考图像质量很高,而要求生成的图像中,部分prompt恰好来自参考图像。从工程角度考虑,最终目标是带给用户更好的体验,因此不妨将参考图像组织为优质图库,覆盖常见的prompt,当用户恰好生成这些prompt时,直接从图库中提取并优先展示,可以加快速度、提高质量,从而提升用户体验。注意要将图库中的图的分辨率调整到用户所需后使用,具体代码见changePic.py。
## 代码逻辑
由于我们的代码是基于baseline的,因此baseline中没有涉及到代码逻辑修改的run_all.py等文件的逻辑就不再赘述了。下面介绍与我们的创新点相对应的代码逻辑。
### torch_utils.py
如果要使用FreeU方法,将diffusers库中diffusers/utils/torch_utils.py替换为上传的torch_utils.py,然后运行推理时加入pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.4, b2=1.6)
。二者区别为原版使用torch.fft,在jittor框架下不能运行;新版使用numpy模拟torch.fft,并利用jittor中的复数与Var类型封装:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
This version of the method comes from here:
https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
"""
x = x_in
B, C, H, W = x.shape
# Non-power of 2 images must be float32
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
x = x.to(dtype=torch.float32)
# FFT
x_freq = np.fft.fftn(x,axes=(-2, -1)) # x.numpy->less style
x_freq = np.fft.fftshift(x_freq, axes=(-2, -1))
B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W), device=x.device)
crow, ccol = H // 2, W // 2
mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
x_freq = jt.nn.ComplexNumber(x_freq.real, x_freq.imag) * mask
rarray = x_freq.real.numpy().astype(complex)
iarray = x_freq.imag.numpy().astype(complex)
x_freq = rarray+iarray*complex(0,1)
# IFFT
x_freq = np.fft.ifftshift(x_freq, axes=(-2, -1))
x_filtered = jt.array(np.fft.ifftn(x_freq, axes=(-2, -1)).real.astype(np.float32))
return x_filtered.to(dtype=x_in.dtype)return_dict=False
时,后者返回元组image,
has_nsfw_concept, all_latents,
output_prompt_embeds,分别为图片、nsfw、各步latents,输入prompt的embedding。而前者只采集并输出image,
has_nsfw_concept。其他方面二者相同。 ### sample.py
D3PO的采样器。与推理代码run_all.py类似,但是增加了采样部分。大体逻辑为设置optimizer与DDIMScheduler,同一prompt生成7张图片与latents与embedding,最终保存图片、latents、next_latents、prompt_embeds、timesteps,用于训练。由于上文所述的权重保存问题,未能实际应用D3PO。
### update.py
D3PO的训练器。由于上文所述的权重保存问题,未能实际应用D3PO,故未实现。仅作为D3PO框架展示。
### run_specific.py
prompt对点调优使用。TASKID为风格编号(从0开始),STATEID为该风格下第几个生成(从0开始),KEY为物体名(prompt.json中的key)。使用时需要在prompt.json的同级目录下建立style.json,其格式与prompt.json相同,但value为如下规则:为0代表不进行对点调优;否则以该值替换为实际的prompt。即:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20with open(f"{dataset_root}/{taskid}/style.json", "r") as file:
styles = json.load(file)
with open(f"{dataset_root}/{taskid}/prompt.json", "r") as file:
prompts = json.load(file)
with open(f"{dataset_root}/{taskid}/style.json", "r") as file:
styles = json.load(file)
style_prompt=prompts["style"]
prompt=style_prompt.replace("*",prompts[KEY])
if(styles[KEY]!='0'): prompt=styles[KEY]
print(prompt)
for i in range(max(0,(2*(25*TASKID+STATEID)-1))):
rd=randn_tensor((1, 4, 64, 64), seed=SEED, dtype='float32')
image = pipe(
prompt + f" in style_{taskid}",
num_inference_steps=50,
width=512,
height=512,
guidance_scale=12,
negative_prompt=NEGATIVE_PROMPT
).images[0]
generate_all.py
对应prompt调优-prompt设计 使用openai的国内镜像,调用gpt-4V的api,让大模型识别示例图片生成prompt,并将结果保存成易读取的json文件。
crop.py
对应数据增强-对图片进行裁剪
遍历参考图像文件夹,将每一张1024*1024的图像按左上左下右上右下切为四张512*512的图片。
getNormal.py
对应扩散模型优化:修正噪声采样均值方差
使用stablediffusion中vae的encoder对图像进行编码,之后计算参考图像的均值与标准差,保存到文件中待后续读取。
changePic.py
对应建立优质图库 将参考图像文件夹中的图像压缩为512*512的图像,保存到新的文件夹中,以便后续使用。
参考代码
D3PO:参考Direct Preference for Denoising Diffusion Policy Optimization (D3PO)。主要涉及上传的sample.py。 FreeU:参考原diffusers/utils/torch_utils.py。将涉及torch.fft的部分用jittor与numpy替代。 有偏移的初始噪声:参考个人博客,主要对修改了一行代码,涉及到train.py
结果
zip里上传了可复现的结果图片和最好版本的结果图片。
因为最开始没有对随机数种子锁定,所以最好版本的无法复现。
截止到6月30号晚22:09,我们队在A榜排名11。
一些好看的结果