用java做物品识别和姿态识别
前言
之前搞得语音识别突然发现浏览器就有接口可以直接用,而且识别又快又准,参考:使用 JavaScript 的 SpeechRecognition API 实现语音识别_speechrecognition js-CSDN博客
进入正题
这个功能首先要感谢一下作者常康,仓库地址(gitee.com/agriculture… 这个项目很早之前就关注了,最近这段时间正好要用才真正实践了一下,只是初步测试了一下,在性能方面还需要进一步测试,本人电脑就很拉识别就很卡。
先看效果
改动
主要对姿态识别做了一些小改动,将原图片识别改成视频视频识别,如果要调用摄像头将video.open(0);的代码注释放开即可
package cn.ck;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import cn.ck.config.PEConfig;
import cn.ck.domain.KeyPoint;
import cn.ck.domain.PEResult;
import cn.ck.utils.Letterbox;
import nu.pattern.OpenCV;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.highgui.HighGui;
import org.opencv.imgproc.Imgproc;
import org.opencv.videoio.VideoCapture;
import org.opencv.videoio.Videoio;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/*
* 姿态识别,可以识别动作等等.,比如跳绳技术
*/
public class PoseEstimation {
static {
// 加载opencv动态库
//System.load(ClassLoader.getSystemResource("lib/opencv_java470-无用.dll").getPath());
OpenCV.loadLocally();
}
public static void main(String[] args) throws OrtException {
String model_path = "src\main\resources\model\yolov7-w6-pose-nms.onnx";
// 加载ONNX模型
OrtEnvironment environment = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
OrtSession session = environment.createSession(model_path, sessionOptions);
// 输出基本信息
session.getInputInfo().keySet().forEach(x -> {
try {
System.out.println("input name = " + x);
System.out.println(session.getInputInfo().get(x).getInfo().toString());
} catch (OrtException e) {
throw new RuntimeException(e);
}
});
VideoCapture video = new VideoCapture();
// video.open(0); //获取电脑上第0个摄像头
//可以把识别后的视频在通过rtmp转发到其他流媒体服务器,就可以远程预览视频后视频,需要使用ffmpeg将连续图片合成flv 等等,很简单。
if (!video.isOpened()) {
System.err.println("打开视频流失败,未检测到监控,请先用vlc软件测试链接是否可以播放!,下面试用默认测试视频进行预览效果!");
video.open("video/test2.mp4");
}
// 跳帧检测,一般设置为3,毫秒内视频画面变化是不大的,快了无意义,反而浪费性能
int detect_skip = 4;
// 跳帧计数
int detect_skip_index = 1;
// 最新一帧也就是上一帧推理结果
float[][] outputData = null;
//当前最新一帧。上一帧也可以暂存一下
Mat img = new Mat();
// 在这里先定义下线的粗细、关键的半径(按比例设置大小粗细比较好一些)
int minDwDh = Math.min((int)video.get(Videoio.CAP_PROP_FRAME_WIDTH), (int)video.get(Videoio.CAP_PROP_FRAME_HEIGHT));
int thickness = minDwDh / PEConfig.lineThicknessRatio;
int radius = minDwDh / PEConfig.dotRadiusRatio;
// 转换颜色空间
Mat image = new Mat();
// 图像预处理
Letterbox letterbox = new Letterbox();
letterbox.setNewShape(new Size(960, 960));
letterbox.setStride(64);
// 使用多线程和GPU可以提升帧率,线上项目必须多线程!!!,一个线程拉流,将图像存到[定长]队列或数组或者集合,一个线程模型推理,中间通过变量或者队列交换数据,代码示例仅仅使用单线程
while (video.read(img)) {
if ((detect_skip_index % detect_skip == 0) || outputData == null) {
Imgproc.cvtColor(img, image, Imgproc.COLOR_BGR2RGB);
image = letterbox.letterbox(image);
int rows = letterbox.getHeight();
int cols = letterbox.getWidth();
int channels = image.channels();
// 将图像转换为模型输入格式
float[] pixels = new float[channels * rows * cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
double[] pixel = image.get(j, i);
for (int k = 0; k < channels; k++) {
pixels[rows * cols * k + j * cols + i] = (float) pixel[k] / 255.0f;
}
}
}
detect_skip_index = 1;
OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), new long[]{1L, (long) channels, (long) rows, (long) cols});
OrtSession.Result output = session.run(Collections.singletonMap(session.getInputInfo().keySet().iterator().next(), tensor));
// 处理输出结果并绘制
outputData = ((float[][]) output.get(0).getValue());
}else{
detect_skip_index = detect_skip_index + 1;
}
double ratio = letterbox.getRatio();
double dw =letterbox.getDw();
double dh = letterbox.getDh();
List<PEResult> peResults = new ArrayList<>();
for (float[] outputDatum : outputData) {
PEResult result = new PEResult(outputDatum);
if (result.getScore() > PEConfig.personScoreThreshold) {
peResults.add(result);
}
}
// 对结果进行非极大值抑制
peResults = nms(peResults, PEConfig.IoUThreshold);
for (PEResult peResult: peResults) {
System.out.println(peResult);
// 画框
Point topLeft = new Point((peResult.getX0()-dw)/ratio, (peResult.getY0()-dh)/ratio);
Point bottomRight = new Point((peResult.getX1()-dw)/ratio, (peResult.getY1()-dh)/ratio);
// Imgproc.rectangle(img, topLeft, bottomRight, new Scalar(255,0,0), thickness);
List<KeyPoint> keyPoints = peResult.getKeyPointList();
// 画点
keyPoints.forEach(keyPoint->{
if (keyPoint.getScore()>PEConfig.keyPointScoreThreshold) {
Point center = new Point((keyPoint.getX()-dw)/ratio, (keyPoint.getY()-dh)/ratio);
Scalar color = PEConfig.poseKptColor.get(keyPoint.getId());
Imgproc.circle(img, center, radius, color, -1); //-1表示实心
}
});
// 画线
for (int i = 0; i< PEConfig.skeleton.length; i++){
int indexPoint1 = PEConfig.skeleton[i][0]-1;
int indexPoint2 = PEConfig.skeleton[i][1]-1;
if ( keyPoints.get(indexPoint1).getScore()>PEConfig.keyPointScoreThreshold &&
keyPoints.get(indexPoint2).getScore()>PEConfig.keyPointScoreThreshold ) {
Scalar coler = PEConfig.poseLimbColor.get(i);
Point point1 = new Point(
(keyPoints.get(indexPoint1).getX()-dw)/ratio,
(keyPoints.get(indexPoint1).getY()-dh)/ratio
);
Point point2 = new Point(
(keyPoints.get(indexPoint2).getX()-dw)/ratio,
(keyPoints.get(indexPoint2).getY()-dh)/ratio
);
Imgproc.line(img, point1, point2, coler, thickness);
}
}
}
//服务器部署:由于服务器没有桌面,所以无法弹出画面预览,主要注释一下代码
HighGui.imshow("result", img);
// 多次按任意按键关闭弹窗画面,结束程序
if(HighGui.waitKey(1) != -1){
break;
}
}
HighGui.destroyAllWindows();
video.release();
System.exit(0);
}
public static List<PEResult> nms(List<PEResult> boxes, float iouThreshold) {
// 根据score从大到小对List进行排序
boxes.sort((b1, b2) -> Float.compare(b2.getScore(), b1.getScore()));
List<PEResult> resultList = new ArrayList<>();
for (int i = 0; i < boxes.size(); i++) {
PEResult box = boxes.get(i);
boolean keep = true;
// 从i+1开始,遍历之后的所有boxes,移除与box的IOU大于阈值的元素
for (int j = i + 1; j < boxes.size(); j++) {
PEResult otherBox = boxes.get(j);
float iou = getIntersectionOverUnion(box, otherBox);
if (iou > iouThreshold) {
keep = false;
break;
}
}
if (keep) {
resultList.add(box);
}
}
return resultList;
}
private static float getIntersectionOverUnion(PEResult box1, PEResult box2) {
float x1 = Math.max(box1.getX0(), box2.getX0());
float y1 = Math.max(box1.getY0(), box2.getY0());
float x2 = Math.min(box1.getX1(), box2.getX1());
float y2 = Math.min(box1.getY1(), box2.getY1());
float intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
float box1Area = (box1.getX1() - box1.getX0()) * (box1.getY1() - box1.getY0());
float box2Area = (box2.getX1() - box2.getX0()) * (box2.getY1() - box2.getY0());
float unionArea = box1Area + box2Area - intersectionArea;
return intersectionArea / unionArea;
}
}
姿态识别模型提取链接,
通过网盘分享的文件:yolov7-w6-pose-nms.onnx
链接: pan.baidu.com/s/1UdAUPWr1… 提取码: du6y
后言
就像原作者说的,不是每个同学都会python,不是每个项目都是python语言开发,不是每个岗位都会深度学习。
希望java在AI领域能有更好的发展
作者:北冥有鱼518
来源:juejin.cn/post/7413234304278970404
来源:juejin.cn/post/7413234304278970404