fix import error

This commit is contained in:
shuaiqing 2023-06-19 20:46:44 +08:00
parent e9d5f061a5
commit e30a28bff0
2 changed files with 5 additions and 2 deletions

View File

@ -6,7 +6,6 @@ import numpy as np
import math import math
# https://download.openmmlab.com/mmpose/hand/hrnetv2/hrnetv2_w18_rhd2d_256x256-95b20dd8_20210330.pth # https://download.openmmlab.com/mmpose/hand/hrnetv2/hrnetv2_w18_rhd2d_256x256-95b20dd8_20210330.pth
# https://download.openmmlab.com/mmpose/hand/dark/hrnetv2_w18_onehand10k_256x256_dark-a2f80c64_20210330.pth # https://download.openmmlab.com/mmpose/hand/dark/hrnetv2_w18_onehand10k_256x256_dark-a2f80c64_20210330.pth
from .hrnet import PoseHighResolutionNet
from ..basetopdown import BaseTopDownModelCache, get_preds_from_heatmaps, gdown_models from ..basetopdown import BaseTopDownModelCache, get_preds_from_heatmaps, gdown_models
class TopDownAsMMPose(nn.Module): class TopDownAsMMPose(nn.Module):
@ -31,6 +30,7 @@ class MyHand2D(BaseTopDownModelCache):
def __init__(self, ckpt, url=None, mode='hrnet'): def __init__(self, ckpt, url=None, mode='hrnet'):
if mode == 'hrnet': if mode == 'hrnet':
super().__init__(name='hand2d', bbox_scale=1.1, res_input=256) super().__init__(name='hand2d', bbox_scale=1.1, res_input=256)
from .hrnet import PoseHighResolutionNet
backbone = PoseHighResolutionNet(inp_ch=3, out_ch=21, W=18, multi_scale_final=True, add_final_layer=False) backbone = PoseHighResolutionNet(inp_ch=3, out_ch=21, W=18, multi_scale_final=True, add_final_layer=False)
checkpoint = torch.load(ckpt, map_location='cpu')['state_dict'] checkpoint = torch.load(ckpt, map_location='cpu')['state_dict']
self.load_checkpoint(backbone, checkpoint, prefix='backbone.', strict=True) self.load_checkpoint(backbone, checkpoint, prefix='backbone.', strict=True)

View File

@ -5,8 +5,11 @@ opencv-python
yacs yacs
tabulate tabulate
termcolor termcolor
chumpy git+https://github.com/mattloper/chumpy.git
mediapipe==0.10.0 mediapipe==0.10.0
func_timeout func_timeout
ultralytics ultralytics
gdown gdown
setuptools==59.5.0
tensorboard==2.8.0
pytorch-lightning==1.5.0