Diffusion model - basic explained

Posted by Hao Do on July 15, 2023

Coding

1
2
3
4
5
6
7
8
'''
Diffusion model - basic explained.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1C1IaA1CHMyRraZvCgFgRcYazyEXXxrwP
'''
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
'''
Mục đích của code này chỉ là dùng diffusion model cho forward, sau đó dùng unet để backward về lại ảnh gốc.
Có sử dụng để test classification task.
'''

import torchvision.transforms as transforms
import torch.nn as nn
import torchvision
import math
import matplotlib.pyplot as plt
import torch
import urllib
import numpy as np
import PIL

device = torch.device("cuda:0")

def get_sample_image()-> PIL.Image.Image:
    url = 'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTZmJy3aSZ1Ix573d2MlJXQowLCLQyIUsPdniOJ7rBsgG4XJb04g9ZFA9MhxYvckeKkVmo&usqp=CAU'
    filename = 'racoon.jpg'
    urllib.request.urlretrieve(url, filename)
    return PIL.Image.open(filename)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
'''
Hàm trên có tên là `plot_noise_distribution`, và nó được sử dụng để vẽ biểu đồ phân bố của hai loại nhiễu: nhiễu thực tế và nhiễu được dự đoán.

Dưới đây là cách hoạt động của hàm:

1. Hàm `plot_noise_distribution` nhận hai đối số là `noise` và `predicted_noise`, có thể giả định là hai tensor chứa thông tin về nhiễu thực tế và nhiễu được dự đoán tương ứng.

2. Đầu tiên, hàm chuyển đổi hai tensor `noise` và `predicted_noise` sang mảng NumPy bằng cách sử dụng phương thức `.cpu().numpy()`. Việc này là cần thiết nếu hai tensor này đang được sử dụng trong ngữ cảnh của mô hình PyTorch và muốn biểu đồ histogram.

3. Hàm sau đó sử dụng `plt.hist` (từ thư viện matplotlib) để tạo biểu đồ histogram cho hai loại nhiễu. Các tham số của hàm này bao gồm:

   - `noise.cpu().numpy().flatten()`: Dữ liệu của nhiễu thực tế được truyền vào để vẽ histogram.
   - `predicted_noise.cpu().numpy().flatten()`: Dữ liệu của nhiễu được dự đoán được truyền vào để vẽ histogram.
   - `density=True`: Tham số này chỉ định rằng histogram sẽ được chuẩn hóa để hiển thị phân phối xác suất thay vì đếm tần suất xuất hiện.
   - `alpha=0.8`: Tham số này xác định độ trong suốt của các thanh trong histogram. Trong trường hợp này, 0.8 có nghĩa là các thanh sẽ khá trong suốt, giúp dễ dàng nhìn thấy sự chồng chéo của hai histogram.
   - `label="ground truth noise"` và `label="predicted noise"`: Tham số này xác định nhãn cho mỗi histogram trong biểu đồ để có thể tạo chú thích.

4. Sau khi đã vẽ hai histogram, hàm sử dụng `plt.legend()` để hiển thị chú thích cho các loại nhiễu đã vẽ.

5. Cuối cùng, hàm sử dụng `plt.show()` để hiển thị biểu đồ histogram cho người dùng.

Tóm lại, hàm này được sử dụng để so sánh và hiển thị biểu đồ phân bố giữa nhiễu thực tế và nhiễu được dự đoán, giúp người dùng có cái nhìn trực quan về sự tương đồng hoặc khác biệt giữa hai loại nhiễu này.
'''

def plot_noise_distribution(noise, predicted_noise):
    plt.hist(noise.cpu().numpy().flatten(), density = True, alpha = 0.8, label = "ground truth noise")
    plt.hist(predicted_noise.cpu().numpy().flatten(), density = True, alpha = 0.8, label = "predicted noise")
    plt.legend()
    plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
'''
Hàm `plot_noise_prediction` được sử dụng để hiển thị và so sánh hai bức ảnh nhiễu: nhiễu thực tế và nhiễu được dự đoán. Biểu đồ sẽ có hai cột ảnh nhiễu, mỗi cột đại diện cho một loại nhiễu.

Dưới đây là cách hoạt động của hàm:

1. Hàm `plot_noise_prediction` nhận hai đối số là `noise` và `predicted_noise`, có thể giả định là hai tensor chứa thông tin về nhiễu thực tế và nhiễu được dự đoán tương ứng.

2. Hàm sử dụng `plt.figure(figsize=(15, 15))` để tạo một khu vực vẽ mới với kích thước 15x15 inches.

3. Hàm sử dụng `plt.subplots(1, 2, figsize=(5, 5))` để tạo một lưới với một hàng và hai cột. Mỗi ô trong lưới chứa một bức ảnh nhiễu, trong trường hợp này là nhiễu thực tế và nhiễu được dự đoán. Kích thước của mỗi bức ảnh được đặt là 5x5 inches.

4. Hàm sử dụng `ax[0].imshow(reverse_transform(noise))` để hiển thị bức ảnh nhiễu thực tế trong ô đầu tiên của lưới. `reverse_transform` là một hàm được giả định đã được định nghĩa ở nơi khác để chuyển đổi lại từ dữ liệu tensor thành định dạng hình ảnh thường (ví dụ: numpy array hoặc PIL.Image.Image).

5. Hàm sử dụng `ax[0].set_title(f"ground truth noise", fontsize=10)` để đặt tiêu đề cho ảnh nhiễu thực tế.

6. Tương tự, hàm sử dụng `ax[1].imshow(reverse_transform(predicted_noise))` để hiển thị bức ảnh nhiễu được dự đoán trong ô thứ hai của lưới.

7. Hàm sử dụng `ax[1].set_title(f"predicted noise", fontsize=10)` để đặt tiêu đề cho ảnh nhiễu được dự đoán.

8. Cuối cùng, hàm sử dụng `plt.show()` để hiển thị lưới chứa hai ảnh nhiễu cho người dùng.

Tóm lại, hàm này dùng để hiển thị hai bức ảnh nhiễu (nhiễu thực tế và nhiễu được dự đoán) cạnh nhau để so sánh chúng và kiểm tra tính chính xác của quá trình dự đoán nhiễu trong một tác vụ cụ thể.
'''

def plot_noise_prediction(noise, predicted_noise):
    plt.figure(figsize=(15,15))
    f, ax = plt.subplots(1, 2, figsize = (5,5))
    ax[0].imshow(reverse_transform(noise))
    ax[0].set_title(f"ground truth noise", fontsize = 10)
    ax[1].imshow(reverse_transform(predicted_noise))
    ax[1].set_title(f"predicted noise", fontsize = 10)
    plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
'''
Đoạn mã trên định nghĩa một lớp (class) có tên `DiffusionModel`, có chức năng thực hiện các phép tính liên quan đến mô hình diffusion (phân tán) trong xử lý ảnh. Được giải thích như sau:

1. Phương thức `__init__(self, start_schedule=0.0001, end_schedule=0.02, timesteps=300)`: Đây là hàm khởi tạo, được gọi khi một đối tượng của lớp `DiffusionModel` được tạo ra. Các tham số `start_schedule`, `end_schedule`, và `timesteps` là giá trị mặc định cho lịch trình (schedule) bắt đầu, kết thúc và số bước thời gian tạo ra trong mô hình diffusion.

2. Phương thức `forward(self, x_0, t, device)`: Đây là phương thức dự đoán (forward) của mô hình. Nó nhận vào đầu vào `x_0` là một ảnh, `t` là các chỉ số thời gian (time step), và `device` là thiết bị tính toán. Phương thức này tính toán sự lan tỏa của nhiễu trong ảnh và trả về giá trị ước lượng cho ảnh đã bị nhiễu và nhiễu chính xác.

3. Phương thức `backward(self, x, t, model, **kwargs)`: Đây là phương thức tính toán ngược (backward) của mô hình. Nó nhận vào đầu vào `x` là một ảnh, `t` là các chỉ số thời gian, `model` là mô hình dự đoán nhiễu, và `**kwargs` là các đối số khác của mô hình. Phương thức này sẽ dự đoán và trả về ảnh không bị nhiễu từ ảnh đã nhiễu, thực hiện việc loại bỏ nhiễu từ ảnh.

4. Phương thức `get_index_from_list(values, t, x_shape)`: Đây là phương thức tĩnh (static method) được sử dụng để lấy các giá trị từ một danh sách (tensor) dựa trên các chỉ số thời gian `t`. Nó nhận vào `values` là danh sách các giá trị, `t` là các chỉ số thời gian, và `x_shape` là kích thước của ảnh.

Lớp `DiffusionModel` này được sử dụng để mô hình hóa quá trình phân tán nhiễu trong ảnh và thực hiện việc dự đoán và loại bỏ nhiễu từ ảnh đã nhiễu.
'''

class DiffusionModel:
    def __init__(self, start_schedule=0.0001, end_schedule=0.02, timesteps = 300):
        self.start_schedule = start_schedule
        self.end_schedule = end_schedule
        self.timesteps = timesteps

        """
        if
            betas = [0.1, 0.2, 0.3, ...]
        then
            alphas = [0.9, 0.8, 0.7, ...]
            alphas_cumprod = [0.9, 0.9 * 0.8, 0.9 * 0.8, * 0.7, ...]


        """
        self.betas = torch.linspace(start_schedule, end_schedule, timesteps)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)

    def forward(self, x_0, t, device):
        """
        x_0: (B, C, H, W)
        t: (B,)
        """
        noise = torch.randn_like(x_0)
        sqrt_alphas_cumprod_t = self.get_index_from_list(self.alphas_cumprod.sqrt(), t, x_0.shape)
        sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alphas_cumprod), t, x_0.shape)

        mean = sqrt_alphas_cumprod_t.to(device) * x_0.to(device)
        variance = sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device)

        return mean + variance, noise.to(device)

    @torch.no_grad()
    def backward(self, x, t, model, **kwargs):
        """
        Calls the model to predict the noise in the image and returns
        the denoised image.
        Applies noise to this image, if we are not in the last step yet.
        """
        betas_t = self.get_index_from_list(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alphas_cumprod), t, x.shape)
        sqrt_recip_alphas_t = self.get_index_from_list(torch.sqrt(1.0 / self.alphas), t, x.shape)
        mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t, **kwargs) / sqrt_one_minus_alphas_cumprod_t)
        posterior_variance_t = betas_t

        if t == 0:
            return mean
        else:
            noise = torch.randn_like(x)
            variance = torch.sqrt(posterior_variance_t) * noise
            return mean + variance

    @staticmethod
    def get_index_from_list(values, t, x_shape):
        batch_size = t.shape[0]
        """
        pick the values from vals
        according to the indices stored in `t`
        """
        result = values.gather(-1, t.cpu())
        """
        if
        x_shape = (5, 3, 64, 64)
            -> len(x_shape) = 4
            -> len(x_shape) - 1 = 3

        and thus we reshape `out` to dims
        (batch_size, 1, 1, 1)

        """
        return result.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
IMAGE_SHAPE = (32, 32)

'''
Hai đoạn mã trên định nghĩa hai hàm giúp chuyển đổi dữ liệu ảnh giữa dạng chuẩn và dạng chuẩn hoá (normalized). Cụ thể:

1. Hàm `transform`: Được sử dụng để chuyển đổi dữ liệu ảnh về dạng chuẩn hoá, phù hợp với đầu vào của mô hình huấn luyện. Đây là quá trình tiền xử lý (pre-processing) cho dữ liệu ảnh.

- `transforms.Resize(IMAGE_SHAPE)`: Điều chỉnh kích thước của ảnh đầu vào thành kích thước được xác định trước (`IMAGE_SHAPE`).
- `transforms.ToTensor()`: Chuyển đổi ảnh thành tensor PyTorch và chia giá trị pixel của ảnh thành khoảng [0, 1].
- `transforms.Lambda(lambda t: (t * 2) - 1)`: Biến đổi dữ liệu của tensor để nằm trong khoảng [-1, 1].

2. Hàm `reverse_transform`: Được sử dụng để chuyển đổi dữ liệu ảnh từ dạng chuẩn hoá trở lại dạng ảnh gốc, phù hợp với đầu ra của mô hình huấn luyện. Đây là quá trình hậu xử lý (post-processing) cho kết quả của mô hình.

- `transforms.Lambda(lambda t: (t + 1) / 2)`: Biến đổi dữ liệu của tensor để nằm trong khoảng [0, 1].
- `transforms.Lambda(lambda t: t.permute(1, 2, 0))`: Chuyển đổi tensor từ dạng CHW (channels-first) sang dạng HWC (height-width-channels) để phù hợp với định dạng hình ảnh thông thường.
- `transforms.Lambda(lambda t: t * 255.)`: Scale dữ liệu của tensor để nằm trong khoảng [0, 255] (giá trị pixel hợp lệ cho ảnh uint8).
- `transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8))`: Chuyển đổi tensor thành mảng numpy kiểu uint8.
- `transforms.ToPILImage()`: Chuyển đổi mảng numpy thành đối tượng hình ảnh của PIL (Python Imaging Library).

Tóm lại, hàm `transform` được sử dụng để chuẩn hoá dữ liệu ảnh đầu vào trước khi đưa vào mô hình huấn luyện. Hàm `reverse_transform` được sử dụng để chuyển đổi kết quả của mô hình trở lại dạng ảnh gốc để hiển thị hoặc lưu trữ.
'''

transform = transforms.Compose([
    transforms.Resize(IMAGE_SHAPE), # Resize the input image
    transforms.ToTensor(), # Convert to torch tensor (scales data into [0,1])
    transforms.Lambda(lambda t: (t * 2) - 1), # Scale data between [-1, 1]
])


reverse_transform = transforms.Compose([
    transforms.Lambda(lambda t: (t + 1) / 2), # Scale data between [0,1]
    transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
    transforms.Lambda(lambda t: t * 255.), # Scale data between [0.,255.]
    transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)), # Convert into an uint8 numpy array
    transforms.ToPILImage(), # Convert to PIL image
])

pil_image = get_sample_image()
torch_image = transform(pil_image)

diffusion_model = DiffusionModel()

NO_DISPLAY_IMAGES = 5
torch_image_batch = torch.stack([torch_image] * NO_DISPLAY_IMAGES)
t = torch.linspace(0, diffusion_model.timesteps - 1, NO_DISPLAY_IMAGES).long()
noisy_image_batch, _ = diffusion_model.forward(torch_image_batch, t, device)

plt.figure(figsize=(15,15))
f, ax = plt.subplots(1, NO_DISPLAY_IMAGES, figsize = (100,100))

for idx, image in enumerate(noisy_image_batch):
    ax[idx].imshow(reverse_transform(image))
    ax[idx].set_title(f"Iteration: {t[idx].item()}", fontsize = 100)
plt.show()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
'''
Hàm `SinusoidalPositionEmbeddings` được sử dụng để tính toán các nhúng vị trí (position embeddings) sử dụng phép nhúng chuỗi dạng sóng (sinusoidal embeddings). Điều này thường được sử dụng trong các mô hình xử lý dữ liệu tuần tự như Transformer để giúp mô hình học được thông tin về thứ tự và vị trí của các thành phần trong dữ liệu.

Đoạn mã trên định nghĩa một lớp (class) `SinusoidalPositionEmbeddings` kế thừa từ `nn.Module`, và có một phương thức `forward` để tính toán các embeddings.

Cụ thể, các bước trong phương thức `forward` là:

1. `self.dim`: Đây là kích thước của nhúng vị trí, được chỉ định khi khởi tạo lớp.

2. Tính toán các thông số liên quan đến nhúng chuỗi dạng sóng:

   - `half_dim = self.dim // 2`: Chia kích thước của nhúng vị trí cho 2 để xác định số chiều cho mỗi phần tử dạng sóng (sin và cos) trong embeddings.
   - `embeddings = math.log(10000) / (half_dim - 1)`: Tính toán giá trị cơ số (base) để sử dụng trong tính toán embeddings dạng sóng.
   - `embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)`: Tính toán các giá trị embeddings dạng sóng sử dụng hàm mũ (exponential).

3. Tính toán embeddings dạng sóng cho các vị trí thời gian (time) được đưa vào:

   - `embeddings = time[:, None] * embeddings[None, :]`: Nhân ma trận thời gian với ma trận embeddings dạng sóng để tính toán embeddings cuối cùng. `time` là đối số đầu vào đại diện cho các vị trí thời gian.
   - `embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)`: Kết hợp các phần tử dạng sóng sin và cos để tạo ra embeddings cuối cùng. Mục đích của việc này là để bảo đảm tính chất chu kỳ và không tương quan giữa các chiều embeddings.

Cuối cùng, phương thức `forward` trả về các embeddings dạng sóng đã tính toán.

Tóm lại, lớp `SinusoidalPositionEmbeddings` này được sử dụng để tính toán các nhúng vị trí sử dụng phép nhúng chuỗi dạng sóng (sinusoidal embeddings), thường được sử dụng trong các mô hình xử lý dữ liệu tuần tự để giúp mô hình hiểu thêm về thông tin vị trí trong dữ liệu.
'''

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
'''
Hàm `Block` là một module trong mô hình deep learning và được sử dụng để xây dựng các khối (blocks) của mô hình. Module này bao gồm một loạt các lớp liên tiếp như Convolution, Batch Normalization, và Linear (Fully Connected) để thực hiện các phép tính xử lý ảnh.

Cụ thể, mô-đun `Block` nhận các đối số sau khi được khởi tạo:

- `channels_in`: Số lượng kênh (channels) của dữ liệu đầu vào.
- `channels_out`: Số lượng kênh (channels) của dữ liệu đầu ra.
- `time_embedding_dims`: Kích thước của embeddings dạng sóng (sinusoidal embeddings) cho thông tin thời gian.
- `labels`: Xác định xem dữ liệu có chứa nhãn (labels) không.
- `num_filters`: Số lượng bộ lọc (filters) của lớp Convolution.
- `downsample`: Xác định xem module này có phải là module downsample hay không (downsampling giảm kích thước ảnh).

Sau khi khởi tạo, các lớp và phép tính trong module `Block` bao gồm:

- `self.time_embedding`: Một lớp `SinusoidalPositionEmbeddings` dùng để tính toán embeddings dạng sóng cho thông tin thời gian.
- `self.label_mlp`: Một lớp Linear (Fully Connected) dùng để xử lý nhãn (labels) nếu có.
- `self.conv1`: Lớp Convolution 2D đầu tiên. Nếu `downsample=True`, thì kích thước ảnh sẽ được giảm xuống sau lớp này; ngược lại, kích thước ảnh sẽ được giữ nguyên.
- `self.final`: Lớp Convolution hoặc ConvolutionTranspose 2D cuối cùng. Nó sẽ xác định kích thước ảnh đầu ra của module.
- `self.bnorm1`: Lớp Batch Normalization đầu tiên.
- `self.bnorm2`: Lớp Batch Normalization thứ hai.
- `self.conv2`: Lớp Convolution 2D thứ hai.
- `self.time_mlp`: Một lớp Linear (Fully Connected) dùng để xử lý thông tin thời gian (time embeddings).
- `self.relu`: Hàm kích hoạt ReLU (Rectified Linear Unit).

Phương thức `forward` của module `Block` thực hiện phép tính chuyển tiếp (forward pass) của dữ liệu qua các lớp và trả về kết quả đầu ra. Cụ thể:

- Đầu tiên, dữ liệu đầu vào `x` được đưa qua lớp Convolution đầu tiên (`self.conv1`), sau đó áp dụng Batch Normalization và hàm kích hoạt ReLU.
- Thông tin thời gian `t` được chuyển qua `self.time_embedding` để tính toán embeddings dạng sóng.
- Thực hiện các phép tính liên quan đến thông tin thời gian và nhãn nếu có (dựa vào giá trị của `self.labels` và đối số `**kwargs`).
- Dữ liệu được đưa qua lớp Convolution thứ hai (`self.conv2`) và sau đó áp dụng Batch Normalization và hàm kích hoạt ReLU.
- Kết quả đầu ra được tính toán qua lớp Convolution/CovolutionTranspose cuối cùng (`self.final`) để xác định kích thước ảnh đầu ra của module.

Tóm lại, module `Block` này được sử dụng để xây dựng một khối trong mô hình deep learning với các lớp Convolution, Batch Normalization, và Linear để xử lý dữ liệu ảnh và thông tin thời gian (nếu có). Nó được thiết kế để thực hiện phép tính chuyển tiếp qua các lớp và trả về kết quả đầu ra.
'''

class Block(nn.Module):
    def __init__(self, channels_in, channels_out, time_embedding_dims, labels, num_filters = 3, downsample=True):
        super().__init__()

        self.time_embedding_dims = time_embedding_dims
        self.time_embedding = SinusoidalPositionEmbeddings(time_embedding_dims)
        self.labels = labels
        if labels:
            self.label_mlp = nn.Linear(1, channels_out)

        self.downsample = downsample

        if downsample:
            self.conv1 = nn.Conv2d(channels_in, channels_out, num_filters, padding=1)
            self.final = nn.Conv2d(channels_out, channels_out, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(2 * channels_in, channels_out, num_filters, padding=1)
            self.final = nn.ConvTranspose2d(channels_out, channels_out, 4, 2, 1)

        self.bnorm1 = nn.BatchNorm2d(channels_out)
        self.bnorm2 = nn.BatchNorm2d(channels_out)

        self.conv2 = nn.Conv2d(channels_out, channels_out, 3, padding=1)
        self.time_mlp = nn.Linear(time_embedding_dims, channels_out)
        self.relu = nn.ReLU()

    def forward(self, x, t, **kwargs):
        o = self.bnorm1(self.relu(self.conv1(x)))
        o_time = self.relu(self.time_mlp(self.time_embedding(t)))
        o = o + o_time[(..., ) + (None, ) * 2]
        if self.labels:
            label = kwargs.get('labels')
            o_label = self.relu(self.label_mlp(label))
            o = o + o_label[(..., ) + (None, ) * 2]

        o = self.bnorm2(self.relu(self.conv2(o)))

        return self.final(o)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
'''
Hàm `UNet` là một mô hình Unet được định nghĩa dưới dạng một lớp (class) kế thừa từ `nn.Module`. Mô hình này được sử dụng để thực hiện bài toán xây dựng và ánh xạ các dữ liệu ảnh (2D) có kích thước và thông tin thời gian (sequence) thành dữ liệu ảnh tương tự như ảnh gốc.

Cụ thể, mô hình `UNet` được khởi tạo với các đối số sau:

- `img_channels`: Số lượng kênh (channels) của ảnh đầu vào.
- `time_embedding_dims`: Kích thước của embeddings dạng sóng (sinusoidal embeddings) cho thông tin thời gian.
- `labels`: Xác định xem dữ liệu có chứa nhãn (labels) không.
- `sequence_channels`: Danh sách (tuple) chứa số lượng kênh của các khối trong quá trình trích xuất đặc trưng của mô hình.

Sau khi khởi tạo, mô hình `UNet` bao gồm các lớp và phép tính sau:

- `self.time_embedding_dims` và `sequence_channels_rev`: Đây là các biến được sử dụng để khởi tạo dữ liệu cho các khối trong mô hình Unet.
- `self.downsampling`: Là một danh sách (list) các module `Block` được sử dụng để thực hiện phép xuống mẫu dữ liệu (downsampling) thông qua các lớp Convolution, Batch Normalization, và Linear (nếu có).
- `self.upsampling`: Là một danh sách (list) các module `Block` được sử dụng để thực hiện phép lên mẫu dữ liệu (upsampling) thông qua các lớp Convolution hoặc ConvolutionTranspose.
- `self.conv1` và `self.conv2`: Là các lớp Convolution 2D được sử dụng để thực hiện các phép tính Convolution lên dữ liệu ảnh đầu vào.

Phương thức `forward` của mô hình `UNet` thực hiện phép tính chuyển tiếp (forward pass) của dữ liệu qua các khối (downsampling và upsampling) của mô hình Unet và trả về kết quả đầu ra. Cụ thể:

- Dữ liệu đầu vào `x` được đưa qua lớp Convolution đầu tiên (`self.conv1`).
- Dữ liệu đầu ra từ Convolution đầu tiên sau đó được truyền qua mạng downsampling bằng cách gọi các lớp trong `self.downsampling` theo thứ tự.
- Dữ liệu sau khi đã đi qua các khối downsampling sẽ được truyền qua mạng upsampling bằng cách gọi các lớp trong `self.upsampling` theo thứ tự.
- Cuối cùng, dữ liệu được đưa qua lớp Convolution cuối cùng (`self.conv2`) để tạo ra kết quả đầu ra của mô hình.

Tóm lại, mô hình `UNet` này được sử dụng để thực hiện các phép tính upsampling và downsampling thông qua các lớp Convolution và ConvolutionTranspose để ánh xạ các dữ liệu ảnh có thông tin thời gian thành dữ liệu ảnh tương tự như ảnh gốc.
'''

class UNet(nn.Module):
    def __init__(self, img_channels = 3, time_embedding_dims = 128, labels = False, sequence_channels = (64, 128, 256, 512, 1024)):
        super().__init__()
        self.time_embedding_dims = time_embedding_dims
        sequence_channels_rev = reversed(sequence_channels)

        self.downsampling = nn.ModuleList([Block(channels_in, channels_out, time_embedding_dims, labels) for channels_in, channels_out in zip(sequence_channels, sequence_channels[1:])])
        self.upsampling = nn.ModuleList([Block(channels_in, channels_out, time_embedding_dims, labels,downsample=False) for channels_in, channels_out in zip(sequence_channels[::-1], sequence_channels[::-1][1:])])
        self.conv1 = nn.Conv2d(img_channels, sequence_channels[0], 3, padding=1)
        self.conv2 = nn.Conv2d(sequence_channels[0], img_channels, 1)


    def forward(self, x, t, **kwargs):
        residuals = []
        o = self.conv1(x)
        for ds in self.downsampling:
            o = ds(o, t, **kwargs)
            residuals.append(o)
        for us, res in zip(self.upsampling, reversed(residuals)):
            o = us(torch.cat((o, res), dim=1), t, **kwargs)

        return self.conv2(o)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
NO_EPOCHS = 2000
PRINT_FREQUENCY = 400
LR = 0.001
BATCH_SIZE = 128
VERBOSE = True

unet = UNet(labels=False)
unet.to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)

for epoch in range(NO_EPOCHS):
    mean_epoch_loss = []

    batch = torch.stack([torch_image] * BATCH_SIZE)
    t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,)).long().to(device)

    batch_noisy, noise = diffusion_model.forward(batch, t, device)
    predicted_noise = unet(batch_noisy, t)

    optimizer.zero_grad()
    loss = torch.nn.functional.mse_loss(noise, predicted_noise)
    mean_epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    if epoch % PRINT_FREQUENCY == 0:
        print('---')
        print(f"Epoch: {epoch} | Train Loss {np.mean(mean_epoch_loss)}")
        if VERBOSE:
            with torch.no_grad():
                plot_noise_prediction(noise[0], predicted_noise[0])
                plot_noise_distribution(noise, predicted_noise)

with torch.no_grad():
    img = torch.randn((1, 3) + IMAGE_SHAPE).to(device)
    for i in reversed(range(diffusion_model.timesteps)):
        t = torch.full((1,), i, dtype=torch.long, device=device)
        img = diffusion_model.backward(img, t, unet.eval())
        if i % 50 == 0:
            plt.figure(figsize=(2,2))
            plt.imshow(reverse_transform(img[0]))
            plt.show()


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
## Test with CIFAR10 dataset.

BATCH_SIZE = 256
NO_EPOCHS = 100
PRINT_FREQUENCY = 10
LR = 0.001
VERBOSE = False

unet = UNet(labels=True)
unet.to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, drop_last=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, drop_last=True)

'''
Đoạn mã trên là một vòng lặp huấn luyện mô hình neural network sử dụng hai tập dữ liệu là `trainloader` và `testloader`. Mô hình bao gồm một `diffusion_model` (mô hình phân tán) và một `unet` (mô hình Unet). Mục tiêu là để huấn luyện mô hình Unet để dự đoán và loại bỏ nhiễu từ các ảnh nhiễu được tạo từ mô hình phân tán.

Các bước trong quá trình huấn luyện là:

1. Đối với mỗi epoch (vòng lặp huấn luyện), thực hiện huấn luyện trên tập dữ liệu huấn luyện (`trainloader`) và đánh giá kết quả trên tập dữ liệu kiểm tra (`testloader`).

2. Trong quá trình huấn luyện:
   - Lấy ngẫu nhiên một thời điểm (`t`) từ tập `diffusion_model.timesteps` để tạo dữ liệu nhiễu. Lưu ý rằng `diffusion_model` là mô hình phân tán, và dữ liệu nhiễu được tạo từ mô hình này.
   - Lấy một batch dữ liệu từ `trainloader`, đưa dữ liệu vào thiết bị tính toán `device` (thường là GPU).
   - Sử dụng `diffusion_model` để tạo dữ liệu nhiễu từ dữ liệu đầu vào.
   - Sử dụng `unet` để dự đoán nhiễu và loại bỏ nhiễu từ dữ liệu nhiễu.
   - Tính toán hàm mất mát giữa nhiễu thực tế và nhiễu được dự đoán bằng cách sử dụng hàm lỗi bình phương trung bình (mean squared error).
   - Thực hiện quá trình lan truyền ngược (backpropagation) để cập nhật các trọng số của `unet` thông qua bước cập nhật tối ưu hóa `optimizer`.

3. Đánh giá trên tập kiểm tra (`testloader`):
   - Lấy ngẫu nhiên một thời điểm (`t`) từ tập `diffusion_model.timesteps` để tạo dữ liệu nhiễu.
   - Lấy một batch dữ liệu từ `testloader`, đưa dữ liệu vào thiết bị tính toán `device` (thường là GPU).
   - Sử dụng `diffusion_model` để tạo dữ liệu nhiễu từ dữ liệu đầu vào.
   - Sử dụng `unet` để dự đoán nhiễu và loại bỏ nhiễu từ dữ liệu nhiễu.
   - Tính toán hàm mất mát giữa nhiễu thực tế và nhiễu được dự đoán bằng cách sử dụng hàm lỗi bình phương trung bình (mean squared error).

4. Thông báo và lưu mô hình:
   - Nếu số epoch hiện tại chia hết cho `PRINT_FREQUENCY`, hiển thị thông tin về epoch, mất mát trung bình trên tập huấn luyện (`mean_epoch_loss`) và mất mát trung bình trên tập kiểm tra (`mean_epoch_loss_val`).
   - Nếu được yêu cầu (`VERBOSE`), hiển thị đồ họa của nhiễu thực tế và nhiễu dự đoán qua hàm `plot_noise_prediction` và hiển thị biểu đồ phân bố nhiễu thực tế và nhiễu dự đoán qua hàm `plot_noise_distribution`.
   - Lưu trạng thái của mô hình `unet` vào tệp tin có tên chứa số epoch hiện tại (`f"epoch: {epoch}"`).

Tóm lại, đoạn mã trên thực hiện việc huấn luyện mô hình Unet để dự đoán và loại bỏ nhiễu từ dữ liệu ảnh nhiễu được tạo từ mô hình phân tán.
'''
for epoch in range(NO_EPOCHS):
    mean_epoch_loss = []
    mean_epoch_loss_val = []
    for batch, label in trainloader:
        t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,)).long().to(device)
        batch = batch.to(device)
        batch_noisy, noise = diffusion_model.forward(batch, t, device)
        predicted_noise = unet(batch_noisy, t, labels = label.reshape(-1,1).float().to(device))

        optimizer.zero_grad()
        loss = torch.nn.functional.mse_loss(noise, predicted_noise)
        mean_epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()

    for batch, label in testloader:

        t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,)).long().to(device)
        batch = batch.to(device)

        batch_noisy, noise = diffusion_model.forward(batch, t, device)
        predicted_noise = unet(batch_noisy, t, labels = label.reshape(-1,1).float().to(device))

        loss = torch.nn.functional.mse_loss(noise, predicted_noise)
        mean_epoch_loss_val.append(loss.item())

    if epoch % PRINT_FREQUENCY == 0:
        print('---')
        print(f"Epoch: {epoch} | Train Loss {np.mean(mean_epoch_loss)} | Val Loss {np.mean(mean_epoch_loss_val)}")
        if VERBOSE:
            with torch.no_grad():
                plot_noise_prediction(noise[0], predicted_noise[0])
                plot_noise_distribution(noise, predicted_noise)

        torch.save(unet.state_dict(), f"epoch: {epoch}")

unet = UNet(labels=True)
unet.load_state_dict(torch.load(("epoch: 80")))

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

NUM_CLASSES = len(classes)
NUM_DISPLAY_IMAGES = 5

torch.manual_seed(16)

plt.figure(figsize=(15,15))
f, ax = plt.subplots(NUM_CLASSES, NUM_DISPLAY_IMAGES, figsize = (100,100))

for c in range(NUM_CLASSES):
    imgs = torch.randn((NUM_DISPLAY_IMAGES, 3) + IMAGE_SHAPE).to(device)
    for i in reversed(range(diffusion_model.timesteps)):
        t = torch.full((1,), i, dtype=torch.long, device=device)
        labels = torch.tensor([c] * NUM_DISPLAY_IMAGES).resize(NUM_DISPLAY_IMAGES, 1).float().to(device)
        imgs = diffusion_model.backward(x=imgs, t=t, model=unet.eval().to(device), labels = labels)
    for idx, img in enumerate(imgs):
        ax[c][idx].imshow(reverse_transform(img))
        ax[c][idx].set_title(f"Class: {classes[c]}", fontsize = 100)

plt.show()

Ref

link

Tài liệu tham khảo

Internet

Hết.