【深度学习基础模型】Variational Autoencoders (VAE) 详细理解并附实现代码。

news/2024/9/28 9:19:39 标签: 深度学习, 人工智能, 机器学习, 分类, python, VAE

VAE__0">【深度学习基础模型】Variational Autoencoders (VAE) 详细理解并附实现代码

深度学习基础模型】Variational Autoencoders (VAE) 详细理解并附实现代码


文章目录

  • 深度学习基础模型】Variational Autoencoders (VAE) 详细理解并附实现代码
  • 1.Variational Autoencoders (VAE) 的原理和应用
    • 1.1 VAE 原理
    • 1.2 VAE 的主要特征:
    • 1.3 VAE 的应用领域:
  • 2.Python 代码实现 VAE 在遥感领域的应用
    • 2.1VAE 模型的实现
    • 2.2代码解释
  • 3.总结


参考地址:https://www.asimovinstitute.org/neural-network-zoo/
论文地址:https://arxiv.org/pdf/1312.6114v10

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

VAE__13">1.Variational Autoencoders (VAE) 的原理和应用

VAE__14">1.1 VAE 原理

变分自编码器(Variational Autoencoder, VAE)是生成模型的一种,旨在学习输入数据的潜在概率分布VAE 与传统的自编码器(AE)相比,其核心区别在于它采用了贝叶斯方法进行推理

VAE__16">1.2 VAE 的主要特征:

  • 架构:与 AE 相同,VAE 也由编码器和解码器组成,但编码器的输出是潜在变量的概率分布(通常为高斯分布)
  • 重参数化技巧:为了解决标准反向传播无法有效训练模型的问题,VAE 引入了重参数化技巧。通过将潜在变量表示为固定分布(如标准正态分布)与参数化分布(均值和方差)的组合,模型能够有效学习
  • 损失函数VAE 的损失函数由两部分组成:重构损失和 KL 散度(Kullback-Leibler Divergence)。重构损失衡量生成样本与真实样本之间的差异,而 KL 散度则确保潜在分布接近先验分布(通常是标准正态分布)。

VAE__21">1.3 VAE 的应用领域:

  • 图像生成VAE 可以生成新图像,广泛应用于计算机视觉。
  • 数据插值:通过在潜在空间中进行插值,VAE 可以生成两种输入之间的过渡图像。
  • 异常检测:在学习正常数据的分布后,VAE 可以检测到异常样本。

在遥感领域,VAE 可以用于处理高维遥感数据,生成新图像,或从复杂的多光谱图像中提取潜在特征。

VAE__27">2.Python 代码实现 VAE 在遥感领域的应用

下面通过一个简单的 VAE 实现,演示如何在遥感图像处理中应用 VAE

VAE__29">2.1VAE 模型的实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt

# 定义 VAE 模型
class VAE(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(VAE, self).__init__()
        
        # 编码器
        self.fc1 = nn.Linear(input_size, hidden_size)  # 输入到隐藏层
        self.fc21 = nn.Linear(hidden_size, hidden_size)  # 均值
        self.fc22 = nn.Linear(hidden_size, hidden_size)  # 对数方差
        
        # 解码器
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, input_size)

    def encode(self, x):
        h1 = torch.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)  # 返回均值和对数方差

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # 标准差
        eps = torch.randn_like(std)  # 随机噪声
        return mu + eps * std  # 重新参数化

    def decode(self, z):
        h3 = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))  # 输出为概率值 [0, 1]

    def forward(self, x):
        mu, logvar = self.encode(x)  # 编码
        z = self.reparameterize(mu, logvar)  # 重新参数化
        return self.decode(z), mu, logvar  # 解码及返回均值和对数方差

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')  # 重构损失
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())  # KL 散度
        return BCE + KLD  # 总损失

# 生成模拟遥感图像数据 (64 维特征)
X = np.random.rand(1000, 64)  # 1000 个样本,每个样本有 64 维光谱特征
X = torch.tensor(X, dtype=torch.float32)

# 创建数据加载器
dataset = TensorDataset(X)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 定义模型、优化器
input_size = 64
hidden_size = 32  # 隐藏层大小
vae = VAE(input_size=input_size, hidden_size=hidden_size)
optimizer = optim.Adam(vae.parameters(), lr=0.001)

# 训练 VAE 模型
num_epochs = 50
for epoch in range(num_epochs):
    for data in dataloader:
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data[0])  # 前向传播
        loss = vae.loss_function(recon_batch, data[0], mu, logvar)  # 计算损失
        loss.backward()
        optimizer.step()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

# 使用训练好的模型进行数据生成
with torch.no_grad():
    sample = torch.randn(64)  # 从标准正态分布生成潜在变量
    generated_data = vae.decode(sample).numpy()  # 解码生成新样本

# 可视化原始数据与生成数据
plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
plt.title('Generated Data')
plt.imshow(generated_data[:10], aspect='auto', cmap='hot')
plt.subplot(1, 2, 2)
plt.title('Original Data Sample')
plt.imshow(X.numpy()[:10], aspect='auto', cmap='hot')
plt.show()

2.2代码解释

1.模型定义:

class VAE(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(VAE, self).__init__()
        
        # 编码器
        self.fc1 = nn.Linear(input_size, hidden_size)  # 输入到隐藏层
        self.fc21 = nn.Linear(hidden_size, hidden_size)  # 均值
        self.fc22 = nn.Linear(hidden_size, hidden_size)  # 对数方差
        
        # 解码器
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, input_size)
  • VAE 类定义了编码器和解码器结构,包括均值和对数方差的输出。

2.重参数化技巧:

def reparameterize(self, mu, logvar):
    std = torch.exp(0.5 * logvar)  # 标准差
    eps = torch.randn_like(std)  # 随机噪声
    return mu + eps * std  # 重新参数化
  • 使用随机噪声与潜在变量均值和标准差结合,生成潜在表示。

3.数据生成:

X = np.random.rand(1000, 64)  # 生成 1000 个样本,每个样本有 64 维光谱特征
X = torch.tensor(X, dtype=torch.float32)
  • 模拟生成随机的遥感光谱数据。

4.数据加载器:

dataset = TensorDataset(X)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 使用 DataLoader 创建批处理数据集。

5.模型训练

for epoch in range(num_epochs):
    for data in dataloader:
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data[0])
        loss = vae.loss_function(recon_batch, data[0], mu, logvar)
        loss.backward()
        optimizer.step()
  • 使用 50 个 epoch 进行训练,计算重构损失和 KL 散度,更新权重。

6.生成新数据:

with torch.no_grad():
    sample = torch.randn(64)  # 从标准正态分布生成潜在变量
    generated_data = vae.decode(sample).numpy()  # 解码生成新样本
  • 通过从潜在空间生成样本,使用解码器生成新数据。

7.可视化:

plt.subplot(1, 2, 1)
plt.title('Generated Data')
plt.imshow(generated_data[:10], aspect='auto', cmap='hot')
  • 可视化生成的数据与原始数据的对比。

3.总结

变分自编码器(VAE)是一种强大的生成模型,能够有效学习输入数据的潜在概率分布。在遥感领域,VAE 可以用于数据生成、特征提取和异常检测等任务。通过简单的 Python 实现,我们展示了如何使用 VAE 处理遥感数据,生成新样本,并可视化结果。


http://www.niftyadmin.cn/n/5680804.html

相关文章

golang fmt.Sprintf 引用前述变量

可使用 %[1]s 之类的写法,引用前述变量,避免重复写 例: package mainimport "fmt"func main() {s : fmt.Sprintf("%[1]s, %[2]d, %[3]s, %[1]s %[3]s", "hello", 777, "world")fmt.Print(s) // 输…

iOS--RunLoop原理

前言 曾经在写项目的时候遇到过这么一个问题。: 项目中添加了一个tableview,然后还有一个计时器,当滑动tableview的时候会阻塞计时器,你得执行这么一段代码后,计时器才能正常运行。 RunLoop.current.add(timer, for…

Spring Boot 调用外部接口的常用方式!

使用Feign进行服务消费是一种简化HTTP调用的方式,可以通过声明式的接口定义来实现。下面是一个使用Feign的示例,包括设置Feign客户端和调用服务的方法。 添加依赖 首先,请确保你的项目中已经添加了Feign的依赖。如果你使用的是Maven&#xf…

基于Hive和Hadoop的电信流量分析系统

本项目是一个基于大数据技术的电信流量分析系统,旨在为用户提供全面的通信数据和深入的流量使用分析。系统采用 Hadoop 平台进行大规模数据存储和处理,利用 MapReduce 进行数据分析和处理,通过 Sqoop 实现数据的导入导出,以 Spark…

DOM元素导出图片与PDF:多种方案对比与实现

背景 在日常前端开发中,经常会有把页面的 DOM 元素作为 PNG 或者 PDF 下载到本地的需求。例如海报功能,简历导出功能等等。在我们自家的产品「代码小抄」中,就使用了 html2canvas 来实现代码片段导出为图片: 是不是还行&#xff…

vue仿chatGpt的AI聊天功能--大模型通义千问(阿里云)

vue仿chatGpt的AI聊天功能–大模型通义千问(阿里云) 通义千问是由阿里云自主研发的大语言模型,用于理解和分析用户输入的自然语言。 1. 创建API-KEY并配置环境变量 打开通义千问网站进行登录,登陆之后创建api-key,右…

C++-list使用学习

###list(链表)是C里面的一种容器,底层是双向的; 这就决定了它的迭代器使用的场景和能够使用的算法;双向(例如list)不能像随机(例如vector)那样用迭代器任意加上几去使用&…

Harbor使用

文章目录 1、上传镜像1.1、在Harbor上创建一个项目1.2、docker添加安全访问权限1.3、推送docker镜像到该项目中1.3.1、登录到Harbor1.3.2、给镜像重新打一个标签1.3.3、推送镜像到Harbor中 2、拉取镜像2.1、先删掉原来的镜像2.2、执行拉取命令 1、上传镜像 需求:将…