ONNX Image Model

Run a PIX2PIX image to image model developed in ONNX. This is similar to Image to Image but uses an ONNX model instead of a TensorFlow.js model.

The node automatically reads the required input and output dimensions from the ONNX model metadata. Input images must match the model’s expected size.

To work well, it needs an input that is similar to what it has seen in training; e.g. if you trained it on a hand model, you need to feed it hand models with the same color and line thickness.

Parameters

  • Model The .onnx model file.

Training

Here’s a simple example of how to train a PIX2PIX model using PyTorch. You will need torch, torchvision, onnx and onnxruntime installed.

pip install torch==2.4.0 torchvision==0.19.0 onnx==1.16.1 onnxruntime==1.19.0

Run the code with python train.py --input_dir datasets/trees --output_dir output.

import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.onnx
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os
import random
import argparse
from tqdm import tqdm


# Create the dataset class
class Pix2PixDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [
            f for f in os.listdir(root_dir) if f.endswith(".jpg") or f.endswith(".png")
        ]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name)

        # Split the image into input and target
        w, h = image.size
        target_image = image.crop((0, 0, w // 2, h))
        input_image = image.crop((w // 2, 0, w, h))

        if self.transform:
            input_image = self.transform(input_image)
            target_image = self.transform(target_image)

        return input_image, target_image


# Implement the UNet architecture for the generator
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
        super(UNetBlock, self).__init__()
        self.conv = (
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
        )
        self.bn = nn.BatchNorm2d(out_channels) if bn else None
        self.dropout = nn.Dropout(0.5) if dropout else None
        self.act = nn.LeakyReLU(0.2) if down else nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        if self.dropout:
            x = self.dropout(x)
        return self.act(x)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = UNetBlock(3, 64, down=True, bn=False)
        self.down2 = UNetBlock(64, 128)
        self.down3 = UNetBlock(128, 256)
        self.down4 = UNetBlock(256, 512)
        self.down5 = UNetBlock(512, 512)
        self.down6 = UNetBlock(512, 512)
        self.down7 = UNetBlock(512, 512)
        self.down8 = UNetBlock(512, 512, bn=False)

        self.up1 = UNetBlock(512, 512, down=False, dropout=True)
        self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
        self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
        self.up4 = UNetBlock(1024, 512, down=False)
        self.up5 = UNetBlock(1024, 256, down=False)
        self.up6 = UNetBlock(512, 128, down=False)
        self.up7 = UNetBlock(256, 64, down=False)

        self.final = nn.Sequential(nn.ConvTranspose2d(128, 3, 4, 2, 1), nn.Tanh())

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8)
        u2 = self.up2(torch.cat([u1, d7], 1))
        u3 = self.up3(torch.cat([u2, d6], 1))
        u4 = self.up4(torch.cat([u3, d5], 1))
        u5 = self.up5(torch.cat([u4, d4], 1))
        u6 = self.up6(torch.cat([u5, d3], 1))
        u7 = self.up7(torch.cat([u6, d2], 1))
        return self.final(torch.cat([u7, d1], 1))


# Implement the discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            UNetBlock(6, 64, bn=False),
            UNetBlock(64, 128),
            UNetBlock(128, 256),
            UNetBlock(256, 512),
            nn.Conv2d(512, 1, 4, 1, 1),
            nn.Sigmoid(),
        )

    def forward(self, x, y):
        return self.model(torch.cat([x, y], 1))


# Define the loss functions and optimizers
criterion_gan = nn.BCELoss()
criterion_pixel = nn.L1Loss()


# Load snapshot if available
def get_latest_snapshot(output_dir):
    snapshots = glob.glob(os.path.join(output_dir, "snapshot_epoch_*.pth"))
    if not snapshots:
        return None
    return max(snapshots, key=os.path.getctime)


def load_snapshot(generator, discriminator, g_optimizer, d_optimizer, snapshot_path):
    checkpoint = torch.load(snapshot_path, map_location=device, weights_only=False)
    generator.load_state_dict(checkpoint["generator"])
    discriminator.load_state_dict(checkpoint["discriminator"])
    g_optimizer.load_state_dict(checkpoint["g_optimizer"])
    d_optimizer.load_state_dict(checkpoint["d_optimizer"])
    start_epoch = int(os.path.basename(snapshot_path).split("_")[2].split(".")[0])
    return start_epoch


# 6. Create the training loop
def train(generator, discriminator, dataloader, args):
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    # Get fixed input/output for visualization
    fixed_set = next(iter(dataloader))
    fixed_input = fixed_set[0][0].unsqueeze(0)
    fixed_target = fixed_set[1][0].unsqueeze(0)
    # fixed_input = next(iter(dataloader))[0][0].unsqueeze(0)  # Get a fixed input for visualization

    start_epoch = 0
    if not args.restart:
        latest_snapshot = get_latest_snapshot(args.output_dir)
        if latest_snapshot:
            start_epoch = load_snapshot(
                generator, discriminator, g_optimizer, d_optimizer, latest_snapshot
            )
            print(f"Resuming training from epoch {start_epoch}")
        else:
            print("No snapshots found. Starting from scratch.")
    else:
        print("Restarting training from scratch.")

    for epoch in range(args.epochs):
        for i, (input_img, target_img) in enumerate(tqdm(dataloader)):
            input_img = input_img.to(device)
            target_img = target_img.to(device)

            # Train Discriminator
            d_optimizer.zero_grad()
            fake_img = generator(input_img)
            d_real = discriminator(input_img, target_img)
            d_fake = discriminator(input_img, fake_img.detach())
            d_loss_real = criterion_gan(d_real, torch.ones_like(d_real))
            d_loss_fake = criterion_gan(d_fake, torch.zeros_like(d_fake))
            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            d_optimizer.step()

            # Train Generator
            g_optimizer.zero_grad()
            fake_img = generator(input_img)
            d_fake = discriminator(input_img, fake_img)
            g_loss_gan = criterion_gan(d_fake, torch.ones_like(d_fake))
            g_loss_pixel = criterion_pixel(fake_img, target_img) * 100
            g_loss = g_loss_gan + g_loss_pixel
            g_loss.backward()
            g_optimizer.step()

            if i % args.sample_interval == 0:
                with torch.no_grad():
                    fake_img = generator(fixed_input.to(device))
                    img_sample = torch.cat(
                        (fixed_input.cpu(), fake_img.cpu(), fixed_target.cpu()), -1
                    )
                    save_image(
                        img_sample,
                        f"{args.output_dir}/epoch_{epoch}_iter_{i}.jpg",
                        nrow=3,
                        normalize=True,
                    )

        if (epoch + 1) % args.snapshot_interval == 0:
            torch.save(
                {
                    "generator": generator.state_dict(),
                    "discriminator": discriminator.state_dict(),
                    "g_optimizer": g_optimizer.state_dict(),
                    "d_optimizer": d_optimizer.state_dict(),
                },
                f"{args.output_dir}/snapshot_epoch_{epoch + 1}.pth",
            )

            # Save to ONNX format
            onnx_path = f"{args.output_dir}/generator_epoch_{epoch + 1}.onnx"
            generator.eval()
            dummy_input = torch.randn(1, 3, 512, 512).to(device)
            traced_script_module = torch.jit.trace(generator, dummy_input)
            torch.onnx.export(
                traced_script_module,
                dummy_input,
                onnx_path,
                export_params=True,
                opset_version=11,
                do_constant_folding=True,
                input_names=["input"],
                output_names=["output"],
                dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
            )
            print(f"ONNX model exported to {onnx_path}")
            generator.train()


# 7. Implement the argument parser for configuration
def parse_args():
    parser = argparse.ArgumentParser(
        description="Conditional GAN with pix2pix architecture"
    )
    parser.add_argument(
        "--input_dir",
        type=str,
        default="datasets/trees",
        help="Input dataset directory",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Output directory for generated images",
    )
    parser.add_argument(
        "--sample_interval",
        type=int,
        default=100,
        help="Interval for saving sample images",
    )
    parser.add_argument(
        "--snapshot_interval",
        type=int,
        default=1,
        help="Interval for saving model snapshots",
    )
    parser.add_argument(
        "--epochs", type=int, default=200, help="Number of epochs to train"
    )
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
    parser.add_argument(
        "--restart", action="store_true", help="Restart training from scratch"
    )
    return parser.parse_args()


# 8. Set up the main function to run the training
def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    transform = transforms.Compose(
        [
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    dataset = Pix2PixDataset(args.input_dir, transform=transform)
    dataloader = DataLoader(
        dataset, batch_size=args.batch_size, shuffle=True, num_workers=4
    )

    generator = Generator().to(device)
    discriminator = Discriminator().to(device)

    train(generator, discriminator, dataloader, args)


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    main()