diff --git a/README.md b/README.md index e6ed5f6e..7345e0e5 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,270 @@ After installation completes, run the training script. ## Wan 2.1 Training - Coming soon. + in the first part, we'll run on a single host VM to get familiar with the workflow, then run on xpk for large scale training. + + Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). + + This workflow was tested using v5p-8 with a 500GB disk attached. + + ### Dataset Preparation + + For this example, we'll be using the [PusaV1 dataset](https://huggingface.co/datasets/RaphaelLiu/PusaV1_training). + + First, download the dataset. + + ```bash + export HF_DATASET_DIR=/mnt/disks/external_disk/PusaV1_training/ + export TFRECORDS_DATASET_DIR=/mnt/disks/external_disk/wan_tfr_dataset_pusa_v1 + huggingface-cli download RaphaelLiu/PusaV1_training --repo-type dataset --local-dir $HF_DATASET_DIR + ``` + + Next run the TFRecords conversion script. This step prepares training and eval datasets. Validation is done as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206). More details [here](https://github.com/mlcommons/training/tree/master/text_to_image#5-quality) + + Training dataset. + + ```bash + python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py src/maxdiffusion/configs/base_wan_14b.yml train_data_dir=$HF_DATASET_DIR tfrecords_dir=$TFRECORDS_DATASET_DIR/train no_records_per_shard=10 enable_eval_timesteps=False + ``` + + The script will not have an output, but you can check the progress using: + + ```bash + ls -ll $TFRECORDS_DATASET_DIR/train + ``` + + Evaluation dataset. + + ```bash + python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py src/maxdiffusion/configs/base_wan_14b.yml train_data_dir=$HF_DATASET_DIR tfrecords_dir=$TFRECORDS_DATASET_DIR/eval no_records_per_shard=10 enable_eval_timesteps=True + ``` + + The evaluation dataset creation takes the first 420 samples of the dataset and adds a timestep field. We then need to manually delete the first 420 samples from the `train` folder so they are not used in training. + + + ```bash + printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' rm + ``` + + And verify that they do not exist. + + ```bash + printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' echo + ``` + + After the script is done running, you should see the following directory structure inside `$TFRECORDS_DATASET_DIR` + + ``` + train + eval_timesteps + ``` + + In some instances an empty file `file_42-430.tfrec` is created inside `eval_timesteps`, for sanity check, let's run a delete command. + + ```bash + rm $TFRECORDS_DATASET_DIR/eval_timesteps/file_42-430.tfrec + ``` + + ### Training on a Single VM + + Loading the data is supported both locally from the disk created above, or from `gcs`. In this guide, we'll be using a gcs bucket to train. First copy the data to the GCS bucket. + + ```bash + BUCKET_NAME=my-bucket + gsutil -m cp -r $TFRECORDS_DATASET_DIR gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/} + ``` + + Now run the training command: + + ```bash + RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM} + OUTPUT_DIR=gs://$BUCKET_NAME/wan/ + DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/ + EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/ + SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/ + ``` + + ```bash + export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \ + --xla_tpu_megacore_fusion_allow_ags=false \ + --xla_enable_async_collective_permute=true \ + --xla_tpu_enable_ag_backward_pipelining=true \ + --xla_tpu_enable_data_parallel_all_reduce_opt=true \ + --xla_tpu_data_parallel_opt_different_sized_ops=true \ + --xla_tpu_enable_async_collective_fusion=true \ + --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ + --xla_tpu_overlap_compute_collective_tc=true \ + --xla_enable_async_all_gather=true \ + --xla_tpu_scoped_vmem_limit_kib=65536 \ + --xla_tpu_enable_async_all_to_all=true \ + --xla_tpu_enable_all_experimental_scheduler_features=true \ + --xla_tpu_enable_scheduler_memory_pressure_tracking=true \ + --xla_tpu_host_transfer_overlap_limit=24 \ + --xla_tpu_aggressive_opt_barrier_removal=ENABLED \ + --xla_lhs_prioritize_async_depth_over_stall=ENABLED \ + --xla_should_allow_loop_variant_parameter_in_chain=ENABLED \ + --xla_should_add_loop_invariant_op_in_chain=ENABLED \ + --xla_max_concurrent_host_send_recv=100 \ + --xla_tpu_scheduler_percent_shared_memory_limit=100 \ + --xla_latency_hiding_scheduler_rerun=2 \ + --xla_tpu_use_minor_sharding_for_major_trivial_input=true \ + --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \ + --xla_tpu_assign_all_reduce_scatter_layout=true' + ``` + + ```bash + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \ + src/maxdiffusion/configs/base_wan_14b.yml \ + attention='flash' \ + weights_dtype=bfloat16 \ + activations_dtype=bfloat16 \ + guidance_scale=5.0 \ + flow_shift=5.0 \ + fps=16 \ + skip_jax_distributed_system=False \ + run_name=${RUN_NAME} \ + output_dir=${OUTPUT_DIR} \ + train_data_dir=${DATASET_DIR} \ + load_tfrecord_cached=True \ + height=1280 \ + width=720 \ + num_frames=81 \ + num_inference_steps=50 \ + jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \ + max_train_steps=1000 \ + enable_profiler=True \ + dataset_save_location=${SAVE_DATASET_DIR} \ + remat_policy='FULL' \ + flash_min_seq_length=0 \ + seed=$RANDOM \ + skip_first_n_steps_for_profiler=3 \ + profiler_steps=3 \ + per_device_batch_size=0.25 \ + ici_data_parallelism=1 \ + ici_fsdp_parallelism=4 \ + ici_tensor_parallelism=1 + ``` + + It is important to note a couple of things: + - per_device_batch_size can be a fractional, but must be a whole number when multiplied by number of devices. In this example, 0.25 * 4 (devices) = effective global batch size = 1. + - The step time in v5p-8 with global batch size = 1 is large due to using `FULL` remat. On larger number of chips we can run larger batch sizes greatly increasing MFU, as we will see in the next session of deploying with xpk. + - To enable eval during training set `eval_every` to a value > 0. + - In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism. + - You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism. + - For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now. + + You should eventually see a training run as: + + ```bash + ***** Running training ***** + Instantaneous batch size per device = 0.25 + Total train batch size (w. parallel & distributed) = 1 + Total optimization steps = 1000 + Calculated TFLOPs per pass: 4893.2719 + Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4 + Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4 + Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4 + Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4 + completed step: 0, seconds: 142.395, TFLOP/s/device: 34.364, loss: 0.270 + To see full metrics 'tensorboard --logdir=gs://jfacevedo-maxdiffusion-v5p/wan/jfacevedo-wan-v5p-8-17263/tensorboard/' + completed step: 1, seconds: 137.207, TFLOP/s/device: 35.664, loss: 0.144 + completed step: 2, seconds: 36.014, TFLOP/s/device: 135.871, loss: 0.210 + completed step: 3, seconds: 36.016, TFLOP/s/device: 135.864, loss: 0.120 + completed step: 4, seconds: 36.008, TFLOP/s/device: 135.894, loss: 0.107 + completed step: 5, seconds: 36.008, TFLOP/s/device: 135.895, loss: 0.346 + completed step: 6, seconds: 36.006, TFLOP/s/device: 135.900, loss: 0.169 + ``` + + ### Deploying with XPK + + This assummes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md). + + Using v5p-256 Then the command to run on xpk is as follows: + + ```bash + RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM} + OUTPUT_DIR=gs://$BUCKET_NAME/wan/ + DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/ + EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/ + SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/ + ``` + + ```bash + LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \ + --xla_tpu_megacore_fusion_allow_ags=false \ + --xla_enable_async_collective_permute=true \ + --xla_tpu_enable_ag_backward_pipelining=true \ + --xla_tpu_enable_data_parallel_all_reduce_opt=true \ + --xla_tpu_data_parallel_opt_different_sized_ops=true \ + --xla_tpu_enable_async_collective_fusion=true \ + --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ + --xla_tpu_overlap_compute_collective_tc=true \ + --xla_enable_async_all_gather=true \ + --xla_tpu_scoped_vmem_limit_kib=65536 \ + --xla_tpu_enable_async_all_to_all=true \ + --xla_tpu_enable_all_experimental_scheduler_features=true \ + --xla_tpu_enable_scheduler_memory_pressure_tracking=true \ + --xla_tpu_host_transfer_overlap_limit=24 \ + --xla_tpu_aggressive_opt_barrier_removal=ENABLED \ + --xla_lhs_prioritize_async_depth_over_stall=ENABLED \ + --xla_should_allow_loop_variant_parameter_in_chain=ENABLED \ + --xla_should_add_loop_invariant_op_in_chain=ENABLED \ + --xla_max_concurrent_host_send_recv=100 \ + --xla_tpu_scheduler_percent_shared_memory_limit=100 \ + --xla_latency_hiding_scheduler_rerun=2 \ + --xla_tpu_use_minor_sharding_for_major_trivial_input=true \ + --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \ + --xla_tpu_assign_all_reduce_scatter_layout=true' + ``` + + ```bash + python3 ~/xpk/xpk.py workload create \ + --cluster=$CLUSTER_NAME \ + --project=$PROJECT \ + --zone=$ZONE \ + --device-type=$DEVICE_TYPE \ + --num-slices=1 \ + --command=" \ + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \ + src/maxdiffusion/configs/base_wan_14b.yml \ + attention='flash' \ + weights_dtype=bfloat16 \ + activations_dtype=bfloat16 \ + guidance_scale=5.0 \ + flow_shift=5.0 \ + fps=16 \ + skip_jax_distributed_system=False \ + run_name=${RUN_NAME} \ + output_dir=${OUTPUT_DIR} \ + train_data_dir=${DATASET_DIR} \ + load_tfrecord_cached=True \ + height=1280 \ + width=720 \ + num_frames=81 \ + num_inference_steps=50 \ + jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \ + enable_profiler=True \ + dataset_save_location=${SAVE_DATASET_DIR} \ + remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \ + flash_min_seq_length=0 \ + seed=$RANDOM \ + skip_first_n_steps_for_profiler=3 \ + profiler_steps=3 \ + per_device_batch_size=0.25 \ + ici_data_parallelism=32 \ + ici_fsdp_parallelism=4 \ + ici_tensor_parallelism=1" \ + max_train_steps=5000 \ + eval_every=100 \ + eval_data_dir=${EVAL_DATA_DIR} \ + enable_generate_video_for_eval=True \ + warmup_steps_fraction=0.025" + --base-docker-image=${IMAGE_DIR} \ + --enable-debug-logs \ + --workload=${RUN_NAME} \ + --priority=medium \ + --max-restarts=0 + ``` ## Flux Training diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 4a973045..78a65377 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -237,7 +237,7 @@ global_batch_size: 0 tfrecords_dir: '' no_records_per_shard: 0 enable_eval_timesteps: False -considered_timesteps_list: [125, 250, 375, 500, 625, 750, 875] +timesteps_list: [125, 250, 375, 500, 625, 750, 875] num_eval_samples: 420 warmup_steps_fraction: 0.1 @@ -321,6 +321,6 @@ qwix_module_path: ".*" eval_every: -1 eval_data_dir: "" enable_generate_video_for_eval: False # This will increase the used TPU memory. -eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list). +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list). enable_ssim: False \ No newline at end of file