Skip to content

Commit 7fef862

Browse files
Add Cog file to build Docker image
1 parent b01ffcd commit 7fef862

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ I found bugs in the implementation thanks to @adambielski and @TropComplique! (h
1111

1212
Implementation of A Style-Based Generator Architecture for Generative Adversarial Networks (https://arxiv.org/abs/1812.04948) in PyTorch
1313

14+
* [Demo and Docker image on Replicate](https://replicate.ai/rosinality/style-based-gan-pytorch)
15+
1416
Usage:
1517

1618
You should prepare lmdb dataset

cog.yaml

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
predict: predict.py:Predictor
2+
build:
3+
python_version: 3.8
4+
python_packages:
5+
- torch==1.7.0
6+
- torchvision==0.8.1
7+
- tqdm==4.59.0
8+
- pillow==8.1.2
9+
- lmdb==1.1.1

predict.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
import math
3+
import tempfile
4+
from pathlib import Path
5+
import torch
6+
from torchvision import utils
7+
import cog
8+
9+
from generate import sample, get_mean_style
10+
from model import StyledGenerator
11+
12+
SIZE = 1024
13+
14+
15+
class Predictor(cog.Predictor):
16+
def setup(self):
17+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18+
self.generator = StyledGenerator(512).to(self.device)
19+
print("Loading checkpoint")
20+
self.generator.load_state_dict(
21+
torch.load(
22+
"stylegan-1024px-new.model",
23+
map_location=self.device,
24+
)["g_running"],
25+
)
26+
self.generator.eval()
27+
28+
@cog.input("seed", type=int, default=-1, help="Random seed, -1 for random")
29+
def predict(self, seed):
30+
if seed < 0:
31+
seed = int.from_bytes(os.urandom(2), "big")
32+
torch.manual_seed(seed)
33+
print(f"seed: {seed}")
34+
35+
mean_style = get_mean_style(self.generator, self.device)
36+
step = int(math.log(SIZE, 2)) - 2
37+
img = sample(self.generator, step, mean_style, 1, self.device)
38+
output_path = Path(tempfile.mkdtemp()) / "output.png"
39+
utils.save_image(img, output_path, normalize=True)
40+
return output_path

0 commit comments

Comments
 (0)