概要
話題のChatGPT使ってみたんですよ,ええ.
仕事で使えるかもしれないというので,フレーム補完手法のFILMという手法を調べてて.
ソースコードあるじゃんってなって見たけどよくわからんしTensorFlowで書かれてるし,PyTorch版もあったんだけどよくわからんってなった.
こりゃ理解するのに時間かかりそうだな,となっていたんだけど,ChatGPTさん速攻でそれらしいコード書いてくれてすげーなってなってる.
ChatGPTの回答
import torch import torch.nn as nn # 中間フレーム数 N = 5 class FrameInterpolationModel(nn.Module): def __init__(self): super(FrameInterpolationModel, self).__init__() # 特徴抽出器 self.feature_extractor = nn.Sequential( nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.Conv2d(256, 512, 3, padding=1), nn.ReLU(), nn.Conv2d(512, 512, 3, padding=1), nn.ReLU() ) # 中間フレーム生成器 self.middle_frame_generator = nn.Sequential( nn.Conv2d(1024, 512, 3, padding=1), nn.ReLU(), nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(), nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(), nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 3, 3, padding=1), nn.Sigmoid() ) # モーション補間器 self.motion_interpolator = nn.Sequential( nn.Conv2d(512, 512, 3, padding=1), nn.ReLU(), nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(), nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(), nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 2, 3, padding=1) ) def forward(self, x0, x2): # 特徴抽出 feat0 = self.feature_extractor(x0) feat2 = self.feature_extractor(x2) # モーションベクトルの推定 motion_vector = self.motion_interpolator(torch.cat([feat0, feat2], dim=1)) motion_vector = torch.tanh(motion_vector) * 2.0 # 中間フレームの生成 x = torch.cat([feat0, feat2], dim=1) middle_frames = [] for i in range(N): # 重み w = (i+1) / (N+1) # 中間特徴量の生成 x_i = nn.functional.grid_sample(x, self._get_grid(motion_vector * w)) # 中間フレームの生成 middle_frame_i = self.middle_frame_generator(x_i) middle_frames.append(middle_frame_i) return middle_frames def _get_grid(self, motion_vector): batch_size, _, h, w = motion_vector.size() # グリッドの生成 xx = torch.arange(0,
途中で切れちゃってるけど,ここまでのコードをあまり時間をかけずに作ってくれるのやべーな.コード書く人そのうち駆逐されそうだ.