数学的にどうなのさ?

大学時代にちょっと長く数学を勉強した人の雑記。数学のこと(主に統計)や趣味、メモなどが多くなります

ChatGPTぱないねって話

概要

話題のChatGPT使ってみたんですよ,ええ.
仕事で使えるかもしれないというので,フレーム補完手法のFILMという手法を調べてて.
ソースコードあるじゃんってなって見たけどよくわからんしTensorFlowで書かれてるし,PyTorch版もあったんだけどよくわからんってなった.
こりゃ理解するのに時間かかりそうだな,となっていたんだけど,ChatGPTさん速攻でそれらしいコード書いてくれてすげーなってなってる.

ChatGPTの回答

import torch
import torch.nn as nn

# 中間フレーム数
N = 5

class FrameInterpolationModel(nn.Module):
    def __init__(self):
        super(FrameInterpolationModel, self).__init__()
        # 特徴抽出器
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU()
        )
        # 中間フレーム生成器
        self.middle_frame_generator = nn.Sequential(
            nn.Conv2d(1024, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Sigmoid()
        )
        # モーション補間器
        self.motion_interpolator = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 2, 3, padding=1)
        )

    def forward(self, x0, x2):
        # 特徴抽出
        feat0 = self.feature_extractor(x0)
        feat2 = self.feature_extractor(x2)

        # モーションベクトルの推定
        motion_vector = self.motion_interpolator(torch.cat([feat0, feat2], dim=1))
        motion_vector = torch.tanh(motion_vector) * 2.0

        # 中間フレームの生成
        x = torch.cat([feat0, feat2], dim=1)
        middle_frames = []
        for i in range(N):
            # 重み
            w = (i+1) / (N+1)
            # 中間特徴量の生成
            x_i = nn.functional.grid_sample(x, self._get_grid(motion_vector * w))
            # 中間フレームの生成
            middle_frame_i = self.middle_frame_generator(x_i)
            middle_frames.append(middle_frame_i)

        return middle_frames

    def _get_grid(self, motion_vector):
        batch_size, _, h, w = motion_vector.size()
        # グリッドの生成
        xx = torch.arange(0,

途中で切れちゃってるけど,ここまでのコードをあまり時間をかけずに作ってくれるのやべーな.コード書く人そのうち駆逐されそうだ.