Kim Seon Deok

ResNet 구현 본문

AI/논문

ResNet 구현

seondeok 2022. 3. 25. 03:17

 

CIFAR - 10 dataset 이용

보다 훨씬 더 깊은 네트워크에서 ResNet이 error rate를 줄일 수 있는지를 확인하기 위해 실험을 진행

 

size : 32*32 >> imagenet(224*224)에 비해 훨신 크기가 작다.

따라서 파라미터의 수를 감소시켜 별도의 resnet을 사용

number of class : 10

number of training images : 50000

number of test images : 10000

 

 

input : 32*32*3

도입부 

7*7 conv, 64, 1/2  >>16*16*64

pool, 1/2  >> 8*8*64

 

크기가 같으므로 네트워크 변형 x

 

중간층

layer1

3*3 conv, 64

3*3 conv, 64

3*3 conv, 64

3*3 conv, 64 >> 8*8*64

network의 변형과정 : 8*8*64 에서 4*4*128 로 만들어 주는 과정

layer2

3*3 conv, 128, 1/2  >>4*4*128

3*3 conv, 128

3*3 conv, 128

3*3 conv, 128

layer3

3*3 conv, 256, 1/2

3*3 conv, 256

3*3 conv, 256

3*3 conv, 256

layer4

3*3 conv, 512, 1/2

3*3 conv, 512

3*3 conv, 512

3*3 conv, 512

 

아웃풋

Adaptive Average Pooling

출력 H,W:1*1

 

Fully Connected Layer

입력 크기 : 512

출력 크기 : class의 수(10)

입력 받기 전 tensor의 모양 변경 필요(view 함수)

 

import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch

class BasicBlock(nn.Module):  # Basicblock을 만들어주는 역할
    def __init__(self, in_channel, out_channel, downsample):
        super().__init__()

        self.downsample = downsample  # True or False
        if self.downsample:
            # 1. 1/2 처리를 해주어야 함
            stride = 2
            # 2. 점선 과정에서 사용될 네트워크를 고려해야 함
            self.down_cnn = nn.Conv2d(in_channels = in_channel, out_channels = out_channel,
            							kernel_size = 1, stride = stride)
            self.down_bn = nn.BatchNorm2d(num_features = out_channel)

        else:
            # 1. 1/2 처리 필요 없음.
            stride = 1
            # 2. 실선 과정이므로 네트워크 없이 그냥 더하기 하면 됨(layer1 처럼)

        self.conv1 = nn.Conv2d(in_channels = in_channel, out_channels = out_channel, 
        						kernel_size = 3, stride = stride, padding = 1)
        self.bn1 = nn.BatchNorm2d(num_features = out_channel)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(in_channels = out_channel, out_channels = out_channel, 
        						kernel_size = 3, stride = 1, padding = 1)
        self.bn2 = nn.BatchNorm2d(num_features = out_channel)
        


    def forward(self,x):
      if self.downsample :
        skip = self.down_cnn(x)
        skip = self.down_bn(skip)
        
      else:
        skip = x

      out = self.conv1(x)
      out = self.bn1(out)
      out = self.relu(out)

      out = self.conv2(out)
      out = self.bn2(out)
      out = self.relu(out)

      out = out + skip

      return out
      
      class ResNet18(nn.Module):
    def __init__(self, num_classes):   # 초기화 과정
        super().__init__()
        
        # 도입부
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels=64, kernel_size = 7, stride = 2, padding = 3)
        self.bn1 = nn.BatchNorm2d(num_features = 64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)

        # 중간층(layer1 ~ layer4)
        self.layer1 = self.make_layer(in_channel=64, out_channel=64, num_blocks = 2,
        								 downsample=False)  
        # downsample : skip connection과정에서 다음 layer로 넘어갈 때 size 변경
        self.layer2 = self.make_layer(in_channel=64, out_channel=128, num_blocks = 2,
       									 downsample=True)
        self.layer3 = self.make_layer(in_channel=128, out_channel=256, num_blocks = 2,
       									 downsample=True)
        self.layer4 = self.make_layer(in_channel=256, out_channel=512, num_blocks = 2, 
        								 downsample=True)     

        # 아웃풋
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))  
        # filter size만 주고 알아서 width와 height를 1*1로 만들도록 함
        self.fc = nn.Linear(in_features = 512, out_features = num_classes)

    def make_layer(self, num_blocks, in_channel, out_channel, downsample ):
        layer = []  # basic block이 들어감
        layer.append(BasicBlock(in_channel = in_channel, out_channel = out_channel, 
        						downsample=downsample))
        for _ in range(1,num_blocks):
          layer.append(BasicBlock(in_channel = out_channel, out_channel = out_channel, 
          						downsample = False))
        return nn.Sequential(*layer)  # python 에서 *은 unpacking을 의미

    def forward(self, x):   # 이미지를 다루는 과정
        batch_size = x.shape[0]
        # 도입부
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)

        # 중간층(layer1 ~ layer4)
        # bacis block = 3*3 Conv 2개와 skip connection으로 구성
        out = self.layer1(out)  
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)


        # 아웃풋
        out = self.avg_pool(out)
        out = out.view(batch_size, -1)
        out = self.fc(out)


        return out
# hyper-parameter
num_classes = 10
num_epochs = 10
batch_size = 100
learning_rate = 0.001

# Data PreProcessing
transforms_train = transforms.Compose([
                                       transforms.RandomCrop(32, padding = 4),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.4914,0.4822,0.4465),(0.2003,0.1944, 0.2010))

])

transforms_test = transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.4914,0.4822,0.4465),(0.2003,0.1944, 0.2010))
])

# pytorch에서 제공하는 CIFAR10 dataset
train_dev_dataset = torchvision.datasets.CIFAR10(root='./data',train=True,
                                        transform = transforms_train, download = True)
test_dataset = torchvision.datasets.CIFAR10(root = './data', train=False,
                                            transform = transforms_test, download = True)
train_dataset, dev_dataset = torch.utils.data.random_split(train_dev_dataset, [45000, 5000])

# 배치 단위로 데이터를 처리해주는 Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)
dev_loader = torch.utils.data.DataLoader(dataset=dev_dataset,
                                         batch_size = batch_size,
                                         shuffle = False)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset, 
                                          batch_size = batch_size,
                                          shuffle = False)

# model을 지정한 device로 올려줌
model = ResNet18(num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# model.parameters - > 가중치 w들을 의미

def evaluation(data_loader):
  correct = 0
  total = 0
  for images, labels in data_loader:
      images = images.to(device)
      labels = labels.to(device)
      outputs = model(images)
      _,predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  return correct / total

loss_arr = []
max = 0.0
total_step = len(train_loader)
for epoch in range(num_epochs):
  for i, (images, labels) in enumerate(train_loader):
    model.train()
    # Move tensors to the configurated device
    images = images.to(device)
    labels = labels.to(device)
    # Forward pass
    outputs = model(images)
    loss = criterion(outputs, labels)
    # Backward and optimize
    optimizer.zero_grad()  # iteration마다 gradient를 0으로 초기화
    loss.backward() # 가중치 w에 대해 loss를 미분
    optimizer.step() # 가중치들을 업데이트

    
    if (i+1) % 150 == 0:
      loss_arr.append(loss)
      print('Epoch [{}/{}], step [{}/{}], Loss: {:.4f}'
      .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

      with torch.no_grad():
        model.eval()
        acc = evaluation(dev_loader)
        if max < acc:
          max = acc
          print('max dev accuracy:', max)
          torch.save(model.state_dict(), 'model.ckpt')

 

 

torchvision.models의 ResNet18로 CIFAR data 학습

 

'AI > 논문' 카테고리의 다른 글

EfficientNet  (0) 2022.03.30
DenseNet  (0) 2022.03.30
ResNet  (0) 2022.03.24
GoogleNet  (0) 2022.03.24
VGGNet  (0) 2022.03.22
Comments