-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from johndpope/feat/32-cycle-consistency
DRAFT - Feat/32 cycle consistency
- Loading branch information
Showing
26 changed files
with
6,022 additions
and
271 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
## EMODataset Class Summary | ||
|
||
### Overview | ||
The `EMODataset` class is a PyTorch dataset for processing and augmenting video frames, with functionalities to remove backgrounds, warp and crop faces, and save/load processed frames efficiently. The class is designed to handle large video datasets and includes methods to streamline the preprocessing pipeline. | ||
|
||
### Dependencies | ||
The class relies on the following libraries: | ||
- `moviepy.editor`: Video editing and processing. | ||
- `PIL.Image`: Image processing. | ||
- `torch`: PyTorch for tensor operations and model support. | ||
- `torchvision.transforms`: Image transformations. | ||
- `decord`: Efficient video reading. | ||
- `rembg`: Background removal. | ||
- `face_recognition`: Face detection. | ||
- `skimage.transform`: Image warping. | ||
- `cv2`: Video writing with OpenCV. | ||
- `numpy`: Array operations. | ||
- `io`, `os`, `json`, `Path`, `subprocess`, `tqdm`: Standard libraries for file handling, I/O operations, and progress visualization. | ||
|
||
### Initialization | ||
The `__init__` method sets up the dataset with various parameters: | ||
- `use_gpu`, `sample_rate`, `n_sample_frames`, `width`, `height`, `img_scale`, `img_ratio`, `video_dir`, `drop_ratio`, `json_file`, `stage`, `transform`, `remove_background`, `use_greenscreen`, `apply_crop_warping` | ||
- Loads video metadata from the provided JSON file. | ||
- Initializes decord for video reading with PyTorch tensor output. | ||
|
||
### Methods | ||
|
||
#### `__len__` | ||
Returns the length of the dataset, determined by the number of video IDs. | ||
|
||
#### `warp_and_crop_face` | ||
Processes an image tensor to detect, warp, and crop the face region: | ||
- Converts tensor to PIL image. | ||
- Removes background. | ||
- Detects face locations. | ||
- Crops the face region. | ||
- Optionally applies thin-plate-spline warping. | ||
- Converts the processed image back to a tensor and returns it. | ||
|
||
#### `load_and_process_video` | ||
Loads and processes video frames: | ||
- Checks if processed tensor file exists; if so, loads tensors. | ||
- If not, processes video frames, applies augmentation, and saves frames as PNG images and tensors. | ||
- Saves processed tensors as compressed numpy arrays for efficient loading. | ||
|
||
#### `augmentation` | ||
Applies transformations and optional background removal to the provided images: | ||
- Supports both single images and lists of images. | ||
- Returns transformed tensors. | ||
|
||
#### `remove_bg` | ||
Removes the background from the provided image using `rembg`: | ||
- Optionally applies a green screen background. | ||
- Converts image to RGB format and returns it. | ||
|
||
#### `save_video` | ||
Saves a list of frames as a video file: | ||
- Uses OpenCV to write frames to a video file. | ||
|
||
#### `process_video` | ||
Processes all frames of a video: | ||
- Uses the `process_video_frames` method to process frames. | ||
|
||
#### `process_video_frames` | ||
Processes frames of a video using decord: | ||
- Reads frames using decord and applies augmentation. | ||
- Returns processed frames. | ||
|
||
#### `__getitem__` | ||
Returns a sample from the dataset: | ||
- Loads and processes source and driving videos. | ||
- Returns a dictionary containing video IDs and frames. | ||
|
||
### Usage | ||
To use the `EMODataset` class: | ||
1. Initialize the dataset with appropriate parameters. | ||
2. Use PyTorch DataLoader to iterate over the dataset and retrieve samples. | ||
3. Process the frames as needed for training or inference in a machine learning model. | ||
|
||
### Example | ||
```python | ||
from torchvision import transforms | ||
|
||
transform = transforms.Compose([ | ||
transforms.Resize((512, 512)), | ||
transforms.ToTensor(), | ||
]) | ||
|
||
dataset = EMODataset( | ||
use_gpu=False, | ||
sample_rate=5, | ||
n_sample_frames=16, | ||
width=512, | ||
height=512, | ||
img_scale=(0.9, 1.0), | ||
video_dir="path/to/videos", | ||
json_file="path/to/metadata.json", | ||
transform=transform, | ||
remove_background=True, | ||
use_greenscreen=False, | ||
apply_crop_warping=True | ||
) | ||
|
||
for sample in dataset: | ||
print(sample) | ||
``` | ||
|
||
This class provides a comprehensive pipeline for processing video data, making it suitable for tasks such as training deep learning models on video datasets. |
Oops, something went wrong.