svm과 머신러닝
This commit is contained in:
48
ch15/svmplane/main.cpp
Normal file
48
ch15/svmplane/main.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
#include "opencv2/opencv.hpp"
|
||||
#include <iostream>
|
||||
|
||||
using namespace cv;
|
||||
using namespace cv::ml;
|
||||
using namespace std;
|
||||
|
||||
int main() {
|
||||
Mat train = Mat_<float>({ 8, 2 }, {
|
||||
150, 200, 200, 250, 100, 250, 150, 300,
|
||||
350, 100, 400, 200, 400, 300, 350, 400 });
|
||||
Mat label = Mat_<int>({ 8, 1 }, { 0, 0, 0, 0, 1, 1, 1, 1 });
|
||||
|
||||
Ptr<SVM> svm = SVM::create();
|
||||
svm->setType(SVM::Types::C_SVC);
|
||||
svm->setKernel(SVM::KernelTypes::RBF);
|
||||
svm->trainAuto(train, ROW_SAMPLE, label);
|
||||
|
||||
Mat img = Mat::zeros(Size(500, 500), CV_8UC3);
|
||||
|
||||
for (int j = 0; j < img.rows; j++) {
|
||||
for (int i = 0; i < img.cols; i++) {
|
||||
Mat test = Mat_<float>({ 1, 2 }, { (float)i, (float)j });
|
||||
int res = cvRound(svm->predict(test));
|
||||
|
||||
if (res == 0)
|
||||
img.at<Vec3b>(j, i) = Vec3b(128, 128, 255);
|
||||
else
|
||||
img.at<Vec3b>(j, i) = Vec3b(128, 255, 128);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < train.rows; i++) {
|
||||
int x = cvRound(train.at<float>(i, 0));
|
||||
int y = cvRound(train.at<float>(i, 1));
|
||||
int l = label.at<int>(i, 0);
|
||||
|
||||
if (l == 0)
|
||||
circle(img, Point(x, y), 5, Scalar(0, 0, 128), -1, LINE_AA);
|
||||
else
|
||||
circle(img, Point(x, y), 5, Scalar(0, 128, 0), -1, LINE_AA);
|
||||
}
|
||||
|
||||
imshow("img", img);
|
||||
|
||||
waitKey();
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user