Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -201,7 +201,7 @@ def is_valid_image_filename(name):
|
|
| 201 |
return False
|
| 202 |
|
| 203 |
|
| 204 |
-
def
|
| 205 |
video = cv2.VideoCapture(video_file)
|
| 206 |
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 207 |
interval = total_frames // num_frames
|
|
@@ -216,7 +216,7 @@ def sample_frames_old(video_file, num_frames):
|
|
| 216 |
video.release()
|
| 217 |
return frames
|
| 218 |
|
| 219 |
-
def
|
| 220 |
video_frames = []
|
| 221 |
vr = VideoReader(video_path, ctx=cpu(0))
|
| 222 |
total_frames = len(vr)
|
|
@@ -240,6 +240,22 @@ def sample_frames(video_path, frame_count=32):
|
|
| 240 |
|
| 241 |
return video_frames
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
def load_image(image_file):
|
| 245 |
if image_file.startswith("http") or image_file.startswith("https"):
|
|
@@ -319,6 +335,7 @@ def bot(history, temperature, top_p, max_output_tokens):
|
|
| 319 |
images_this_term = []
|
| 320 |
text_this_term = ""
|
| 321 |
|
|
|
|
| 322 |
num_new_images = 0
|
| 323 |
# previous_image = False
|
| 324 |
for i, message in enumerate(history[:-1]):
|
|
@@ -332,7 +349,9 @@ def bot(history, temperature, top_p, max_output_tokens):
|
|
| 332 |
if is_valid_video_filename(message[0][0]):
|
| 333 |
# raise ValueError("Video is not supported")
|
| 334 |
# num_new_images += our_chatbot.num_frames
|
| 335 |
-
num_new_images += len(sample_frames(message[0][0], our_chatbot.num_frames))
|
|
|
|
|
|
|
| 336 |
elif is_valid_image_filename(message[0][0]):
|
| 337 |
print("#### Load image from local file",message[0][0])
|
| 338 |
num_new_images += 1
|
|
@@ -343,6 +362,7 @@ def bot(history, temperature, top_p, max_output_tokens):
|
|
| 343 |
num_new_images = 0
|
| 344 |
# previous_image = False
|
| 345 |
|
|
|
|
| 346 |
image_list = []
|
| 347 |
for f in images_this_term:
|
| 348 |
if is_valid_video_filename(f):
|
|
@@ -388,19 +408,21 @@ def bot(history, temperature, top_p, max_output_tokens):
|
|
| 388 |
with open(file_path, "rb") as src, open(filename, "wb") as dst:
|
| 389 |
dst.write(src.read())
|
| 390 |
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
]
|
| 396 |
-
.
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
]
|
| 400 |
|
| 401 |
|
| 402 |
-
|
| 403 |
-
image_token = DEFAULT_IMAGE_TOKEN * num_new_images
|
| 404 |
|
| 405 |
inp = text
|
| 406 |
inp = image_token + "\n" + inp
|
|
@@ -440,6 +462,7 @@ def bot(history, temperature, top_p, max_output_tokens):
|
|
| 440 |
max_new_tokens=max_output_tokens,
|
| 441 |
use_cache=False,
|
| 442 |
stopping_criteria=[stopping_criteria],
|
|
|
|
| 443 |
)
|
| 444 |
|
| 445 |
t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
|
|
|
|
| 201 |
return False
|
| 202 |
|
| 203 |
|
| 204 |
+
def sample_frames_v1(video_file, num_frames):
|
| 205 |
video = cv2.VideoCapture(video_file)
|
| 206 |
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 207 |
interval = total_frames // num_frames
|
|
|
|
| 216 |
video.release()
|
| 217 |
return frames
|
| 218 |
|
| 219 |
+
def sample_frames_v2(video_path, frame_count=32):
|
| 220 |
video_frames = []
|
| 221 |
vr = VideoReader(video_path, ctx=cpu(0))
|
| 222 |
total_frames = len(vr)
|
|
|
|
| 240 |
|
| 241 |
return video_frames
|
| 242 |
|
| 243 |
+
def sample_frames(video_path, num_frames=8):
|
| 244 |
+
cap = cv2.VideoCapture(video_path)
|
| 245 |
+
frames = []
|
| 246 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 247 |
+
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
| 248 |
+
|
| 249 |
+
for i in indices:
|
| 250 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
| 251 |
+
ret, frame = cap.read()
|
| 252 |
+
if ret:
|
| 253 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 254 |
+
frames.append(Image.fromarray(frame))
|
| 255 |
+
|
| 256 |
+
cap.release()
|
| 257 |
+
return frames
|
| 258 |
+
|
| 259 |
|
| 260 |
def load_image(image_file):
|
| 261 |
if image_file.startswith("http") or image_file.startswith("https"):
|
|
|
|
| 335 |
images_this_term = []
|
| 336 |
text_this_term = ""
|
| 337 |
|
| 338 |
+
is_video = False
|
| 339 |
num_new_images = 0
|
| 340 |
# previous_image = False
|
| 341 |
for i, message in enumerate(history[:-1]):
|
|
|
|
| 349 |
if is_valid_video_filename(message[0][0]):
|
| 350 |
# raise ValueError("Video is not supported")
|
| 351 |
# num_new_images += our_chatbot.num_frames
|
| 352 |
+
# num_new_images += len(sample_frames(message[0][0], our_chatbot.num_frames))
|
| 353 |
+
num_new_images += 1
|
| 354 |
+
is_video = True
|
| 355 |
elif is_valid_image_filename(message[0][0]):
|
| 356 |
print("#### Load image from local file",message[0][0])
|
| 357 |
num_new_images += 1
|
|
|
|
| 362 |
num_new_images = 0
|
| 363 |
# previous_image = False
|
| 364 |
|
| 365 |
+
|
| 366 |
image_list = []
|
| 367 |
for f in images_this_term:
|
| 368 |
if is_valid_video_filename(f):
|
|
|
|
| 408 |
with open(file_path, "rb") as src, open(filename, "wb") as dst:
|
| 409 |
dst.write(src.read())
|
| 410 |
|
| 411 |
+
if not is_video:
|
| 412 |
+
image_tensor = [
|
| 413 |
+
our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
|
| 414 |
+
0
|
| 415 |
+
]
|
| 416 |
+
.half()
|
| 417 |
+
.to(our_chatbot.model.device)
|
| 418 |
+
for f in image_list
|
| 419 |
]
|
| 420 |
+
image_tensor = torch.stack(image_tensor)
|
| 421 |
+
else:
|
| 422 |
+
image_tensor = our_chatbot.image_processor.preprocess(image_list, return_tensors="pt")["pixel_values"].half().to(our_chatbot.model.device)
|
|
|
|
| 423 |
|
| 424 |
|
| 425 |
+
image_token = DEFAULT_IMAGE_TOKEN * num_new_images if not is_video else DEFAULT_IMAGE_TOKEN * num_new_images
|
|
|
|
| 426 |
|
| 427 |
inp = text
|
| 428 |
inp = image_token + "\n" + inp
|
|
|
|
| 462 |
max_new_tokens=max_output_tokens,
|
| 463 |
use_cache=False,
|
| 464 |
stopping_criteria=[stopping_criteria],
|
| 465 |
+
modalities=["video"] if is_video else ["image"]
|
| 466 |
)
|
| 467 |
|
| 468 |
t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
|