-
Notifications
You must be signed in to change notification settings - Fork 962
/
Copy pathllm_flux_cogvideox.py
257 lines (217 loc) · 9.24 KB
/
llm_flux_cogvideox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""
The original experimental code for this project can be found at:
https://gist.github.com/a-r-r-o-w/d070cce059ab4ceab3a9f289ff83c69c
By using this code, description prompts will be generated through a local large language model, and images will be
generated using the black-forest-labs/FLUX.1-dev model, followed by video generation via CogVideoX.
The entire process utilizes open-source solutions, without the need for any API keys.
You can use the generate.sh file in the same folder to automate running this code
for batch generation of videos and images.
bash generate.sh
"""
import argparse
import gc
import json
import os
import pathlib
import random
from typing import Any, Dict
from transformers import AutoTokenizer
os.environ["TORCH_LOGS"] = "+dynamo,recompiles,graph_breaks"
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
import numpy as np
import torch
import transformers
from diffusers import CogVideoXImageToVideoPipeline, CogVideoXDPMScheduler, DiffusionPipeline
from diffusers.utils.logging import get_logger
from diffusers.utils import export_to_video
torch.set_float32_matmul_precision("high")
logger = get_logger(__name__)
SYSTEM_PROMPT = """
You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe.
For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model.
There are a few rules to follow:
- You will only ever output a single video description per request.
- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit.
You responses should just be the video generation prompt. Here are examples:
- “A lone figure stands on a city rooftop at night, gazing up at the full moon. The moon glows brightly, casting a gentle light over the quiet cityscape. Below, the windows of countless homes shine with warm lights, creating a contrast between the bustling life below and the peaceful solitude above. The scene captures the essence of the Mid-Autumn Festival, where despite the distance, the figure feels connected to loved ones through the shared beauty of the moonlit sky.”
- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
- "A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall"
""".strip()
USER_PROMPT = """
Could you generate a prompt for a video generation model?
Please limit the prompt to [{0}] words.
""".strip()
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num_videos",
type=int,
default=5,
help="Number of unique videos you would like to generate."
)
parser.add_argument(
"--model_path",
type=str,
default="THUDM/CogVideoX-5B",
help="The path of Image2Video CogVideoX-5B",
)
parser.add_argument(
"--caption_generator_model_id",
type=str,
default="THUDM/glm-4-9b-chat",
help="Caption generation model. default GLM-4-9B",
)
parser.add_argument(
"--caption_generator_cache_dir",
type=str,
default=None,
help="Cache directory for caption generation model."
)
parser.add_argument(
"--image_generator_model_id",
type=str,
default="black-forest-labs/FLUX.1-dev",
help="Image generation model."
)
parser.add_argument(
"--image_generator_cache_dir",
type=str,
default=None,
help="Cache directory for image generation model."
)
parser.add_argument(
"--image_generator_num_inference_steps",
type=int,
default=50,
help="Caption generation model."
)
parser.add_argument(
"--guidance_scale",
type=float,
default=7,
help="Guidance scale to be use for generation."
)
parser.add_argument(
"--use_dynamic_cfg",
action="store_true",
help="Whether or not to use cosine dynamic guidance for generation [Recommended].",
)
parser.add_argument(
"--output_dir",
type=str,
default="outputs/",
help="Location where generated images and videos should be stored.",
)
parser.add_argument(
"--compile",
action="store_true",
help="Whether or not to compile the transformer of image and video generators."
)
parser.add_argument(
"--enable_vae_tiling",
action="store_true",
help="Whether or not to use VAE tiling when encoding/decoding."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Seed for reproducibility."
)
return parser.parse_args()
def reset_memory():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
@torch.no_grad()
def main(args: Dict[str, Any]) -> None:
output_dir = pathlib.Path(args.output_dir)
os.makedirs(output_dir.as_posix(), exist_ok=True)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
reset_memory()
tokenizer = AutoTokenizer.from_pretrained(args.caption_generator_model_id, trust_remote_code=True)
caption_generator = transformers.pipeline(
"text-generation",
model=args.caption_generator_model_id,
device_map="auto",
model_kwargs={
"local_files_only": True,
"cache_dir": args.caption_generator_cache_dir,
"torch_dtype": torch.bfloat16,
},
trust_remote_code=True,
tokenizer=tokenizer
)
captions = []
for i in range(args.num_videos):
num_words = random.choice([50, 75, 100])
user_prompt = USER_PROMPT.format(num_words)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
]
outputs = caption_generator(messages, max_new_tokens=226)
caption = outputs[0]["generated_text"][-1]["content"]
if caption.startswith("\"") and caption.endswith("\""):
caption = caption[1:-1]
captions.append(caption)
logger.info(f"Generated caption: {caption}")
with open(output_dir / "captions.json", "w") as file:
json.dump(captions, file)
del caption_generator
reset_memory()
image_generator = DiffusionPipeline.from_pretrained(
args.image_generator_model_id,
cache_dir=args.image_generator_cache_dir,
torch_dtype=torch.bfloat16
)
image_generator.to("cuda")
if args.compile:
image_generator.transformer = torch.compile(image_generator.transformer, mode="max-autotune", fullgraph=True)
if args.enable_vae_tiling:
image_generator.vae.enable_tiling()
images = []
for index, caption in enumerate(captions):
image = image_generator(
prompt=caption,
height=480,
width=720,
num_inference_steps=args.image_generator_num_inference_steps,
guidance_scale=3.5,
).images[0]
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
image.save(output_dir / f"{index}_{filename}.png")
images.append(image)
del image_generator
reset_memory()
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
args.model_path, torch_dtype=torch.bfloat16).to("cuda")
video_generator.scheduler = CogVideoXDPMScheduler.from_config(
video_generator.scheduler.config,
timestep_spacing="trailing")
if args.compile:
video_generator.transformer = torch.compile(video_generator.transformer, mode="max-autotune", fullgraph=True)
if args.enable_vae_tiling:
video_generator.vae.enable_tiling()
generator = torch.Generator().manual_seed(args.seed)
for index, (caption, image) in enumerate(zip(captions, images)):
video = video_generator(
image=image,
prompt=caption,
height=480,
width=720,
num_frames=49,
num_inference_steps=50,
guidance_scale=args.guidance_scale,
use_dynamic_cfg=args.use_dynamic_cfg,
generator=generator,
).frames[0]
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8)
if __name__ == "__main__":
args = get_args()
main(args)