Files
COMP90044_a1/I3D.py
2025-04-14 17:15:38 +10:00

53 lines
1.6 KiB
Python

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from pytorchvideo.models.hub import i3d_r50
import cv2
import numpy as np
from PIL import Image
model = i3d_r50(pretrained=True).eval()
feature_extractor = torch.nn.Sequential(*model.blocks[:-1])
def preprocess_video(video_path, num_frames=32, size=224):
cap = cv2.VideoCapture(video_path)
frames = []
transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])
])
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
for i in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frame = transform(frame)
frames.append(frame)
cap.release()
video_tensor = torch.stack(frames).permute(1, 0, 2, 3).unsqueeze(0)
return video_tensor
def extract_features(video_tensor):
with torch.no_grad():
features = feature_extractor(video_tensor)
features = F.adaptive_avg_pool3d(features, 1)
features = features.flatten()
return features.numpy()
def video_features(video_path):
video_tensor = preprocess_video(video_path)
features = extract_features(video_tensor)
return features
video_path = r'D:\DESKTOP\2025\44\a1\dataset\org\0.mp4'
video_f = video_features(video_path)
print(f"视频features: {video_f.shape}")