1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| model_pred = self.model_predict( batch=batch, latents=latents, noisy_latents=noisy_latents, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, add_text_embeds=add_text_embeds, timesteps=timesteps, ) ```
```Python init_backend["sampler"] = MultiAspectSampler( id=init_backend["id"], metadata_backend=init_backend["metadata_backend"], data_backend=init_backend["data_backend"], accelerator=accelerator, batch_size=args.train_batch_size, debug_aspect_buckets=args.debug_aspect_buckets, delete_unwanted_images=backend.get( "delete_unwanted_images", args.delete_unwanted_images ), resolution=backend.get("resolution", args.resolution), resolution_type=backend.get("resolution_type", args.resolution_type), caption_strategy=backend.get("caption_strategy", args.caption_strategy), use_captions=use_captions, prepend_instance_prompt=backend.get( "prepend_instance_prompt", args.prepend_instance_prompt ), instance_prompt=backend.get("instance_prompt", args.instance_prompt), conditioning_type=conditioning_type, is_regularisation_data=is_regularisation_data, )
init_backend["train_dataloader"] = torch.utils.data.DataLoader( init_backend["train_dataset"], batch_size=1, shuffle=False, sampler=init_backend["sampler"], collate_fn=lambda examples: collate_fn(examples), num_workers=0, persistent_workers=False, )
|