机器学习-用本地摄像头做图片分类训练

  • 图片分类是通过采集一定数量的样本,训练出适用的模型,再通过调用模型对新的图片进行识别的一个过程。相较于物体检测,图片分类更加适用于单个物体的分类识别。为了展示图片分类的效果,这里介绍基于本地摄像头的图片分类。在没有树莓派的情况下也可以通过电脑进行模型的训练。

准备阶段

  • 首先,把USB摄像头模块连接到电脑上。(使用电脑自带摄像头也可)

  • 登陆古德微平台后,按顺序点击更多功能——机器学习——基于本地摄像头的图片分类。

    图1

    (图1)进入图片分类界面

注:基于本地摄像头的图片分类训练模型是使用本地的摄像头,如果在树莓派上打开网页进行训练那就是使用连接在树莓派上的摄像头。

训练模型

  • 接下来开始训练模型,以“苹果”模型和“梨”模型为例。首先,进行图片采样和标注,将需要识别的物体放在背景下,再点击摄像头标志打开摄像头,点击拍照按钮进行采样(可以转动物体模型,多角度对物体进行采样,这样训练识别效果会更好,采样10-20张为宜)。采样结束后将类别名修改为目标类别。比如,上面放的苹果模型就标注为苹果,重复上述步骤采样及标注另外的物体模型。

    图2

    (图2)开启摄像头及标注类别

    图3

    (图3)拍照采样


  • 采样结束后,点击“开始训练”对样本图片进行训练。(训练需要一定时间)

    图4

    (图4)点击开始训练


  • 训练完成后右方可以查看效果预览,来验证训练模型识别的准确性。 我们通过置信度的高低判断拍摄物体与训练物体的相似程度,并以此来判断物体的类别。效果如下图:

    图5

    (图5)图片分类模型预览效果


  • 训练完成后点击下载树莓派可用模型,下方会出现一个进度条,当进度条读完时模型下载完毕,一共下载两个文件文件, 一个是后缀名为.tflite的模型文件,一个是后缀名为.txt的标签文件。

    图6

    (图6)下载模型

在树莓派中使用前面训练的模型进行图片分类

  • 如何使用图片分类模型进行识别

    • 将上一步下载的模型文件和标签文件拷贝到树莓派(如何拷贝文件到树莓派),调用以下积木,即可进行图片分类。

      图7

      (图7)图片分类简单

    • 在加载图片分类模型的积木块中,分别输入要加载的模型文件和标签文件路径。将要识别的物体放在摄像头下,点击运行,在右侧调试区查看运行结果。

    • 点击这里下载本案例代码。

  • 图片分类应用案例:水果分类器

    • 案例简介:使用180舵机控制平板,当把苹果放到平板上时,舵机控制平板往一侧翻转;当把梨放到平板上时,舵机控制平板往另一侧翻转。
    • 代码截图:
      图8

      (图8)水果分类器代码

    • 效果演示:
      图9

      (图9)水果分类器效果演示

    • 点击这里下载本案例代码。
  • 使用python加载图片分类模型案例

    • 案例简介:使用python加载图片分类模型,实时对画面进行识别分类,将分类结果显示到画面中。
    • 效果演示
      <div align="center">
          <img src="/media/水果分类识别python演示.gif" alt="图11" width="800">
          <h4>(图10)水果分类识别演示</h4>
       </div>
      
    • python代码

        import time
        import numpy as np
        import cv2
        from tflite_runtime.interpreter import Interpreter
        from PIL import Image, ImageDraw, ImageFont
        def load_labels(path):
            with open(path, 'r') as f:
                return {i: line.strip() for i, line in enumerate(f.readlines())}
        def cv2ImgAddText(img, text, left, top, textColor=(0, 255, 0), textSize=20):
            if (isinstance(img, np.ndarray)):  # 判断是否OpenCV图片类型
                img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            # 创建一个可以在给定图像上绘图的对象
            draw = ImageDraw.Draw(img)
            # 字体的格式
            fontStyle = ImageFont.truetype('/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', textSize)
            # 绘制文本
            draw.text((left, top), text, textColor, font=fontStyle)
            # 转换回OpenCV格式
            return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
        #加载水果分类模型
        labels_path = '/home/pi/model/image_classification/labels.txt'
        model_path = '/home/pi/model/image_classification/model.tflite'
        labels = load_labels(labels_path)
        print("#load model")
        interpreter = Interpreter(model_path = model_path)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        _, height, width, _ = input_details[0]['shape']
        try:
            cap = cv2.VideoCapture(0)
            while cap.isOpened():
                ret, new_img = cap.read()        
                origin_img = new_img.copy()
                new_img = cv2.cvtColor(new_img, cv2.COLOR_BGR2RGB)
                new_img = cv2.resize(new_img,(width,height))
                new_img = np.expand_dims(new_img, axis=0)
                start_time = time.time()
                interpreter.set_tensor(input_details[0]['index'], new_img)
                # 开始预测
                interpreter.invoke()   
                # 获取预测的结果
                output_data = np.squeeze(interpreter.get_tensor(output_details[0]['index']))        
                max_label_id = np.argmax(output_data)
                if output_details[0]['dtype'] == np.uint8:
                    scale, zero_point = output_details[0]['quantization']
                    output_data = scale * (output_data - zero_point)            
                elapsed_ms = (time.time() - start_time)
                origin_img = cv2ImgAddText(origin_img, '%s accuracy:%.2f fps:%d' % (labels[max_label_id], output_data[max_label_id], int(1/elapsed_ms)),30,30,(0,255,0), 40)
                cv2.imshow("frame", origin_img)
                if cv2.waitKey(1) == ord('q'):
                    break
        finally:
            cv2.destroyAllWindows()
      
Copyright © 古德微 2023 all right reserved,powered by GDWRobot本课修订时间: 2022-11-07

results matching ""

    No results matching ""