【Python】Spring Boot部署深度学习模型(Java/Pytorch)

为什么使用Java框架Spring Boot部署深度学习模型

稍早前训练了一些深度学习模型后,遇到了模型部署的一些问题,首先现有的很多业务都是java实现的,例如预算控制,用户限额等,图片识别直接和这些系统交互会造成一定的代码侵入,以及多个系统出现冗余,所以考虑使用Spring Boot将图片侦测服务包装起来,以独立的领域,搭建一个的服务,对外提供图片侦测的功能。其次Spring框架在服务管理、负载等方面有成熟的方案,也方便日后的扩展升级。

本文记录了使用Java部署深度学习模型的过程,注意模型核心还是运行在Pytorch框架上的,这里只是一个提供外围访问或域内调用的API。

可直接参见完整Java项目:https://github.com/anylots/detection
python模型项目的DetectNet:https://github.com/anylots/DetectNet,提供http接口;based on Yet-Another-EfficientDet-Pytorch

框架组成

管理时应用架构为Spring Boot+Thymeleaf+Bootstrap组合,运行时为Pytorch+Flask组合。

设想中的系统架构:
【Python】Spring Boot部署深度学习模型(Java/Pytorch)

Java 管理时部分

第一步,使用接收到的imageLink或上传的文件调用图片识别服务,返回数据为图片的BASE64编码。

第二步,组装Spring的ModelAndView对象 。

第三步,返回ModelAndView对象 ,Thymeleaf引擎会将识别结果返回给前端。

`@Controller

public class ImageDetectController {

/**

* service of imageDetect

*/

@Autowired

private ImageDetectService imageDetectService;

/**

* detect

*

* @return detect.html

*/

@RequestMapping(value = "/detect", method = RequestMethod.GET)

public String detect() {

return "detect";

}

/**

* detect out

*

* @param imageLink

* @return detectOut.html

*/

@RequestMapping(value = "/detectImage", method = RequestMethod.POST)

public ModelAndView detectOut(String imageLink) {

// step 1. detect image by imageUrl

String detectFrame = imageDetectService.detect(imageLink);

// step 2. assemble modelAndView

ModelAndView modelAndView = new ModelAndView();

modelAndView.setViewName("detectOut");

modelAndView.addObject("img", detectFrame);

// step 3. return detect result page

return modelAndView;

}`

* 1

* 2

* 3

* 4

* 5

* 6

* 7

* 8

* 9

* 10

* 11

* 12

* 13

* 14

* 15

* 16

* 17

* 18

* 19

* 20

* 21

* 22

* 23

* 24

* 25

* 26

* 27

* 28

* 29

* 30

* 31

* 32

* 33

* 34

* 35

* 36

* 37

* 38

* 39

* 40

spring boot 项目结构

【Python】Spring Boot部署深度学习模型(Java/Pytorch)

Python 运行时部分

使用flask提供http接口
这里先根据传入的url获取图片,然后调用service层得到识别后的图片信息,最后通过http接口返回给spring boot管理时(现在对python的rpc框架还不了解,后续再研究研究)。
图片数据格式就参考了旷视公司的图片识别接口,采用BASE64编码传输图片信息,

`@app.route('/detect/imageDetect', methods=['post'])

def process():

# step 1. receive image url

image_link = request.form.get("imageLink")

if not image_link.strip():

return "error" # check request

response = req.get(image_link)

image = Image.open(BytesIO(response.content))

# step 2. detect image

image_array = service.detect(image)

# step 3. convert image_array to byte_array

img = Image.fromarray(image_array, 'RGB')

img_byte_array = io.BytesIO()

img.save(img_byte_array, format='JPEG')

# step 4. return image_info to page

image_info = base64.b64encode(img_byte_array.getvalue()).decode('ascii')

return image_info

if __name__ == '__main__':

app.jinja_env.auto_reload = True

app.config['TEMPLATES_AUTO_RELOAD'] = True

app.run(debug=False, port=8081)`

* 1

* 2

* 3

* 4

* 5

* 6

* 7

* 8

* 9

* 10

* 11

* 12

* 13

* 14

* 15

* 16

* 17

* 18

* 19

* 20

* 21

* 22

* 23

* 24

* 25

* 26

* 27

* 28

Pytorch部署EfficientDet

这里使用里一个service层来包装EfficientDet模型,将transforms 、CLASS分类信息、识别器定义为全局变量,避免每次请求都去初始化这些信息,降低耗时。

`import random

import time

import cv2 as opencv

import numpy as np

import torchvision

from PIL import Image

from detector import *

# image detector,return output of detection data

detector = Detector()

# data transforms

transforms = torchvision.transforms.Compose([

torchvision.transforms.ToTensor()

])

# set of names and colors

names = cfg.COCO_CLASS

# draw identification frame based on detection data

class ImgDetectService:

# return a image with boxes based on detection data

def detect(self, img):

start_time = time.time()

# convert image to array

frame = np.array(img)

# convert to cv format

frames = frame[:, :, ::-1]

# convert to model format

image = Image.fromarray(frames, 'RGB')

width, high = image.size

x_w = width / 416

y_h = high / 416

normal_img = image.resize((416, 416))

img_data = transforms(normal_img)

img_data = torch.FloatTensor(img_data).view(-1, 3, 416, 416).to(cfg.DEVICE)

# detect image

y = detector(img_data, 0.7, cfg.ANCHORS_GROUP)[0]

tl = round(0.002 * (width + high) / 2) + 1 # line thickness

tf = 1

for i in y:

# plots one bounding box on image img

x1 = int((i[0]) * x_w)

y1 = int((i[1]) * y_h)

x2 = int((i[2]) * x_w)

y2 = int((i[3]) * y_h)

cls = i[5]

color = [random.randint(0, 255) for _ in range(3)]

opencv.rectangle(frame, (x1, y1), (x2, y2), color, thickness=2)

# plots label

label = names[int(cls)]

label_size = opencv.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]

opencv.rectangle(frame, (x1, y1), (x1 + label_size[0], y1 - label_size[1] - 3), color, -1)

opencv.putText(frame, label, (x1, y1 - 8), 0, tl / 3, [225, 255, 255], thickness=tf,

lineType=opencv.LINE_AA)

end_time = time.time()

print(end_time - start_time)

return frame`

* 1

* 2

* 3

* 4

* 5

* 6

* 7

* 8

* 9

* 10

* 11

* 12

* 13

* 14

* 15

* 16

* 17

* 18

* 19

* 20

* 21

* 22

* 23

* 24

* 25

* 26

* 27

* 28

* 29

* 30

* 31

* 32

* 33

* 34

* 35

* 36

* 37

* 38

* 39

* 40

* 41

* 42

* 43

* 44

* 45

* 46

* 47

* 48

* 49

* 50

* 51

* 52

* 53

* 54

* 55

* 56

* 57

* 58

* 59

* 60

* 61

* 62

* 63

* 64

* 65

* 66

* 67

* 68

* 69

* 70

效果演示:

分别启动detection和DetectNet项目
填入需要识别的图片url或者上传图片文件,点击提交
【Python】Spring Boot部署深度学习模型(Java/Pytorch)

识别结果

请求总耗时150ms左右,其中pytorch运行时耗时在90ms(device=CUDA,GTX1050Ti),管理时耗时60ms(i5 8400 8GRAM)。耗时较大,这个估计和http接口有关,后续研究下python的rpc调用,以及数据压缩传输。
【Python】Spring Boot部署深度学习模型(Java/Pytorch)
模型权重:链接: https://pan.baidu.com/s/1SyIa... 提取码: 3pif

说明:本文记录细节和逻辑还有很多未完善的地方,对图片识别服务搭建、部署还将继续研究,然后继续更新
原文:https://blog.csdn.net/m0_4650...

以上是 【Python】Spring Boot部署深度学习模型(Java/Pytorch) 的全部内容, 来源链接: utcz.com/a/89135.html

回到顶部