ONNX Image Model

Run a PIX2PIX image to image model developed in ONNX. This is similar to Image to Image but uses an ONNX model instead of a TensorFlow.js model.

This currently only runs 512x512 image models.

To work well, it needs an input that is similar to what it has seen in training; e.g. if you trained it on a hand model, you need to feed it hand models with the same color and line thickness.

Parameters

  • Model The .onnx model file.

Training

Here’s a simple example of how to train a PIX2PIX model using PyTorch. You will need torch, torchvision, onnx and onnxruntime installed.

pip install torch==2.4.0 torchvision==0.19.0 onnx==1.16.1 onnxruntime==1.19.0

Run the code with python train.py --input_dir datasets/trees --output_dir output.

import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.onnx
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os
import random
import argparse
from tqdm import tqdm


# Create the dataset class
class Pix2PixDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [
            f for f in os.listdir(root_dir) if f.endswith(".jpg") or f.endswith(".png")
        ]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name)

        # Split the image into input and target
        w, h = image.size
        target_image = image.crop((0, 0, w // 2, h))
        input_image = image.crop((w // 2, 0, w, h))

        if self.transform:
            input_image = self.transform(input_image)
            target_image = self.transform(target_image)

        return input_image, target_image


# Implement the UNet architecture for the generator
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
        super(UNetBlock, self).__init__()
        self.conv = (
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
        )
        self.bn = nn.BatchNorm2d(out_channels) if bn else None
        self.dropout = nn.Dropout(0.5) if dropout else None
        self.act = nn.LeakyReLU(0.2) if down else nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        if self.dropout:
            x = self.dropout(x)
        return self.act(x)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = UNetBlock(3, 64, down=True, bn=False)
        self.down2 = UNetBlock(64, 128)
        self.down3 = UNetBlock(128, 256)
        self.down4 = UNetBlock(256, 512)
        self.down5 = UNetBlock(512, 512)
        self.down6 = UNetBlock(512, 512)
        self.down7 = UNetBlock(512, 512)
        self.down8 = UNetBlock(512, 512, bn=False)

        self.up1 = UNetBlock(512, 512, down=False, dropout=True)
        self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
        self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
        self.up4 = UNetBlock(1024, 512, down=False)
        self.up5 = UNetBlock(1024, 256, down=False)
        self.up6 = UNetBlock(512, 128, down=False)
        self.up7 = UNetBlock(256, 64, down=False)

        self.final = nn.Sequential(nn.ConvTranspose2d(128, 3, 4, 2, 1), nn.Tanh())

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8)
        u2 = self.up2(torch.cat([u1, d7], 1))
        u3 = self.up3(torch.cat([u2, d6], 1))
        u4 = self.up4(torch.cat([u3, d5], 1))
        u5 = self.up5(torch.cat([u4, d4], 1))
        u6 = self.up6(torch.cat([u5, d3], 1))
        u7 = self.up7(torch.cat([u6, d2], 1))
        return self.final(torch.cat([u7, d1], 1))


# Implement the discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            UNetBlock(6, 64, bn=False),
            UNetBlock(64, 128),
            UNetBlock(128, 256),
            UNetBlock(256, 512),
            nn.Conv2d(512, 1, 4, 1, 1),
            nn.Sigmoid(),
        )

    def forward(self, x, y):
        return self.model(torch.cat([x, y], 1))


# Define the loss functions and optimizers
criterion_gan = nn.BCELoss()
criterion_pixel = nn.L1Loss()


# Load snapshot if available
def get_latest_snapshot(output_dir):
    snapshots = glob.glob(os.path.join(output_dir, "snapshot_epoch_*.pth"))
    if not snapshots:
        return None
    return max(snapshots, key=os.path.getctime)


def load_snapshot(generator, discriminator, g_optimizer, d_optimizer, snapshot_path):
    checkpoint = torch.load(snapshot_path, map_location=device, weights_only=False)
    generator.load_state_dict(checkpoint["generator"])
    discriminator.load_state_dict(checkpoint["discriminator"])
    g_optimizer.load_state_dict(checkpoint["g_optimizer"])
    d_optimizer.load_state_dict(checkpoint["d_optimizer"])
    start_epoch = int(os.path.basename(snapshot_path).split("_")[2].split(".")[0])
    return start_epoch


# 6. Create the training loop
def train(generator, discriminator, dataloader, args):
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    # Get fixed input/output for visualization
    fixed_set = next(iter(dataloader))
    fixed_input = fixed_set[0][0].unsqueeze(0)
    fixed_target = fixed_set[1][0].unsqueeze(0)
    # fixed_input = next(iter(dataloader))[0][0].unsqueeze(0)  # Get a fixed input for visualization

    start_epoch = 0
    if not args.restart:
        latest_snapshot = get_latest_snapshot(args.output_dir)
        if latest_snapshot:
            start_epoch = load_snapshot(
                generator, discriminator, g_optimizer, d_optimizer, latest_snapshot
            )
            print(f"Resuming training from epoch {start_epoch}")
        else:
            print("No snapshots found. Starting from scratch.")
    else:
        print("Restarting training from scratch.")

    for epoch in range(args.epochs):
        for i, (input_img, target_img) in enumerate(tqdm(dataloader)):
            input_img = input_img.to(device)
            target_img = target_img.to(device)

            # Train Discriminator
            d_optimizer.zero_grad()
            fake_img = generator(input_img)
            d_real = discriminator(input_img, target_img)
            d_fake = discriminator(input_img, fake_img.detach())
            d_loss_real = criterion_gan(d_real, torch.ones_like(d_real))
            d_loss_fake = criterion_gan(d_fake, torch.zeros_like(d_fake))
            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            d_optimizer.step()

            # Train Generator
            g_optimizer.zero_grad()
            fake_img = generator(input_img)
            d_fake = discriminator(input_img, fake_img)
            g_loss_gan = criterion_gan(d_fake, torch.ones_like(d_fake))
            g_loss_pixel = criterion_pixel(fake_img, target_img) * 100
            g_loss = g_loss_gan + g_loss_pixel
            g_loss.backward()
            g_optimizer.step()

            if i % args.sample_interval == 0:
                with torch.no_grad():
                    fake_img = generator(fixed_input.to(device))
                    img_sample = torch.cat(
                        (fixed_input.cpu(), fake_img.cpu(), fixed_target.cpu()), -1
                    )
                    save_image(
                        img_sample,
                        f"{args.output_dir}/epoch_{epoch}_iter_{i}.jpg",
                        nrow=3,
                        normalize=True,
                    )

        if (epoch + 1) % args.snapshot_interval == 0:
            torch.save(
                {
                    "generator": generator.state_dict(),
                    "discriminator": discriminator.state_dict(),
                    "g_optimizer": g_optimizer.state_dict(),
                    "d_optimizer": d_optimizer.state_dict(),
                },
                f"{args.output_dir}/snapshot_epoch_{epoch + 1}.pth",
            )

            # Save to ONNX format
            onnx_path = f"{args.output_dir}/generator_epoch_{epoch + 1}.onnx"
            generator.eval()
            dummy_input = torch.randn(1, 3, 512, 512).to(device)
            traced_script_module = torch.jit.trace(generator, dummy_input)
            torch.onnx.export(
                traced_script_module,
                dummy_input,
                onnx_path,
                export_params=True,
                opset_version=11,
                do_constant_folding=True,
                input_names=["input"],
                output_names=["output"],
                dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
            )
            print(f"ONNX model exported to {onnx_path}")
            generator.train()


# 7. Implement the argument parser for configuration
def parse_args():
    parser = argparse.ArgumentParser(
        description="Conditional GAN with pix2pix architecture"
    )
    parser.add_argument(
        "--input_dir",
        type=str,
        default="datasets/trees",
        help="Input dataset directory",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Output directory for generated images",
    )
    parser.add_argument(
        "--sample_interval",
        type=int,
        default=100,
        help="Interval for saving sample images",
    )
    parser.add_argument(
        "--snapshot_interval",
        type=int,
        default=1,
        help="Interval for saving model snapshots",
    )
    parser.add_argument(
        "--epochs", type=int, default=200, help="Number of epochs to train"
    )
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
    parser.add_argument(
        "--restart", action="store_true", help="Restart training from scratch"
    )
    return parser.parse_args()


# 8. Set up the main function to run the training
def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    transform = transforms.Compose(
        [
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    dataset = Pix2PixDataset(args.input_dir, transform=transform)
    dataloader = DataLoader(
        dataset, batch_size=args.batch_size, shuffle=True, num_workers=4
    )

    generator = Generator().to(device)
    discriminator = Discriminator().to(device)

    train(generator, discriminator, dataloader, args)


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    main()