01-15 00:15
Recent Posts
Recent Comments
관리 메뉴

너와나의 관심사

pytorch model 을 Android 에서 동작 본문

카테고리 없음

pytorch model 을 Android 에서 동작

벤치마킹 2020. 11. 13. 02:14

Pytorch pth Model 을 Android 에서 동작 시킬려면 c++과 연동되는 pt 모델 

torch.jit.save API 를 통해서 모델 변환이 필요하다. 


import torch
import torchvision

model = torchvision.models.resnet18(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("app/src/main/assets/model.pt")

5. Loading TorchScript Module

Module module = Module.load(assetFilePath(this, "model.pt"));
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();


torch.jit.save(torchscript_model_optimized, "cartoon_test2.pt")


하지만 여기서 중요한 포인트는 model.eval() 로 inference 모드로 변환이 안될때가 있다. 

이때는 아래 예제 처럼 처리해주자 


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from models.generator import Generator
from torch.utils.mobile_optimizer import optimize_for_mobile
 
device = torch.device('cpu')
 
 
 
ModelName = "trained_netG.pth"
example_input = torch.rand(13452340, dtype=torch.float)
model2  = Generator().to(device)
model2.load_state_dict(torch.load(ModelName, map_location=device))
model2.eval()
script_module = torch.jit.trace(model2.forward, example_input)
 
 
torchscript_model_optimized = optimize_for_mobile(script_module)
torch.jit.save(torchscript_model_optimized, "cartoon_test2.pt")
 
script_module.save('cartoon_test.pt')
 
 
loaded_model = torch.jit.load("cartoon_test.pt")
loaded_model.eval()
loaded_model.save('cartoon_test3.pt')
 
 
cs

 




Comments