onnxruntime only used for GPU: not install otherwise
This commit is contained in:
parent
556a5c6125
commit
8af6ec8075
@ -39,7 +39,6 @@ import logging
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import cv2
|
||||
import onnxruntime as ort
|
||||
|
||||
from rtmlib import PoseTracker, Body, Wholebody, BodyWithFeet, draw_skeleton
|
||||
from Pose2Sim.common import natural_sort_key
|
||||
@ -358,23 +357,28 @@ def rtm_estimator(config_dict):
|
||||
frame_rate = 60
|
||||
|
||||
# If CUDA is available, use it with ONNXRuntime backend; else use CPU with openvino
|
||||
if 'CUDAExecutionProvider' in ort.get_available_providers():
|
||||
try:
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
if torch.cuda.is_available() == False and 'CUDAExecutionProvider' in ort.get_available_providers():
|
||||
device = 'cuda'
|
||||
backend = 'onnxruntime'
|
||||
logging.info(f"\nValid CUDA installation found: using ONNXRuntime backend with GPU.")
|
||||
else:
|
||||
raise
|
||||
except:
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available() == False:
|
||||
device = 'cuda'
|
||||
import onnxruntime as ort
|
||||
if 'MPSExecutionProvider' in ort.get_available_providers() or 'CoreMLExecutionProvider' in ort.get_available_providers():
|
||||
device = 'mps'
|
||||
backend = 'onnxruntime'
|
||||
logging.info(f"\nValid CUDA installation found: using ONNXRuntime backend with GPU.")
|
||||
logging.info(f"\nValid MPS installation found: using ONNXRuntime backend with GPU.")
|
||||
else:
|
||||
raise
|
||||
except:
|
||||
pass
|
||||
elif 'MPSExecutionProvider' in ort.get_available_providers() or 'CoreMLExecutionProvider' in ort.get_available_providers():
|
||||
device = 'mps'
|
||||
backend = 'onnxruntime'
|
||||
logging.info(f"\nValid MPS installation found: using ONNXRuntime backend with GPU.")
|
||||
else:
|
||||
device = 'cpu'
|
||||
backend = 'openvino'
|
||||
logging.info(f"\nNo valid CUDA installation found: using OpenVINO backend with CPU.")
|
||||
device = 'cpu'
|
||||
backend = 'openvino'
|
||||
logging.info(f"\nNo valid CUDA installation found: using OpenVINO backend with CPU.")
|
||||
|
||||
if det_frequency>1:
|
||||
logging.info(f'Inference run only every {det_frequency} frames. Inbetween, pose estimation tracks previously detected points.')
|
||||
|
Loading…
Reference in New Issue
Block a user