import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from .base_model import BaseModel
__all__ = ['DSResNet']
class DoubleConvBlock(BaseModel):
def __init__(self, in_channels: int, out_channels: int):
super(DoubleConvBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels, out_channels,
kernel_size=3, padding=1
),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
self.conv2 = nn.Sequential(
nn.Conv2d(
out_channels, out_channels,
kernel_size=3, padding=1
),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.conv1(x)
out = self.conv2(out)
return out
class ResEncoderBlock(BaseModel):
def __init__(self, in_channels: int, out_channels: int):
super(ResEncoderBlock, self).__init__()
self.double_conv = DoubleConvBlock(
in_channels, out_channels
)
self.skip_conv = nn.Conv2d(
in_channels, out_channels, kernel_size=1
)
self.down = nn.MaxPool2d(2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.skip_conv(x)
out = self.double_conv(x)
out = out + identity
return self.down(out), out
class ResDecoderBlock(BaseModel):
def __init__(self, in_channels: int, out_channels: int):
super(ResDecoderBlock, self).__init__()
self.transition_conv = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1
)
)
self.enc_skip_conv = nn.Conv2d(
in_channels, out_channels, kernel_size=1
)
self.skip_conv = nn.Conv2d(
in_channels, out_channels, kernel_size=1
)
self.double_conv = DoubleConvBlock(
in_channels, out_channels
)
def forward(
self, x: torch.Tensor, encoder_input: torch.Tensor, skip_input: Optional[torch.Tensor] = None
) -> torch.Tensor:
# Transition
x = self.transition_conv(x)
x = F.interpolate(
x, scale_factor=2, mode='bilinear', align_corners=False
)
if not skip_input is None:
encoder_input = torch.cat(
[encoder_input, skip_input], dim=1
)
encoder_input = self.enc_skip_conv(encoder_input)
x = torch.cat([x, encoder_input], dim=1)
# Decoding
identity = self.skip_conv(x)
out = self.double_conv(x)
out = out + identity
return out
[docs]class DSResNet(BaseModel):
"""A U-Net Inspired model for Monocular Depth Estimation and Image Segmentation.
For information check the `Depth-Estimation-Segmentation repository
<https://github.com/shan18/Depth-Estimation-Segmentation>`_.
`Note`: This model inherits the ``BaseModel`` class.
"""
def __init__(self):
super(DSResNet, self).__init__()
# Encoder Network
# ===============
# Preparation Block for bg
self.b1 = ResEncoderBlock(3, 16)
self.b2 = ResEncoderBlock(16, 32)
# Preparation Block for bg_fg
self.bf1 = ResEncoderBlock(3, 16)
self.bf2 = ResEncoderBlock(16, 32)
# Join both inputs
self.merge = nn.Conv2d(64, 32, kernel_size=1)
# Merged encoder network
self.enc1 = ResEncoderBlock(32, 64)
self.enc2 = ResEncoderBlock(64, 128)
self.enc3 = ResEncoderBlock(128, 256)
self.enc4 = ResEncoderBlock(256, 512)
# Decoder Network
# ===============
# Decoder Network - Segmentation
self.Mdec3 = ResDecoderBlock(512, 256)
self.Mdec2 = ResDecoderBlock(256, 128)
self.Mdec1 = ResDecoderBlock(128, 64)
self.M2 = ResDecoderBlock(64, 32)
self.M1 = ResDecoderBlock(32, 16)
self.M0 = nn.Conv2d(16, 1, kernel_size=1)
# Decoder Network - Depth
self.Ddec3 = ResDecoderBlock(512, 256)
self.Ddec2 = ResDecoderBlock(256, 128)
self.Ddec1 = ResDecoderBlock(128, 64)
self.D2 = ResDecoderBlock(64, 32)
self.D1 = ResDecoderBlock(32, 16)
self.D0 = nn.Conv2d(16, 1, kernel_size=1)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
# bg
b1_down, b1 = self.b1(x['bg'])
b2_down, b2 = self.b2(b1_down)
# bg_fg
bf1_down, bf1 = self.bf1(x['bg_fg'])
bf2_down, bf2 = self.bf2(bf1_down)
# Merging
merge = torch.cat([b2_down, bf2_down], dim=1)
merge = self.merge(merge)
# Merged Encoder
enc1_down, enc1 = self.enc1(merge)
enc2_down, enc2 = self.enc2(enc1_down)
enc3_down, enc3 = self.enc3(enc2_down)
_, enc4 = self.enc4(enc3_down)
# Decoder - Segmentation
Mdec3 = self.Mdec3(enc4, enc3)
Mdec2 = self.Mdec2(Mdec3, enc2)
Mdec1 = self.Mdec1(Mdec2, enc1)
m2 = self.M2(Mdec1, b2, bf2)
m1 = self.M1(m2, b1, bf1)
outM = self.M0(m1)
# Decoder - Depth
Ddec3 = self.Ddec3(enc4, enc3)
Ddec2 = self.Ddec2(Ddec3, enc2)
Ddec1 = self.Ddec1(Ddec2, enc1)
d2 = self.D2(Ddec1, b2, bf2)
d1 = self.D1(d2, b1, bf1)
outD = self.D0(d1)
return outD, outM