SimpleTuner

读SimpleTuner代码

learn_rate

1
2
3
4
5
6
self.config.learning_rate = (
self.config.learning_rate
* self.config.gradient_accumulation_steps
* self.config.train_batch_size
* getattr(self.accelerator, "num_processes", 1)
)
1
2
3
4
5
self.config.total_batch_size = (
self.config.train_batch_size
* self.accelerator.num_processes
* self.config.gradient_accumulation_steps
)

返回batch的逻辑在MultiAspectSampler._validate_and_yield_images_from_samples

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
batch = iterator_fn(step, *iterator_args)

latents = batch["latent_batch"].to(
self.accelerator.device, dtype=self.config.weight_dtype
)

noise = torch.randn_like(latents)

bsz = latents.shape[0]

timesteps = segmented_timestep_selection(
actual_num_timesteps=self.noise_scheduler.config.num_train_timesteps,
bsz=bsz,
weights=weights,
use_refiner_range=StateTracker.is_sdxl_refiner()
and not StateTracker.get_args().sdxl_refiner_uses_full_range,
).to(self.accelerator.device)

noisy_latents = self.noise_scheduler.add_noise(
latents.float(), input_noise.float(), timesteps
).to(
device=self.accelerator.device,
dtype=self.config.weight_dtype,
)

encoder_hidden_states = batch["prompt_embeds"].to(
dtype=self.config.weight_dtype, device=self.accelerator.device
)

add_text_embeds = batch["add_text_embeds"]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
model_pred = self.model_predict(
batch=batch,
latents=latents,
noisy_latents=noisy_latents,
encoder_hidden_states=encoder_hidden_states,
added_cond_kwargs=added_cond_kwargs,
add_text_embeds=add_text_embeds,
timesteps=timesteps,
)
```

```Python
init_backend["sampler"] = MultiAspectSampler(
id=init_backend["id"],
metadata_backend=init_backend["metadata_backend"],
data_backend=init_backend["data_backend"],
accelerator=accelerator,
batch_size=args.train_batch_size,
debug_aspect_buckets=args.debug_aspect_buckets,
delete_unwanted_images=backend.get(
"delete_unwanted_images", args.delete_unwanted_images
),
resolution=backend.get("resolution", args.resolution),
resolution_type=backend.get("resolution_type", args.resolution_type),
caption_strategy=backend.get("caption_strategy", args.caption_strategy),
use_captions=use_captions,
prepend_instance_prompt=backend.get(
"prepend_instance_prompt", args.prepend_instance_prompt
),
instance_prompt=backend.get("instance_prompt", args.instance_prompt),
conditioning_type=conditioning_type,
is_regularisation_data=is_regularisation_data,
)

init_backend["train_dataloader"] = torch.utils.data.DataLoader(
init_backend["train_dataset"],
batch_size=1, # The sampler handles batching
shuffle=False, # The sampler handles shuffling
sampler=init_backend["sampler"],
collate_fn=lambda examples: collate_fn(examples),
num_workers=0,
persistent_workers=False,
)

调用逻辑:sampler产生img_metadata,传给dataset.__getitem__,之后传给collate_fn,最后传给model_predict

conditioning_type:

1
2
3
4
5
6
conditioning_type = backend.get("conditioning_type")
if (
backend.get("dataset_type") == "conditioning"
or conditioning_type is not None
):
backend["dataset_type"] = "conditioning"