Commit a9933f90 authored by Cao Duc Anh's avatar Cao Duc Anh

add export onnx

parent e61ef9b9
.vscode
__pycached__
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"from tool.config import Cfg\n",
"from tool.translate import build_model, process_input, translate\n",
"import torch\n",
"import onnxruntime\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"config = Cfg.load_config_from_file('./config/vgg-seq2seq.yml')\n",
"config['cnn']['pretrained']=False\n",
"config['device'] = 'cpu'\n",
"model, vocab = build_model(config)\n",
"weight_path = './weight/transformerocr.pth'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# load weight\n",
"model.load_state_dict(torch.load(weight_path, map_location=torch.device(config['device'])))\n",
"model = model.eval() "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Export CNN part"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def convert_cnn_part(img, save_path, model, max_seq_length=128, sos_token=1, eos_token=2): \n",
" with torch.no_grad(): \n",
" src = model.cnn(img)\n",
" torch.onnx.export(model.cnn, img, save_path, export_params=True, opset_version=12, do_constant_folding=True, verbose=True, input_names=['img'], output_names=['output'], dynamic_axes={'img': {3: 'lenght'}, 'output': {0: 'channel'}})\n",
" \n",
" return src"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"graph(%img : Float(1, 3, 32, *, strides=[45600, 15200, 475, 1], requires_grad=0, device=cpu),\n",
" %model.last_conv_1x1.weight : Float(256, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cpu),\n",
" %model.last_conv_1x1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),\n",
" %190 : Float(64, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %191 : Float(64, strides=[1], requires_grad=0, device=cpu),\n",
" %193 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %194 : Float(64, strides=[1], requires_grad=0, device=cpu),\n",
" %196 : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %197 : Float(128, strides=[1], requires_grad=0, device=cpu),\n",
" %199 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %200 : Float(128, strides=[1], requires_grad=0, device=cpu),\n",
" %202 : Float(256, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %203 : Float(256, strides=[1], requires_grad=0, device=cpu),\n",
" %205 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %206 : Float(256, strides=[1], requires_grad=0, device=cpu),\n",
" %208 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %209 : Float(256, strides=[1], requires_grad=0, device=cpu),\n",
" %211 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %212 : Float(256, strides=[1], requires_grad=0, device=cpu),\n",
" %214 : Float(512, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %215 : Float(512, strides=[1], requires_grad=0, device=cpu),\n",
" %217 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %218 : Float(512, strides=[1], requires_grad=0, device=cpu),\n",
" %220 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %221 : Float(512, strides=[1], requires_grad=0, device=cpu),\n",
" %223 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %224 : Float(512, strides=[1], requires_grad=0, device=cpu),\n",
" %226 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %227 : Float(512, strides=[1], requires_grad=0, device=cpu),\n",
" %229 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %230 : Float(512, strides=[1], requires_grad=0, device=cpu),\n",
" %232 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %233 : Float(512, strides=[1], requires_grad=0, device=cpu),\n",
" %235 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=0, device=cpu),\n",
" %236 : Float(512, strides=[1], requires_grad=0, device=cpu)):\n",
" %189 : Float(1, 64, 32, *, strides=[972800, 15200, 475, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%img, %190, %191)\n",
" %117 : Float(1, 64, 32, *, strides=[972800, 15200, 475, 1], requires_grad=0, device=cpu) = onnx::Relu(%189) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %192 : Float(1, 64, 32, *, strides=[972800, 15200, 475, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%117, %193, %194)\n",
" %120 : Float(1, 64, 32, *, strides=[972800, 15200, 475, 1], requires_grad=0, device=cpu) = onnx::Relu(%192) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %121 : Long(8, strides=[1], device=cpu) = onnx::Constant[value= 0 0 0 0 0 0 0 0 [ CPULongType{8} ]]()\n",
" %122 : Float(1, 64, 32, *, device=cpu) = onnx::Pad[mode=\"constant\"](%120, %121)\n",
" %123 : Float(1, 64, 16, *, strides=[242688, 3792, 237, 1], requires_grad=0, device=cpu) = onnx::AveragePool[ceil_mode=0, kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%122) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/modules/pooling.py:616:0\n",
" %195 : Float(1, 128, 16, *, strides=[485376, 3792, 237, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%123, %196, %197)\n",
" %126 : Float(1, 128, 16, *, strides=[485376, 3792, 237, 1], requires_grad=0, device=cpu) = onnx::Relu(%195) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %198 : Float(1, 128, 16, *, strides=[485376, 3792, 237, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%126, %199, %200)\n",
" %129 : Float(1, 128, 16, *, strides=[485376, 3792, 237, 1], requires_grad=0, device=cpu) = onnx::Relu(%198) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %130 : Long(8, strides=[1], device=cpu) = onnx::Constant[value= 0 0 0 0 0 0 0 0 [ CPULongType{8} ]]()\n",
" %131 : Float(1, 128, 16, *, device=cpu) = onnx::Pad[mode=\"constant\"](%129, %130)\n",
" %132 : Float(1, 128, 8, *, strides=[120832, 944, 118, 1], requires_grad=0, device=cpu) = onnx::AveragePool[ceil_mode=0, kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%131) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/modules/pooling.py:616:0\n",
" %201 : Float(1, 256, 8, *, strides=[241664, 944, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%132, %202, %203)\n",
" %135 : Float(1, 256, 8, *, strides=[241664, 944, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%201) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %204 : Float(1, 256, 8, *, strides=[241664, 944, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%135, %205, %206)\n",
" %138 : Float(1, 256, 8, *, strides=[241664, 944, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%204) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %207 : Float(1, 256, 8, *, strides=[241664, 944, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%138, %208, %209)\n",
" %141 : Float(1, 256, 8, *, strides=[241664, 944, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%207) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %210 : Float(1, 256, 8, *, strides=[241664, 944, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%141, %211, %212)\n",
" %144 : Float(1, 256, 8, *, strides=[241664, 944, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%210) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %145 : Long(8, strides=[1], device=cpu) = onnx::Constant[value= 0 0 0 0 0 0 0 0 [ CPULongType{8} ]]()\n",
" %146 : Float(1, 256, 8, *, device=cpu) = onnx::Pad[mode=\"constant\"](%144, %145)\n",
" %147 : Float(1, 256, 4, *, strides=[120832, 472, 118, 1], requires_grad=0, device=cpu) = onnx::AveragePool[ceil_mode=0, kernel_shape=[2, 1], pads=[0, 0, 0, 0], strides=[2, 1]](%146) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/modules/pooling.py:616:0\n",
" %213 : Float(1, 512, 4, *, strides=[241664, 472, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%147, %214, %215)\n",
" %150 : Float(1, 512, 4, *, strides=[241664, 472, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%213) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %216 : Float(1, 512, 4, *, strides=[241664, 472, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%150, %217, %218)\n",
" %153 : Float(1, 512, 4, *, strides=[241664, 472, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%216) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %219 : Float(1, 512, 4, *, strides=[241664, 472, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%153, %220, %221)\n",
" %156 : Float(1, 512, 4, *, strides=[241664, 472, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%219) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %222 : Float(1, 512, 4, *, strides=[241664, 472, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%156, %223, %224)\n",
" %159 : Float(1, 512, 4, *, strides=[241664, 472, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%222) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %160 : Long(8, strides=[1], device=cpu) = onnx::Constant[value= 0 0 0 0 0 0 0 0 [ CPULongType{8} ]]()\n",
" %161 : Float(1, 512, 4, *, device=cpu) = onnx::Pad[mode=\"constant\"](%159, %160)\n",
" %162 : Float(1, 512, 2, *, strides=[120832, 236, 118, 1], requires_grad=0, device=cpu) = onnx::AveragePool[ceil_mode=0, kernel_shape=[2, 1], pads=[0, 0, 0, 0], strides=[2, 1]](%161) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/modules/pooling.py:616:0\n",
" %225 : Float(1, 512, 2, *, strides=[120832, 236, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%162, %226, %227)\n",
" %165 : Float(1, 512, 2, *, strides=[120832, 236, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%225) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %228 : Float(1, 512, 2, *, strides=[120832, 236, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%165, %229, %230)\n",
" %168 : Float(1, 512, 2, *, strides=[120832, 236, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%228) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %231 : Float(1, 512, 2, *, strides=[120832, 236, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%168, %232, %233)\n",
" %171 : Float(1, 512, 2, *, strides=[120832, 236, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%231) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %234 : Float(1, 512, 2, *, strides=[120832, 236, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%171, %235, %236)\n",
" %174 : Float(1, 512, 2, *, strides=[120832, 236, 118, 1], requires_grad=0, device=cpu) = onnx::Relu(%234) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1204:0\n",
" %175 : Long(8, strides=[1], device=cpu) = onnx::Constant[value= 0 0 0 0 0 0 0 0 [ CPULongType{8} ]]()\n",
" %176 : Float(1, 512, 2, *, device=cpu) = onnx::Pad[mode=\"constant\"](%174, %175)\n",
" %177 : Float(1, 512, 2, *, strides=[120832, 236, 118, 1], requires_grad=0, device=cpu) = onnx::AveragePool[ceil_mode=0, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1]](%176) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/functional.py:1076:0\n",
" %178 : Float(1, 256, 2, *, strides=[60416, 236, 118, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1]](%177, %model.last_conv_1x1.weight, %model.last_conv_1x1.bias) # /home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/nn/modules/conv.py:396:0\n",
" %179 : Float(1, 256, *, 2, strides=[60416, 236, 1, 118], requires_grad=0, device=cpu) = onnx::Transpose[perm=[0, 1, 3, 2]](%178) # /home/manhbui/manhbq_workspace/ConvertVietOcr2Onnx/model/backbone/vgg.py:40:0\n",
" %180 : Long(4, strides=[1], device=cpu) = onnx::Shape(%179)\n",
" %181 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
" %182 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
" %183 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}]()\n",
" %184 : Long(2, strides=[1], device=cpu) = onnx::Slice(%180, %182, %183, %181)\n",
" %185 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={-1}]()\n",
" %186 : Long(3, strides=[1], device=cpu) = onnx::Concat[axis=0](%184, %185)\n",
" %187 : Float(1, 256, 236, strides=[60416, 236, 1], requires_grad=0, device=cpu) = onnx::Reshape(%179, %186) # /home/manhbui/manhbq_workspace/ConvertVietOcr2Onnx/model/backbone/vgg.py:41:0\n",
" %output : Float(236, 1, 256, strides=[1, 60416, 236], requires_grad=0, device=cpu) = onnx::Transpose[perm=[2, 0, 1]](%187) # /home/manhbui/manhbq_workspace/ConvertVietOcr2Onnx/model/backbone/vgg.py:42:0\n",
" return (%output)\n",
"\n"
]
}
],
"source": [
"img = torch.rand(1, 3, 32, 475)\n",
"src = convert_cnn_part(img, './weight/cnn.onnx', model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Export encoder part"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def convert_encoder_part(model, src, save_path): \n",
" encoder_outputs, hidden = model.transformer.encoder(src) \n",
" torch.onnx.export(model.transformer.encoder, src, save_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=['src'], output_names=['encoder_outputs', 'hidden'], dynamic_axes={'src':{0: \"channel_input\"}, 'encoder_outputs': {0: 'channel_output'}}) \n",
" return hidden, encoder_outputs"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/manhbui/anaconda3/envs/manhbq/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py:1945: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with GRU can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. \n",
" \"or define the initial states (h0/c0) as inputs of the model. \")\n"
]
}
],
"source": [
"hidden, encoder_outputs = convert_encoder_part(model, src, './weight/encoder.onnx')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Export decoder part"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def convert_decoder_part(model, tgt, hidden, encoder_outputs, save_path):\n",
" tgt = tgt[-1]\n",
" \n",
" torch.onnx.export(model.transformer.decoder,\n",
" (tgt, hidden, encoder_outputs),\n",
" save_path,\n",
" export_params=True,\n",
" opset_version=11,\n",
" do_constant_folding=True,\n",
" input_names=['tgt', 'hidden', 'encoder_outputs'],\n",
" output_names=['output', 'hidden_out', 'last'],\n",
" dynamic_axes={'encoder_outputs':{0: \"channel_input\"},\n",
" 'last': {0: 'channel_output'}})"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"device = img.device\n",
"tgt = torch.LongTensor([[1] * len(img)]).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/manhbui/manhbq_workspace/ConvertVietOcr2Onnx/model/seqmodel/seq2seq.py:93: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
" assert (output == hidden).all()\n"
]
}
],
"source": [
"convert_decoder_part(model, tgt, hidden, encoder_outputs, './weight/decoder.onnx')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load and check model"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import onnx"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"cnn = onnx.load('./weight/cnn.onnx')\n",
"decoder = onnx.load('./weight/encoder.onnx')\n",
"encoder = onnx.load('./weight/decoder.onnx')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# confirm model has valid schema\n",
"onnx.checker.check_model(cnn)\n",
"onnx.checker.check_model(decoder)\n",
"onnx.checker.check_model(encoder)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "'graph torch-jit-export (\\n %tgt[INT64, 1]\\n %hidden[FLOAT, 1x256]\\n %encoder_outputs[FLOAT, channel_inputx1x512]\\n) initializers (\\n %attention.attn.bias[FLOAT, 256]\\n %embedding.weight[FLOAT, 233x256]\\n %fc_out.weight[FLOAT, 233x1024]\\n %fc_out.bias[FLOAT, 233]\\n %116[INT64, 1]\\n %117[INT64, 1]\\n %118[INT64, 1]\\n %119[INT64, 1]\\n %120[FLOAT, 768x256]\\n %121[FLOAT, 256x1]\\n %139[FLOAT, 1x768x768]\\n %140[FLOAT, 1x768x256]\\n %141[FLOAT, 1x1536]\\n) {\\n %13 = Unsqueeze[axes = [0]](%tgt)\\n %14 = Gather(%embedding.weight, %13)\\n %15 = Shape(%encoder_outputs)\\n %16 = Constant[value = <Scalar Tensor []>]()\\n %17 = Gather[axis = 0](%15, %16)\\n %18 = Unsqueeze[axes = [1]](%hidden)\\n %22 = Unsqueeze[axes = [0]](%17)\\n %24 = Concat[axis = 0](%116, %22, %117)\\n %26 = Unsqueeze[axes = [0]](%17)\\n %28 = Concat[axis = 0](%118, %26, %119)\\n %29 = Shape(%24)\\n %30 = ConstantOfShape[value = <Tensor>](%29)\\n %31 = Expand(%18, %30)\\n %32 = Tile(%31, %28)\\n %33 = Transpose[perm = [1, 0, 2]](%encoder_outputs)\\n %34 = Concat[axis = 2](%32, %33)\\n %36 = MatMul(%34, %120)\\n %37 = Add(%36, %attention.attn.bias)\\n %38 = Tanh(%37)\\n %40 = MatMul(%38, %121)\\n %41 = Squeeze[axes = [2]](%40)\\n %42 = Softmax[axis = 1](%41)\\n %43 = Unsqueeze[axes = [1]](%42)\\n %44 = Transpose[perm = [1, 0, 2]](%encoder_outputs)\\n %45 = MatMul(%43, %44)\\n %46 = Transpose[perm = [1, 0, 2]](%45)\\n %47 = Concat[axis = 2](%14, %46)\\n %48 = Unsqueeze[axes = [0]](%hidden)\\n %106, %107 = GRU[hidden_size = 256, linear_before_reset = 1](%47, %139, %140, %141, %, %48)\\n %108 = Squeeze[axes = [1]](%106)\\n %109 = Squeeze[axes = [0]](%14)\\n %110 = Squeeze[axes = [0]](%108)\\n %111 = Squeeze[axes = [0]](%46)\\n %112 = Concat[axis = 1](%110, %111, %109)\\n %output = Gemm[alpha = 1, beta = 1, transB = 1](%112, %fc_out.weight, %fc_out.bias)\\n %hidden_out = Squeeze[axes = [0]](%107)\\n %last = Squeeze[axes = [1]](%43)\\n return %output, %hidden_out, %last\\n}'"
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# # Print a human readable representation of the graph\n",
"onnx.helper.printable_graph(encoder.graph)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference directly"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"img = Image.open('./sample/35944.png')\n",
"img = process_input(img, config['dataset']['image_height'], \n",
" config['dataset']['image_min_width'], config['dataset']['image_max_width']) \n",
"img = img.to(config['device'])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "'Mâm non: 141 thí sinh'"
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"s = translate(img, model)[0].tolist()\n",
"s = vocab.decode(s)\n",
"s"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference with ONNX Runtime's Python API"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# create inference session\n",
"cnn_session = onnxruntime.InferenceSession(\"./weight/cnn.onnx\")\n",
"encoder_session = onnxruntime.InferenceSession(\"./weight/encoder.onnx\")\n",
"decoder_session = onnxruntime.InferenceSession(\"./weight/decoder.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def translate_onnx(img, session, max_seq_length=128, sos_token=1, eos_token=2):\n",
" \"\"\"data: BxCxHxW\"\"\"\n",
" cnn_session, encoder_session, decoder_session = session\n",
" \n",
" # create cnn input\n",
" cnn_input = {cnn_session.get_inputs()[0].name: img}\n",
" src = cnn_session.run(None, cnn_input)\n",
" \n",
" # create encoder input\n",
" encoder_input = {encoder_session.get_inputs()[0].name: src[0]}\n",
" encoder_outputs, hidden = encoder_session.run(None, encoder_input)\n",
" translated_sentence = [[sos_token] * len(img)]\n",
" max_length = 0\n",
"\n",
" while max_length <= max_seq_length and not all(\n",
" np.any(np.asarray(translated_sentence).T == eos_token, axis=1)\n",
" ):\n",
" tgt_inp = translated_sentence\n",
" decoder_input = {decoder_session.get_inputs()[0].name: tgt_inp[-1], decoder_session.get_inputs()[1].name: hidden, decoder_session.get_inputs()[2].name: encoder_outputs}\n",
"\n",
" output, hidden, _ = decoder_session.run(None, decoder_input)\n",
" output = np.expand_dims(output, axis=1)\n",
" output = torch.Tensor(output)\n",
"\n",
" values, indices = torch.topk(output, 1)\n",
" indices = indices[:, -1, 0]\n",
" indices = indices.tolist()\n",
"\n",
" translated_sentence.append(indices)\n",
" max_length += 1\n",
"\n",
" del output\n",
"\n",
" translated_sentence = np.asarray(translated_sentence).T\n",
"\n",
" return translated_sentence"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "'Mâm non: 141 thí sinh'"
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"session = (cnn_session, encoder_session, decoder_session)\n",
"s = translate_onnx(np.array(img), session)[0].tolist()\n",
"s = vocab.decode(s)\n",
"s"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.10 64-bit ('manhbq': conda)",
"name": "python3710jvsc74a57bd08073656a449d74c1402c8646685842603f61a524b0c74948679a8c6893091938"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
# ConvertVietOcr2Onnx
Tutorial: [Chuyển đổi mô hình học sâu về ONNX](https://viblo.asia/p/chuyen-doi-mo-hinh-hoc-sau-ve-onnx-bWrZnz4vZxw)
# change to list chars of your dataset or use default vietnamese chars
vocab: 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ '
# cpu, cuda, cuda:0
device: cuda:0
seq_modeling: transformer
transformer:
d_model: 256
nhead: 8
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 2048
max_seq_length: 1024
pos_dropout: 0.1
trans_dropout: 0.1
optimizer:
max_lr: 0.0003
pct_start: 0.1
trainer:
batch_size: 32
print_every: 200
valid_every: 4000
iters: 100000
# where to save our model for prediction
export: ./weights/transformerocr.pth
checkpoint: ./checkpoint/transformerocr_checkpoint.pth
log: ./train.log
# null to disable compuate accuracy, or change to number of sample to enable validiation while training
metrics: null
dataset:
# name of your dataset
name: data
# path to annotation and image
data_root: ./img/
train_annotation: annotation_train.txt
valid_annotation: annotation_val_small.txt
# resize image to 32 height, larger height will increase accuracy
image_height: 32
image_min_width: 32
image_max_width: 512
dataloader:
num_workers: 3
pin_memory: True
aug:
image_aug: true
masked_language_model: true
predictor:
# disable or enable beamsearch while prediction, use beamsearch will be slower
beamsearch: False
quiet: False
\ No newline at end of file
import torch
from torch import nn
import model.backbone.vgg as vgg
class CNN(nn.Module):
def __init__(self, backbone, **kwargs):
super(CNN, self).__init__()
if backbone == 'vgg11_bn':
self.model = vgg.vgg11_bn(**kwargs)
elif backbone == 'vgg19_bn':
self.model = vgg.vgg19_bn(**kwargs)
def forward(self, x):
return self.model(x)
def freeze(self):
for name, param in self.model.features.named_parameters():
if name != 'last_conv_1x1':
param.requires_grad = False
def unfreeze(self):
for param in self.model.features.parameters():
param.requires_grad = True
\ No newline at end of file
import torch
from torch import nn
from torchvision import models
from einops import rearrange
from torchvision.models._utils import IntermediateLayerGetter
class Vgg(nn.Module):
def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5):
super(Vgg, self).__init__()
if name == 'vgg11_bn':
cnn = models.vgg11_bn(pretrained=pretrained)
elif name == 'vgg19_bn':
cnn = models.vgg19_bn(pretrained=pretrained)
pool_idx = 0
for i, layer in enumerate(cnn.features):
if isinstance(layer, torch.nn.MaxPool2d):
cnn.features[i] = torch.nn.AvgPool2d(kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0)
pool_idx += 1
self.features = cnn.features
self.dropout = nn.Dropout(dropout)
self.last_conv_1x1 = nn.Conv2d(512, hidden, 1)
def forward(self, x):
"""
Shape:
- x: (N, C, H, W)
- output: (W, N, C)
"""
conv = self.features(x)
conv = self.dropout(conv)
conv = self.last_conv_1x1(conv)
# conv = rearrange(conv, 'b d h w -> b d (w h)')
conv = conv.permute(0, 1, 3, 2)
conv = conv.flatten(2)
conv = conv.permute(2, 0, 1)
return conv
def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
return Vgg('vgg11_bn', ss, ks, hidden, pretrained, dropout)
def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
return Vgg('vgg19_bn', ss, ks, hidden, pretrained, dropout)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
super().__init__()
self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, src):
"""
src: src_len x batch_size x img_channel
outputs: src_len x batch_size x hid_dim
hidden: batch_size x hid_dim
"""
embedded = self.dropout(src)
outputs, hidden = self.rnn(embedded)
hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
return outputs, hidden
class Attention(nn.Module):
def __init__(self, enc_hid_dim, dec_hid_dim):
super().__init__()
self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
self.v = nn.Linear(dec_hid_dim, 1, bias = False)
def forward(self, hidden, encoder_outputs):
"""
hidden: batch_size x hid_dim
encoder_outputs: src_len x batch_size x hid_dim,
outputs: batch_size x src_len
"""
batch_size = encoder_outputs.shape[1]
src_len = encoder_outputs.shape[0]
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
encoder_outputs = encoder_outputs.permute(1, 0, 2)
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))
attention = self.v(energy).squeeze(2)
return F.softmax(attention, dim = 1)
class Decoder(nn.Module):
def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
super().__init__()
self.output_dim = output_dim
self.attention = attention
self.embedding = nn.Embedding(output_dim, emb_dim)
self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input, hidden, encoder_outputs):
"""
inputs: batch_size
hidden: batch_size x hid_dim
encoder_outputs: src_len x batch_size x hid_dim
"""
input = input.unsqueeze(0)
embedded = self.dropout(self.embedding(input))
a = self.attention(hidden, encoder_outputs)
a = a.unsqueeze(1)
encoder_outputs = encoder_outputs.permute(1, 0, 2)
weighted = torch.bmm(a, encoder_outputs)
weighted = weighted.permute(1, 0, 2)
rnn_input = torch.cat((embedded, weighted), dim = 2)
output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
assert (output == hidden).all()
embedded = embedded.squeeze(0)
output = output.squeeze(0)
weighted = weighted.squeeze(0)
prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
return prediction, hidden.squeeze(0), a.squeeze(1)
class Seq2Seq(nn.Module):
def __init__(self, vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_embedded, dropout=0.1):
super().__init__()
attn = Attention(encoder_hidden, decoder_hidden)
self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout)
self.decoder = Decoder(vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn)
def forward_encoder(self, src):
"""
src: timestep x batch_size x channel
hidden: batch_size x hid_dim
encoder_outputs: src_len x batch_size x hid_dim
"""
encoder_outputs, hidden = self.encoder(src)
return (hidden, encoder_outputs)
def forward_decoder(self, tgt, memory):
"""
tgt: timestep x batch_size
hidden: batch_size x hid_dim
encouder: src_len x batch_size x hid_dim
output: batch_size x 1 x vocab_size
"""
tgt = tgt[-1]
hidden, encoder_outputs = memory
output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs)
output = output.unsqueeze(1)
return output, (hidden, encoder_outputs)
def forward(self, src, trg):
"""
src: time_step x batch_size
trg: time_step x batch_size
outputs: batch_size x time_step x vocab_size
"""
batch_size = src.shape[1]
trg_len = trg.shape[0]
trg_vocab_size = self.decoder.output_dim
device = src.device
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device)
encoder_outputs, hidden = self.encoder(src)
for t in range(trg_len):
input = trg[t]
output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
outputs[t] = output
outputs = outputs.transpose(0, 1).contiguous()
return outputs
def expand_memory(self, memory, beam_size):
hidden, encoder_outputs = memory
hidden = hidden.repeat(beam_size, 1)
encoder_outputs = encoder_outputs.repeat(1, beam_size, 1)
return (hidden, encoder_outputs)
def get_memory(self, memory, i):
hidden, encoder_outputs = memory
hidden = hidden[[i]]
encoder_outputs = encoder_outputs[:, [i],:]
return (hidden, encoder_outputs)
from model.backbone.cnn import CNN
from model.seqmodel.seq2seq import Seq2Seq
from torch import nn
class VietOCR(nn.Module):
def __init__(self, vocab_size,
backbone,
cnn_args,
transformer_args, seq_modeling='transformer'):
super(VietOCR, self).__init__()
self.cnn = CNN(backbone, **cnn_args)
self.seq_modeling = seq_modeling
self.transformer = Seq2Seq(vocab_size, **transformer_args)
def forward(self, img, tgt_input, tgt_key_padding_mask):
"""
Shape:
- img: (N, C, H, W)
- tgt_input: (T, N)
- tgt_key_padding_mask: (N, T)
- output: b t v
"""
src = self.cnn(img)
outputs = self.transformer(src, tgt_input)
return outputs
\ No newline at end of file
class Vocab():
def __init__(self, chars):
self.pad = 0
self.go = 1
self.eos = 2
self.mask_token = 3
self.chars = chars
self.c2i = {c:i+4 for i, c in enumerate(chars)}
self.i2c = {i+4:c for i, c in enumerate(chars)}
self.i2c[0] = '<pad>'
self.i2c[1] = '<sos>'
self.i2c[2] = '<eos>'
self.i2c[3] = '*'
def encode(self, chars):
return [self.go] + [self.c2i[c] for c in chars] + [self.eos]
def decode(self, ids):
first = 1 if self.go in ids else 0
last = ids.index(self.eos) if self.eos in ids else None
sent = ''.join([self.i2c[i] for i in ids[first:last]])
return sent
def __len__(self):
return len(self.c2i) + 4
def batch_decode(self, arr):
texts = [self.decode(ids) for ids in arr]
return texts
def __str__(self):
return self.chars
import yaml
def load_config(config_file):
with open(config_file, encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
class Cfg(dict):
def __init__(self, config_dict):
super(Cfg, self).__init__(**config_dict)
self.__dict__ = self
@staticmethod
def load_config_from_file(fname, base_file='./config/base.yml'):
base_config = load_config(base_file)
with open(fname, encoding='utf-8') as f:
config = yaml.safe_load(f)
base_config.update(config)
return Cfg(base_config)
def save(self, fname):
with open(fname, 'w') as outfile:
yaml.dump(dict(self), outfile, default_flow_style=False, allow_unicode=True)
import torch
import numpy as np
import cv2
from model.vocab import Vocab
from model.transformerocr import VietOCR
import math
from PIL import Image
def translate(img, model, max_seq_length=128, sos_token=1, eos_token=2):
"""data: BxCxHxW"""
model.eval()
device = img.device
with torch.no_grad():
src = model.cnn(img)
memory = model.transformer.forward_encoder(src)
translated_sentence = [[sos_token] * len(img)]
max_length = 0
while max_length <= max_seq_length and not all(np.any(np.asarray(translated_sentence).T == eos_token, axis=1)):
tgt_inp = torch.LongTensor(translated_sentence).to(device)
output, memory = model.transformer.forward_decoder(tgt_inp, memory)
output = output.to('cpu')
values, indices = torch.topk(output, 1)
indices = indices[:, -1, 0]
indices = indices.tolist()
translated_sentence.append(indices)
max_length += 1
del output
translated_sentence = np.asarray(translated_sentence).T
return translated_sentence
def build_model(config):
vocab = Vocab(config['vocab'])
device = config['device']
model = VietOCR(len(vocab),
config['backbone'],
config['cnn'],
config['transformer'],
config['seq_modeling'])
model = model.to(device)
return model, vocab
def resize(w, h, expected_height, image_min_width, image_max_width):
new_w = int(expected_height * float(w) / float(h))
round_to = 10
new_w = math.ceil(new_w/round_to)*round_to
new_w = max(new_w, image_min_width)
new_w = min(new_w, image_max_width)
return new_w, expected_height
def process_image(image, image_height, image_min_width, image_max_width):
img = image.convert('RGB')
w, h = img.size
new_w, image_height = resize(w, h, image_height, image_min_width, image_max_width)
img = img.resize((new_w, image_height), Image.ANTIALIAS)
img = np.asarray(img).transpose(2,0, 1)
img = img/255
return img
def process_input(image, image_height, image_min_width, image_max_width):
img = process_image(image, image_height, image_min_width, image_max_width)
img = img[np.newaxis, ...]
img = torch.FloatTensor(img)
return img
......@@ -41,4 +41,26 @@ test_annotation dùng để tính valid loss.
docker-compose -f training.docker-compose.yml up --build
```
Theo dõi kết quả trên màn hình terminal. <br>
Model sau khi train được lưu tại **"./weights"**
\ No newline at end of file
Model sau khi train được lưu tại **"./weights"**
Sau khi train: docker system prune
## Export model Pytorch to ONNX
### 1. Config
Model pth được lưu tại **"./weights"** <br>
Chỉnh sửa file **vgg-seq2seq.yml**. Các thông tin cần lưu ý:
```
device: cuda:0
transformer:
encoder_hidden: 256
decoder_hidden: 256
img_channel: 256
decoder_embedded: 256
dropout: 0.1
```
### 2. Run
```
docker-compose -f export.docker-compose.yml up --build
```
Model sau khi export được lưu tại **"./weights"**
\ No newline at end of file
......@@ -40,17 +40,17 @@ dataset:
valid_annotation: test_annotation.txt
device: cuda:0
optimizer:
max_lr: 0.0003
max_lr: 0.001
pct_start: 0.1
predictor:
beamsearch: false
pretrain: https://vocr.vn/data/vietocr/vgg_transformer.pth
pretrain: https://vocr.vn/data/vietocr/vgg_seq2seq.pth
quiet: false
seq_modeling: transformer
seq_modeling: seq2seq
trainer:
batch_size: 8
checkpoint: ./checkpoint/transformerocr_checkpoint.pth
export: ./weights/transformerocr.pth
export: ./weights/vetocr.pth
iters: 20000
log: ./train.log
metrics: 2
......@@ -66,4 +66,4 @@ transformer:
pos_dropout: 0.1
trans_dropout: 0.1
vocab: 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ '
weights: https://vocr.vn/data/vietocr/vgg_transformer.pth
weights: https://vocr.vn/data/vietocr/vgg_seq2seq.pth
FROM pytorch/pytorch:2.3.1-cuda11.8-cudnn8-runtime
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=True \
PORT=9090
# Install dependencies
RUN apt-get update \
&& apt-get install -y wget libgl1-mesa-glx libglib2.0-0
WORKDIR /src
COPY ./export_requirements.txt /src/requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
COPY ./ConvertVietOcr2Onnx/ /src/
COPY ./export.py /src/
ENV PYTHONPATH=/src
\ No newline at end of file
version: '3.9'
services:
export-vietocr:
build:
context: ./
dockerfile: export.Dockerfile
volumes:
- ./vgg-seq2seq.yml:/src/config/export_config.yml
- ./weights:/src/weights/
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: python export.py
# tty: True
\ No newline at end of file
import matplotlib.pyplot as plt
from tool.config import Cfg
from tool.translate import build_model, process_input, translate
import torch
import onnxruntime
import numpy as np
config = Cfg.load_config_from_file('/src/config/export_config.yml')
config['cnn']['pretrained']=False
weight_path = '/src/weights/vetocr.pth'
# build model
model, vocab = build_model(config)
# load weight
model.load_state_dict(torch.load(weight_path, map_location=torch.device(config['device'])))
model = model.eval()
# Export mô hình CNN
def convert_cnn_part(img, save_path, model):
with torch.no_grad():
src = model.cnn(img)
torch.onnx.export(model.cnn, img, save_path, export_params=True, opset_version=12, do_constant_folding=True, verbose=True, input_names=['img'], output_names=['output'], dynamic_axes={'img': {3: 'lenght'}, 'output': {0: 'channel'}})
return src
img = torch.rand(1, 3, 32, 475).cuda()
src = convert_cnn_part(img, '/src/weights/cnn.onnx', model)
# Export mô hình Encoder
def convert_encoder_part(model, src, save_path):
encoder_outputs, hidden = model.transformer.encoder(src)
torch.onnx.export(model.transformer.encoder, src, save_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=['src'], output_names=['encoder_outputs', 'hidden'], dynamic_axes={'src':{0: "channel_input"}, 'encoder_outputs': {0: 'channel_output'}})
return hidden, encoder_outputs
hidden, encoder_outputs = convert_encoder_part(model, src, '/src/weights/encoder.onnx')
# Export mô hình Decoder
def convert_decoder_part(model, tgt, hidden, encoder_outputs, save_path):
tgt = tgt[-1]
torch.onnx.export(model.transformer.decoder,
(tgt, hidden, encoder_outputs),
save_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['tgt', 'hidden', 'encoder_outputs'],
output_names=['output', 'hidden_out', 'last'],
dynamic_axes={'encoder_outputs':{0: "channel_input"},
'last': {0: 'channel_output'}})
device = img.device
tgt = torch.LongTensor([[1] * len(img)]).to(device)
convert_decoder_part(model, tgt, hidden, encoder_outputs, '/src/weights/decoder.onnx')
# Kiểm tra mô hình sau khi chuyển đổi
import onnx
# load model from onnx
cnn = onnx.load('/src/weights/cnn.onnx')
decoder = onnx.load('/src/weights/encoder.onnx')
encoder = onnx.load('/src/weights/decoder.onnx')
# confirm model has valid schema
onnx.checker.check_model(cnn)
onnx.checker.check_model(decoder)
onnx.checker.check_model(encoder)
# Print a human readable representation of the graph
onnx.helper.printable_graph(encoder.graph)
\ No newline at end of file
# Base on pytorch/pytorch:2.3.1-cuda11.8-cudnn8-runtime
vietocr
albumentations
matplotlib
onnxruntime
\ No newline at end of file
pretrain:
id_or_url: 1nTKlEog9YFK74kPyX0qLwCWi60_YHHk4
md5: efcabaa6d3adfca8e52bda2fd7d2ee04
cached: /tmp/tranformerorc.pth
device: cuda:0
# url or local path
weights: https://drive.google.com/uc?id=1nTKlEog9YFK74kPyX0qLwCWi60_YHHk4
backbone: vgg19_bn
cnn:
# pooling stride size
ss:
- [2, 2]
- [2, 2]
- [2, 1]
- [2, 1]
- [1, 1]
# pooling kernel size
ks:
- [2, 2]
- [2, 2]
- [2, 1]
- [2, 1]
- [1, 1]
# dim of ouput feature map
hidden: 256
seq_modeling: seq2seq
transformer:
encoder_hidden: 256
decoder_hidden: 256
img_channel: 256
decoder_embedded: 256
dropout: 0.1
optimizer:
max_lr: 0.001
pct_start: 0.1
\ No newline at end of file
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# VietOCR
**Các bạn vui lòng cập nhật lên version mới nhất để không xảy ra lỗi.**
<p align="center">
<img src="https://github.com/pbcquoc/vietocr/raw/master/image/sample.png" width="1000" height="300">
</p>
Trong project này, mình cài đặt mô hình Transformer OCR nhận dạng chữ viết tay, chữ đánh máy cho Tiếng Việt. Kiến trúc mô hình là sự kết hợp tuyệt vời giữ mô hình CNN và Transformer (là mô hình nền tảng của BERT khá nổi tiếng). Mô hình TransformerOCR có rất nhiều ưu điểm so với kiến trúc của mô hình CRNN đã được mình cài đặt. Các bạn có thể đọc [tại](https://pbcquoc.github.io/vietocr) đây về kiến trúc và cách huấn luyện mô hình với các tập dữ liệu khác nhau.
Mô hình VietOCR có tính tổng quát cực tốt, thậm chí có độ chính xác khá cao trên một bộ dataset mới mặc dù mô hình chưa được huấn luyện bao giờ.
<p align="center">
<img src="https://raw.githubusercontent.com/pbcquoc/vietocr/master/image/vietocr.jpg" width="512" height="614">
</p>
# Cài Đặt
Để cài đặt các bạn gõ lệnh sau
```
pip install vietocr
```
# Quick Start
Các bạn tham khảo notebook [này](https://github.com/pbcquoc/vietocr/blob/master/vietocr_gettingstart.ipynb) để biết cách sử dụng nhé.
# Cách tạo file train/test
File train/test có 2 cột, cột đầu tiên là tên file, cột thứ 2 là nhãn(không chứa kí tự \t), 2 cột này cách nhau bằng \t
```
20160518_0151_25432_1_tg_3_5.png để nghe phổ biến chủ trương của UBND tỉnh Phú Yên
20160421_0102_25464_2_tg_0_4.png môi trường lại đều đồng thanh
```
Tham khảo file mẫu tại [đây](https://vocr.vn/data/vietocr/data_line.zip)
# Model Zoo
Thư viện này cài đặt cả 2 kiểu seq model đó là attention seq2seq và transfomer. Seq2seq có tốc độ dự đoán rất nhanh và được dùng trong industry khá nhiều, tuy nhiên transformer lại chính xác hơn nhưng lúc dự đoán lại khá chậm. Do đó mình cung cấp cả 2 loại cho các bạn lựa chọn.
Mô hình này được huấn luyện trên tập dữ liệu gồm 10m ảnh, bao gồm nhiều loại ảnh khác nhau như ảnh tự phát sinh, chữ viết tay, các văn bản scan thực tế.
Pretrain model được cung cấp sẵn.
# Kết quả thử nghiệm trên tập 10m
| Backbone | Config | Precision full sequence | time |
| ------------- |:-------------:| ---:|---:|
| VGG19-bn - Transformer | vgg_transformer | 0.8800 | 86ms @ 1080ti |
| VGG19-bn - Seq2Seq | vgg_seq2seq | 0.8701 | 12ms @ 1080ti |
Thời gian dự đoán của mô hình vgg-transformer quá lâu so với mô hình seq2seq, trong khi đó không có sự khác biệt rõ ràng giữ độ chính xác của 2 loại kiến trúc này.
# Dataset
Mình chỉ cung cấp tập dữ liệu mẫu khoảng 1m ảnh tự phát sinh. Các bạn có thể tải về tại [đây](https://drive.google.com/file/d/1T0cmkhTgu3ahyMIwGZeby612RpVdDxOR/view).
# License
Mình phát hành thư viện này dưới các điều khoản của [Apache 2.0 license]().
# Liên hệ
Nếu bạn có bất kì vấn đề gì, vui lòng tạo issue hoặc liên hệ mình tại pbcquoc@gmail.com
# change to list chars of your dataset or use default vietnamese chars
vocab: 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ '
# cpu, cuda, cuda:0
device: cuda:0
seq_modeling: transformer
transformer:
d_model: 256
nhead: 8
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 2048
max_seq_length: 1024
pos_dropout: 0.1
trans_dropout: 0.1
optimizer:
max_lr: 0.0003
pct_start: 0.1
trainer:
batch_size: 32
print_every: 200
valid_every: 4000
iters: 100000
# where to save our model for prediction
export: ./weights/transformerocr.pth
checkpoint: ./checkpoint/transformerocr_checkpoint.pth
log: ./train.log
# null to disable compuate accuracy, or change to number of sample to enable validiation while training
metrics: null
dataset:
# name of your dataset
name: data
# path to annotation and image
data_root: ./img/
train_annotation: annotation_train.txt
valid_annotation: annotation_val_small.txt
# resize image to 32 height, larger height will increase accuracy
image_height: 32
image_min_width: 32
image_max_width: 512
dataloader:
num_workers: 3
pin_memory: True
aug:
image_aug: true
masked_language_model: true
predictor:
# disable or enable beamsearch while prediction, use beamsearch will be slower
beamsearch: False
quiet: False
pretrain:
id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
md5: 7068030afe2e8fc639d0e1e2c25612b3
cached: /tmp/tranformerorc.pth
weights: https://drive.google.com/uc?id=12dTOZ9VP7ZVzwQgVvqBWz5JO5RXXW5NY
backbone: resnet50
cnn:
ss:
- [2, 2]
- [2, 1]
- [2, 1]
- [2, 1]
- [1, 1]
hidden: 256
pretrain:
id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
md5: 7068030afe2e8fc639d0e1e2c25612b3
cached: /tmp/tranformerorc.pth
weights: https://drive.google.com/uc?id=12dTOZ9VP7ZVzwQgVvqBWz5JO5RXXW5NY
backbone: resnet50_fpn
cnn: {}
pretrain:
id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
md5: fbefa85079ad9001a71eb1bf47a93785
cached: /tmp/tranformerorc.pth
# url or local path
weights: https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
backbone: vgg19_bn
cnn:
# pooling stride size
ss:
- [2, 2]
- [2, 2]
- [2, 1]
- [2, 1]
- [1, 1]
# pooling kernel size
ks:
- [2, 2]
- [2, 2]
- [2, 1]
- [2, 1]
- [1, 1]
# dim of ouput feature map
hidden: 256
seq_modeling: convseq2seq
transformer:
emb_dim: 256
hid_dim: 512
enc_layers: 10
dec_layers: 10
enc_kernel_size: 3
dec_kernel_size: 3
dropout: 0.1
pad_idx: 0
device: cuda:0
enc_max_length: 512
dec_max_length: 512
# for train
pretrain: https://vocr.vn/data/vietocr/vgg_seq2seq.pth
# url or local path (for predict)
weights: https://vocr.vn/data/vietocr/vgg_seq2seq.pth
backbone: vgg19_bn
cnn:
# pooling stride size
ss:
- [2, 2]
- [2, 2]
- [2, 1]
- [2, 1]
- [1, 1]
# pooling kernel size
ks:
- [2, 2]
- [2, 2]
- [2, 1]
- [2, 1]
- [1, 1]
# dim of ouput feature map
hidden: 256
seq_modeling: seq2seq
transformer:
encoder_hidden: 256
decoder_hidden: 256
img_channel: 256
decoder_embedded: 256
dropout: 0.1
optimizer:
max_lr: 0.001
pct_start: 0.1
# for training
pretrain: https://vocr.vn/data/vietocr/vgg_transformer.pth
# url or local path (predict)
weights: https://vocr.vn/data/vietocr/vgg_transformer.pth
backbone: vgg19_bn
cnn:
pretrained: True
# pooling stride size
ss:
- [2, 2]
- [2, 2]
- [2, 1]
- [2, 1]
- [1, 1]
# pooling kernel size
ks:
- [2, 2]
- [2, 2]
- [2, 1]
- [2, 1]
- [1, 1]
# dim of ouput feature map
hidden: 256
import setuptools
with open("README.md", "r") as fh:
long_description = fh.read()
setuptools.setup(
name="vietocr",
version="0.3.13",
author="pbcquoc",
author_email="pbcquoc@gmail.com",
description="Transformer base text detection",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/pbcquoc/vietocr",
packages=setuptools.find_packages(),
install_requires=[
'einops==0.2.0',
'gdown==4.4.0',
'prefetch_generator==1.0.1',
'imgaug==0.4.0',
'albumentations==1.4.2',
'lmdb>=1.0.0',
'scikit-image>=0.21.0',
'pillow==10.2.0'
],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires='>=3.6',
)
from PIL import Image
import numpy as np
from imgaug import augmenters as iaa
import imgaug as ia
import albumentations as A
class ImgAugTransform:
def __init__(self):
sometimes = lambda aug: iaa.Sometimes(0.3, aug)
self.aug = iaa.Sequential(iaa.SomeOf((1, 5),
[
# blur
sometimes(iaa.OneOf([iaa.GaussianBlur(sigma=(0, 1.0)),
iaa.MotionBlur(k=3)])),
# color
sometimes(iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True)),
sometimes(iaa.SigmoidContrast(gain=(3, 10), cutoff=(0.4, 0.6), per_channel=True)),
sometimes(iaa.Invert(0.25, per_channel=0.5)),
sometimes(iaa.Solarize(0.5, threshold=(32, 128))),
sometimes(iaa.Dropout2d(p=0.5)),
sometimes(iaa.Multiply((0.5, 1.5), per_channel=0.5)),
sometimes(iaa.Add((-40, 40), per_channel=0.5)),
sometimes(iaa.JpegCompression(compression=(5, 80))),
# distort
sometimes(iaa.Crop(percent=(0.01, 0.05), sample_independently=True)),
sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.01))),
sometimes(iaa.Affine(scale=(0.7, 1.3), translate_percent=(-0.1, 0.1),
# rotate=(-5, 5), shear=(-5, 5),
order=[0, 1], cval=(0, 255),
mode=ia.ALL)),
sometimes(iaa.PiecewiseAffine(scale=(0.01, 0.01))),
sometimes(iaa.OneOf([iaa.Dropout(p=(0, 0.1)),
iaa.CoarseDropout(p=(0, 0.1), size_percent=(0.02, 0.25))])),
],
random_order=True),
random_order=True)
def __call__(self, img):
img = np.array(img)
img = self.aug.augment_image(img)
img = Image.fromarray(img)
return img
class ImgAugTransformV2:
def __init__(self):
self.aug = A.Compose([
A.InvertImg(p=0.2),
A.ColorJitter(p=0.2),
A.MotionBlur(blur_limit=3, p=0.2),
A.RandomBrightnessContrast(p=0.2),
A.Perspective(scale=(0.01, 0.05))
])
def __call__(self, img):
img = np.array(img)
transformed = self.aug(image=img)
img = transformed["image"]
img = Image.fromarray(img)
return img
import sys
import os
import random
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from collections import defaultdict
import numpy as np
import torch
import lmdb
import six
import time
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data.sampler import Sampler
from vietocr.tool.translate import process_image
from vietocr.tool.create_dataset import createDataset
from vietocr.tool.translate import resize
class OCRDataset(Dataset):
def __init__(self, lmdb_path, root_dir, annotation_path, vocab, image_height=32, image_min_width=32, image_max_width=512, transform=None):
self.root_dir = root_dir
self.annotation_path = os.path.join(root_dir, annotation_path)
self.vocab = vocab
self.transform = transform
self.image_height = image_height
self.image_min_width = image_min_width
self.image_max_width = image_max_width
self.lmdb_path = lmdb_path
if os.path.isdir(self.lmdb_path):
print('{} exists. Remove folder if you want to create new dataset'.format(self.lmdb_path))
sys.stdout.flush()
else:
createDataset(self.lmdb_path, root_dir, annotation_path)
self.env = lmdb.open(
self.lmdb_path,
max_readers=8,
readonly=True,
lock=False,
readahead=False,
meminit=False)
self.txn = self.env.begin(write=False)
nSamples = int(self.txn.get('num-samples'.encode()))
self.nSamples = nSamples
self.build_cluster_indices()
def build_cluster_indices(self):
self.cluster_indices = defaultdict(list)
pbar = tqdm(range(self.__len__()),
desc='{} build cluster'.format(self.lmdb_path),
ncols = 100, position=0, leave=True)
for i in pbar:
bucket = self.get_bucket(i)
self.cluster_indices[bucket].append(i)
def get_bucket(self, idx):
key = 'dim-%09d'%idx
dim_img = self.txn.get(key.encode())
dim_img = np.fromstring(dim_img, dtype=np.int32)
imgH, imgW = dim_img
new_w, image_height = resize(imgW, imgH, self.image_height, self.image_min_width, self.image_max_width)
return new_w
def read_buffer(self, idx):
img_file = 'image-%09d'%idx
label_file = 'label-%09d'%idx
path_file = 'path-%09d'%idx
imgbuf = self.txn.get(img_file.encode())
label = self.txn.get(label_file.encode()).decode()
img_path = self.txn.get(path_file.encode()).decode()
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
return buf, label, img_path
def read_data(self, idx):
buf, label, img_path = self.read_buffer(idx)
img = Image.open(buf).convert('RGB')
if self.transform:
img = self.transform(img)
img_bw = process_image(img, self.image_height, self.image_min_width, self.image_max_width)
word = self.vocab.encode(label)
return img_bw, word, img_path
def __getitem__(self, idx):
img, word, img_path = self.read_data(idx)
img_path = os.path.join(self.root_dir, img_path)
sample = {'img': img, 'word': word, 'img_path': img_path}
return sample
def __len__(self):
return self.nSamples
class ClusterRandomSampler(Sampler):
def __init__(self, data_source, batch_size, shuffle=True):
self.data_source = data_source
self.batch_size = batch_size
self.shuffle = shuffle
def flatten_list(self, lst):
return [item for sublist in lst for item in sublist]
def __iter__(self):
batch_lists = []
for cluster, cluster_indices in self.data_source.cluster_indices.items():
if self.shuffle:
random.shuffle(cluster_indices)
batches = [cluster_indices[i:i + self.batch_size] for i in range(0, len(cluster_indices), self.batch_size)]
batches = [_ for _ in batches if len(_) == self.batch_size]
if self.shuffle:
random.shuffle(batches)
batch_lists.append(batches)
lst = self.flatten_list(batch_lists)
if self.shuffle:
random.shuffle(lst)
lst = self.flatten_list(lst)
return iter(lst)
def __len__(self):
return len(self.data_source)
class Collator(object):
def __init__(self, masked_language_model=True):
self.masked_language_model = masked_language_model
def __call__(self, batch):
filenames = []
img = []
target_weights = []
tgt_input = []
max_label_len = max(len(sample['word']) for sample in batch)
for sample in batch:
img.append(sample['img'])
filenames.append(sample['img_path'])
label = sample['word']
label_len = len(label)
tgt = np.concatenate((
label,
np.zeros(max_label_len - label_len, dtype=np.int32)))
tgt_input.append(tgt)
one_mask_len = label_len - 1
target_weights.append(np.concatenate((
np.ones(one_mask_len, dtype=np.float32),
np.zeros(max_label_len - one_mask_len,dtype=np.float32))))
img = np.array(img, dtype=np.float32)
tgt_input = np.array(tgt_input, dtype=np.int64).T
tgt_output = np.roll(tgt_input, -1, 0).T
tgt_output[:, -1]=0
# random mask token
if self.masked_language_model:
mask = np.random.random(size=tgt_input.shape) < 0.05
mask = mask & (tgt_input != 0) & (tgt_input != 1) & (tgt_input != 2)
tgt_input[mask] = 3
tgt_padding_mask = np.array(target_weights)==0
rs = {
'img': torch.FloatTensor(img),
'tgt_input': torch.LongTensor(tgt_input),
'tgt_output': torch.LongTensor(tgt_output),
'tgt_padding_mask': torch.BoolTensor(tgt_padding_mask),
'filenames': filenames
}
return rs
import torch
import numpy as np
from PIL import Image
import random
from vietocr.model.vocab import Vocab
from vietocr.tool.translate import process_image
import os
from collections import defaultdict
import math
from prefetch_generator import background
class BucketData(object):
def __init__(self, device):
self.max_label_len = 0
self.data_list = []
self.label_list = []
self.file_list = []
self.device = device
def append(self, datum, label, filename):
self.data_list.append(datum)
self.label_list.append(label)
self.file_list.append(filename)
self.max_label_len = max(len(label), self.max_label_len)
return len(self.data_list)
def flush_out(self):
"""
Shape:
- img: (N, C, H, W)
- tgt_input: (T, N)
- tgt_output: (N, T)
- tgt_padding_mask: (N, T)
"""
# encoder part
img = np.array(self.data_list, dtype=np.float32)
# decoder part
target_weights = []
tgt_input = []
for label in self.label_list:
label_len = len(label)
tgt = np.concatenate((
label,
np.zeros(self.max_label_len - label_len, dtype=np.int32)))
tgt_input.append(tgt)
one_mask_len = label_len - 1
target_weights.append(np.concatenate((
np.ones(one_mask_len, dtype=np.float32),
np.zeros(self.max_label_len - one_mask_len,dtype=np.float32))))
# reshape to fit input shape
tgt_input = np.array(tgt_input, dtype=np.int64).T
tgt_output = np.roll(tgt_input, -1, 0).T
tgt_output[:, -1]=0
tgt_padding_mask = np.array(target_weights)==0
filenames = self.file_list
self.data_list, self.label_list, self.file_list = [], [], []
self.max_label_len = 0
rs = {
'img': torch.FloatTensor(img).to(self.device),
'tgt_input': torch.LongTensor(tgt_input).to(self.device),
'tgt_output': torch.LongTensor(tgt_output).to(self.device),
'tgt_padding_mask':torch.BoolTensor(tgt_padding_mask).to(self.device),
'filenames': filenames
}
return rs
def __len__(self):
return len(self.data_list)
def __iadd__(self, other):
self.data_list += other.data_list
self.label_list += other.label_list
self.max_label_len = max(self.max_label_len, other.max_label_len)
self.max_width = max(self.max_width, other.max_width)
def __add__(self, other):
res = BucketData()
res.data_list = self.data_list + other.data_list
res.label_list = self.label_list + other.label_list
res.max_width = max(self.max_width, other.max_width)
res.max_label_len = max((self.max_label_len, other.max_label_len))
return res
class DataGen(object):
def __init__(self,data_root, annotation_fn, vocab, device, image_height=32, image_min_width=32, image_max_width=512):
self.image_height = image_height
self.image_min_width = image_min_width
self.image_max_width = image_max_width
self.data_root = data_root
self.annotation_path = os.path.join(data_root, annotation_fn)
self.vocab = vocab
self.device = device
self.clear()
def clear(self):
self.bucket_data = defaultdict(lambda: BucketData(self.device))
@background(max_prefetch=1)
def gen(self, batch_size, last_batch=True):
with open(self.annotation_path, 'r') as ann_file:
lines = ann_file.readlines()
np.random.shuffle(lines)
for l in lines:
img_path, lex = l.strip().split('\t')
img_path = os.path.join(self.data_root, img_path)
try:
img_bw, word = self.read_data(img_path, lex)
except IOError:
print('ioread image:{}'.format(img_path))
width = img_bw.shape[-1]
bs = self.bucket_data[width].append(img_bw, word, img_path)
if bs >= batch_size:
b = self.bucket_data[width].flush_out()
yield b
if last_batch:
for bucket in self.bucket_data.values():
if len(bucket) > 0:
b = bucket.flush_out()
yield b
self.clear()
def read_data(self, img_path, lex):
with open(img_path, 'rb') as img_file:
img = Image.open(img_file).convert('RGB')
img_bw = process_image(img, self.image_height, self.image_min_width, self.image_max_width)
word = self.vocab.encode(lex)
return img_bw, word
import torch
from torch import nn
import vietocr.model.backbone.vgg as vgg
from vietocr.model.backbone.resnet import Resnet50
class CNN(nn.Module):
def __init__(self, backbone, **kwargs):
super(CNN, self).__init__()
if backbone == 'vgg11_bn':
self.model = vgg.vgg11_bn(**kwargs)
elif backbone == 'vgg19_bn':
self.model = vgg.vgg19_bn(**kwargs)
elif backbone == 'resnet50':
self.model = Resnet50(**kwargs)
def forward(self, x):
return self.model(x)
def freeze(self):
for name, param in self.model.features.named_parameters():
if name != 'last_conv_1x1':
param.requires_grad = False
def unfreeze(self):
for param in self.model.features.parameters():
param.requires_grad = True
import torch
from torch import nn
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = self._conv3x3(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = self._conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def _conv3x3(self, in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, input_channel, output_channel, block, layers):
super(ResNet, self).__init__()
self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
self.inplanes = int(output_channel / 8)
self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
kernel_size=3, stride=1, padding=1, bias=False)
self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
kernel_size=3, stride=1, padding=1, bias=False)
self.bn0_2 = nn.BatchNorm2d(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
0], kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
1], kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
2], kernel_size=3, stride=1, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
3], kernel_size=2, stride=1, padding=0, bias=False)
self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv0_1(x)
x = self.bn0_1(x)
x = self.relu(x)
x = self.conv0_2(x)
x = self.bn0_2(x)
x = self.relu(x)
x = self.maxpool1(x)
x = self.layer1(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool2(x)
x = self.layer2(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.maxpool3(x)
x = self.layer3(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.layer4(x)
x = self.conv4_1(x)
x = self.bn4_1(x)
x = self.relu(x)
x = self.conv4_2(x)
x = self.bn4_2(x)
conv = self.relu(x)
conv = conv.transpose(-1, -2)
conv = conv.flatten(2)
conv = conv.permute(-1, 0, 1)
return conv
def Resnet50(ss, hidden):
return ResNet(3, hidden, BasicBlock, [1, 2, 5, 3])
import torch
from torch import nn
from torchvision import models
from einops import rearrange
from torchvision.models._utils import IntermediateLayerGetter
class Vgg(nn.Module):
def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5):
super(Vgg, self).__init__()
if pretrained:
weights = 'DEFAULT'
else:
weights = None
if name == 'vgg11_bn':
cnn = models.vgg11_bn(weights=weights)
elif name == 'vgg19_bn':
cnn = models.vgg19_bn(weights=weights)
pool_idx = 0
for i, layer in enumerate(cnn.features):
if isinstance(layer, torch.nn.MaxPool2d):
cnn.features[i] = torch.nn.AvgPool2d(kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0)
pool_idx += 1
self.features = cnn.features
self.dropout = nn.Dropout(dropout)
self.last_conv_1x1 = nn.Conv2d(512, hidden, 1)
def forward(self, x):
"""
Shape:
- x: (N, C, H, W)
- output: (W, N, C)
"""
conv = self.features(x)
conv = self.dropout(conv)
conv = self.last_conv_1x1(conv)
# conv = rearrange(conv, 'b d h w -> b d (w h)')
conv = conv.transpose(-1, -2)
conv = conv.flatten(2)
conv = conv.permute(-1, 0, 1)
return conv
def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
return Vgg('vgg11_bn', ss, ks, hidden, pretrained, dropout)
def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
return Vgg('vgg19_bn', ss, ks, hidden, pretrained, dropout)
import torch
class Beam:
def __init__(self, beam_size=8, min_length=0, n_top=1, ranker=None,
start_token_id=1, end_token_id=2):
self.beam_size = beam_size
self.min_length = min_length
self.ranker = ranker
self.end_token_id = end_token_id
self.top_sentence_ended = False
self.prev_ks = []
self.next_ys = [torch.LongTensor(beam_size).fill_(start_token_id)] # remove padding
self.current_scores = torch.FloatTensor(beam_size).zero_()
self.all_scores = []
# Time and k pair for finished.
self.finished = []
self.n_top = n_top
self.ranker = ranker
def advance(self, next_log_probs):
# next_probs : beam_size X vocab_size
vocabulary_size = next_log_probs.size(1)
# current_beam_size = next_log_probs.size(0)
current_length = len(self.next_ys)
if current_length < self.min_length:
for beam_index in range(len(next_log_probs)):
next_log_probs[beam_index][self.end_token_id] = -1e10
if len(self.prev_ks) > 0:
beam_scores = next_log_probs + self.current_scores.unsqueeze(1).expand_as(next_log_probs)
# Don't let EOS have children.
last_y = self.next_ys[-1]
for beam_index in range(last_y.size(0)):
if last_y[beam_index] == self.end_token_id:
beam_scores[beam_index] = -1e10 # -1e20 raises error when executing
else:
beam_scores = next_log_probs[0]
flat_beam_scores = beam_scores.view(-1)
top_scores, top_score_ids = flat_beam_scores.topk(k=self.beam_size, dim=0, largest=True, sorted=True)
self.current_scores = top_scores
self.all_scores.append(self.current_scores)
prev_k = top_score_ids // vocabulary_size # (beam_size, )
next_y = top_score_ids - prev_k * vocabulary_size # (beam_size, )
self.prev_ks.append(prev_k)
self.next_ys.append(next_y)
for beam_index, last_token_id in enumerate(next_y):
if last_token_id == self.end_token_id:
# skip scoring
self.finished.append((self.current_scores[beam_index], len(self.next_ys) - 1, beam_index))
if next_y[0] == self.end_token_id:
self.top_sentence_ended = True
def get_current_state(self):
"Get the outputs for the current timestep."
return torch.stack(self.next_ys, dim=1)
def get_current_origin(self):
"Get the backpointers for the current timestep."
return self.prev_ks[-1]
def done(self):
return self.top_sentence_ended and len(self.finished) >= self.n_top
def get_hypothesis(self, timestep, k):
hypothesis = []
for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
hypothesis.append(self.next_ys[j + 1][k])
# for RNN, [:, k, :], and for trnasformer, [k, :, :]
k = self.prev_ks[j][k]
return hypothesis[::-1]
def sort_finished(self, minimum=None):
if minimum is not None:
i = 0
# Add from beam until we have minimum outputs.
while len(self.finished) < minimum:
# global_scores = self.global_scorer.score(self, self.scores)
# s = global_scores[i]
s = self.current_scores[i]
self.finished.append((s, len(self.next_ys) - 1, i))
i += 1
self.finished = sorted(self.finished, key=lambda a: a[0], reverse=True)
scores = [sc for sc, _, _ in self.finished]
ks = [(t, k) for _, t, k in self.finished]
return scores, ks
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self,
emb_dim,
hid_dim,
n_layers,
kernel_size,
dropout,
device,
max_length = 512):
super().__init__()
assert kernel_size % 2 == 1, "Kernel size must be odd!"
self.device = device
self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)
# self.tok_embedding = nn.Embedding(input_dim, emb_dim)
self.pos_embedding = nn.Embedding(max_length, emb_dim)
self.emb2hid = nn.Linear(emb_dim, hid_dim)
self.hid2emb = nn.Linear(hid_dim, emb_dim)
self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim,
out_channels = 2 * hid_dim,
kernel_size = kernel_size,
padding = (kernel_size - 1) // 2)
for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
def forward(self, src):
#src = [batch size, src len]
src = src.transpose(0, 1)
batch_size = src.shape[0]
src_len = src.shape[1]
device = src.device
#create position tensor
pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device)
#pos = [0, 1, 2, 3, ..., src len - 1]
#pos = [batch size, src len]
#embed tokens and positions
# tok_embedded = self.tok_embedding(src)
tok_embedded = src
pos_embedded = self.pos_embedding(pos)
#tok_embedded = pos_embedded = [batch size, src len, emb dim]
#combine embeddings by elementwise summing
embedded = self.dropout(tok_embedded + pos_embedded)
#embedded = [batch size, src len, emb dim]
#pass embedded through linear layer to convert from emb dim to hid dim
conv_input = self.emb2hid(embedded)
#conv_input = [batch size, src len, hid dim]
#permute for convolutional layer
conv_input = conv_input.permute(0, 2, 1)
#conv_input = [batch size, hid dim, src len]
#begin convolutional blocks...
for i, conv in enumerate(self.convs):
#pass through convolutional layer
conved = conv(self.dropout(conv_input))
#conved = [batch size, 2 * hid dim, src len]
#pass through GLU activation function
conved = F.glu(conved, dim = 1)
#conved = [batch size, hid dim, src len]
#apply residual connection
conved = (conved + conv_input) * self.scale
#conved = [batch size, hid dim, src len]
#set conv_input to conved for next loop iteration
conv_input = conved
#...end convolutional blocks
#permute and convert back to emb dim
conved = self.hid2emb(conved.permute(0, 2, 1))
#conved = [batch size, src len, emb dim]
#elementwise sum output (conved) and input (embedded) to be used for attention
combined = (conved + embedded) * self.scale
#combined = [batch size, src len, emb dim]
return conved, combined
class Decoder(nn.Module):
def __init__(self,
output_dim,
emb_dim,
hid_dim,
n_layers,
kernel_size,
dropout,
trg_pad_idx,
device,
max_length = 512):
super().__init__()
self.kernel_size = kernel_size
self.trg_pad_idx = trg_pad_idx
self.device = device
self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)
self.tok_embedding = nn.Embedding(output_dim, emb_dim)
self.pos_embedding = nn.Embedding(max_length, emb_dim)
self.emb2hid = nn.Linear(emb_dim, hid_dim)
self.hid2emb = nn.Linear(hid_dim, emb_dim)
self.attn_hid2emb = nn.Linear(hid_dim, emb_dim)
self.attn_emb2hid = nn.Linear(emb_dim, hid_dim)
self.fc_out = nn.Linear(emb_dim, output_dim)
self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim,
out_channels = 2 * hid_dim,
kernel_size = kernel_size)
for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
def calculate_attention(self, embedded, conved, encoder_conved, encoder_combined):
#embedded = [batch size, trg len, emb dim]
#conved = [batch size, hid dim, trg len]
#encoder_conved = encoder_combined = [batch size, src len, emb dim]
#permute and convert back to emb dim
conved_emb = self.attn_hid2emb(conved.permute(0, 2, 1))
#conved_emb = [batch size, trg len, emb dim]
combined = (conved_emb + embedded) * self.scale
#combined = [batch size, trg len, emb dim]
energy = torch.matmul(combined, encoder_conved.permute(0, 2, 1))
#energy = [batch size, trg len, src len]
attention = F.softmax(energy, dim=2)
#attention = [batch size, trg len, src len]
attended_encoding = torch.matmul(attention, encoder_combined)
#attended_encoding = [batch size, trg len, emd dim]
#convert from emb dim -> hid dim
attended_encoding = self.attn_emb2hid(attended_encoding)
#attended_encoding = [batch size, trg len, hid dim]
#apply residual connection
attended_combined = (conved + attended_encoding.permute(0, 2, 1)) * self.scale
#attended_combined = [batch size, hid dim, trg len]
return attention, attended_combined
def forward(self, trg, encoder_conved, encoder_combined):
#trg = [batch size, trg len]
#encoder_conved = encoder_combined = [batch size, src len, emb dim]
trg = trg.transpose(0, 1)
batch_size = trg.shape[0]
trg_len = trg.shape[1]
device = trg.device
#create position tensor
pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(device)
#pos = [batch size, trg len]
#embed tokens and positions
tok_embedded = self.tok_embedding(trg)
pos_embedded = self.pos_embedding(pos)
#tok_embedded = [batch size, trg len, emb dim]
#pos_embedded = [batch size, trg len, emb dim]
#combine embeddings by elementwise summing
embedded = self.dropout(tok_embedded + pos_embedded)
#embedded = [batch size, trg len, emb dim]
#pass embedded through linear layer to go through emb dim -> hid dim
conv_input = self.emb2hid(embedded)
#conv_input = [batch size, trg len, hid dim]
#permute for convolutional layer
conv_input = conv_input.permute(0, 2, 1)
#conv_input = [batch size, hid dim, trg len]
batch_size = conv_input.shape[0]
hid_dim = conv_input.shape[1]
for i, conv in enumerate(self.convs):
#apply dropout
conv_input = self.dropout(conv_input)
#need to pad so decoder can't "cheat"
padding = torch.zeros(batch_size,
hid_dim,
self.kernel_size - 1).fill_(self.trg_pad_idx).to(device)
padded_conv_input = torch.cat((padding, conv_input), dim = 2)
#padded_conv_input = [batch size, hid dim, trg len + kernel size - 1]
#pass through convolutional layer
conved = conv(padded_conv_input)
#conved = [batch size, 2 * hid dim, trg len]
#pass through GLU activation function
conved = F.glu(conved, dim = 1)
#conved = [batch size, hid dim, trg len]
#calculate attention
attention, conved = self.calculate_attention(embedded,
conved,
encoder_conved,
encoder_combined)
#attention = [batch size, trg len, src len]
#apply residual connection
conved = (conved + conv_input) * self.scale
#conved = [batch size, hid dim, trg len]
#set conv_input to conved for next loop iteration
conv_input = conved
conved = self.hid2emb(conved.permute(0, 2, 1))
#conved = [batch size, trg len, emb dim]
output = self.fc_out(self.dropout(conved))
#output = [batch size, trg len, output dim]
return output, attention
class ConvSeq2Seq(nn.Module):
def __init__(self, vocab_size, emb_dim, hid_dim, enc_layers, dec_layers, enc_kernel_size, dec_kernel_size, enc_max_length, dec_max_length, dropout, pad_idx, device):
super().__init__()
enc = Encoder(emb_dim, hid_dim, enc_layers, enc_kernel_size, dropout, device, enc_max_length)
dec = Decoder(vocab_size, emb_dim, hid_dim, dec_layers, dec_kernel_size, dropout, pad_idx, device, dec_max_length)
self.encoder = enc
self.decoder = dec
def forward_encoder(self, src):
encoder_conved, encoder_combined = self.encoder(src)
return encoder_conved, encoder_combined
def forward_decoder(self, trg, memory):
encoder_conved, encoder_combined = memory
output, attention = self.decoder(trg, encoder_conved, encoder_combined)
return output, (encoder_conved, encoder_combined)
def forward(self, src, trg):
#src = [batch size, src len]
#trg = [batch size, trg len - 1] (<eos> token sliced off the end)
#calculate z^u (encoder_conved) and (z^u + e) (encoder_combined)
#encoder_conved is output from final encoder conv. block
#encoder_combined is encoder_conved plus (elementwise) src embedding plus
# positional embeddings
encoder_conved, encoder_combined = self.encoder(src)
#encoder_conved = [batch size, src len, emb dim]
#encoder_combined = [batch size, src len, emb dim]
#calculate predictions of next words
#output is a batch of predictions for each word in the trg sentence
#attention a batch of attention scores across the src sentence for
# each word in the trg sentence
output, attention = self.decoder(trg, encoder_conved, encoder_combined)
#output = [batch size, trg len - 1, output dim]
#attention = [batch size, trg len - 1, src len]
return output#, attention
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
super().__init__()
self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, src):
"""
src: src_len x batch_size x img_channel
outputs: src_len x batch_size x hid_dim
hidden: batch_size x hid_dim
"""
embedded = self.dropout(src)
outputs, hidden = self.rnn(embedded)
hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
return outputs, hidden
class Attention(nn.Module):
def __init__(self, enc_hid_dim, dec_hid_dim):
super().__init__()
self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
self.v = nn.Linear(dec_hid_dim, 1, bias = False)
def forward(self, hidden, encoder_outputs):
"""
hidden: batch_size x hid_dim
encoder_outputs: src_len x batch_size x hid_dim,
outputs: batch_size x src_len
"""
batch_size = encoder_outputs.shape[1]
src_len = encoder_outputs.shape[0]
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
encoder_outputs = encoder_outputs.permute(1, 0, 2)
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))
attention = self.v(energy).squeeze(2)
return F.softmax(attention, dim = 1)
class Decoder(nn.Module):
def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
super().__init__()
self.output_dim = output_dim
self.attention = attention
self.embedding = nn.Embedding(output_dim, emb_dim)
self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input, hidden, encoder_outputs):
"""
inputs: batch_size
hidden: batch_size x hid_dim
encoder_outputs: src_len x batch_size x hid_dim
"""
input = input.unsqueeze(0)
embedded = self.dropout(self.embedding(input))
a = self.attention(hidden, encoder_outputs)
a = a.unsqueeze(1)
encoder_outputs = encoder_outputs.permute(1, 0, 2)
weighted = torch.bmm(a, encoder_outputs)
weighted = weighted.permute(1, 0, 2)
rnn_input = torch.cat((embedded, weighted), dim = 2)
output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
assert (output == hidden).all()
embedded = embedded.squeeze(0)
output = output.squeeze(0)
weighted = weighted.squeeze(0)
prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
return prediction, hidden.squeeze(0), a.squeeze(1)
class Seq2Seq(nn.Module):
def __init__(self, vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_embedded, dropout=0.1):
super().__init__()
attn = Attention(encoder_hidden, decoder_hidden)
self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout)
self.decoder = Decoder(vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn)
def forward_encoder(self, src):
"""
src: timestep x batch_size x channel
hidden: batch_size x hid_dim
encoder_outputs: src_len x batch_size x hid_dim
"""
encoder_outputs, hidden = self.encoder(src)
return (hidden, encoder_outputs)
def forward_decoder(self, tgt, memory):
"""
tgt: timestep x batch_size
hidden: batch_size x hid_dim
encouder: src_len x batch_size x hid_dim
output: batch_size x 1 x vocab_size
"""
tgt = tgt[-1]
hidden, encoder_outputs = memory
output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs)
output = output.unsqueeze(1)
return output, (hidden, encoder_outputs)
def forward(self, src, trg):
"""
src: time_step x batch_size
trg: time_step x batch_size
outputs: batch_size x time_step x vocab_size
"""
batch_size = src.shape[1]
trg_len = trg.shape[0]
trg_vocab_size = self.decoder.output_dim
device = src.device
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device)
encoder_outputs, hidden = self.encoder(src)
for t in range(trg_len):
input = trg[t]
output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
outputs[t] = output
outputs = outputs.transpose(0, 1).contiguous()
return outputs
def expand_memory(self, memory, beam_size):
hidden, encoder_outputs = memory
hidden = hidden.repeat(beam_size, 1)
encoder_outputs = encoder_outputs.repeat(1, beam_size, 1)
return (hidden, encoder_outputs)
def get_memory(self, memory, i):
hidden, encoder_outputs = memory
hidden = hidden[[i]]
encoder_outputs = encoder_outputs[:, [i],:]
return (hidden, encoder_outputs)
from einops import rearrange
from torchvision import models
import math
import torch
from torch import nn
class LanguageTransformer(nn.Module):
def __init__(self, vocab_size,
d_model, nhead,
num_encoder_layers, num_decoder_layers,
dim_feedforward, max_seq_length,
pos_dropout, trans_dropout):
super().__init__()
self.d_model = d_model
self.embed_tgt = nn.Embedding(vocab_size, d_model)
self.pos_enc = PositionalEncoding(d_model, pos_dropout, max_seq_length)
# self.learned_pos_enc = LearnedPositionalEncoding(d_model, pos_dropout, max_seq_length)
self.transformer = nn.Transformer(d_model, nhead,
num_encoder_layers, num_decoder_layers,
dim_feedforward, trans_dropout)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
"""
Shape:
- src: (W, N, C)
- tgt: (T, N)
- src_key_padding_mask: (N, S)
- tgt_key_padding_mask: (N, T)
- memory_key_padding_mask: (N, S)
- output: (N, T, E)
"""
tgt_mask = self.gen_nopeek_mask(tgt.shape[0]).to(src.device)
src = self.pos_enc(src*math.sqrt(self.d_model))
# src = self.learned_pos_enc(src*math.sqrt(self.d_model))
tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model))
output = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask.float(), memory_key_padding_mask=memory_key_padding_mask)
# output = rearrange(output, 't n e -> n t e')
output = output.transpose(0, 1)
return self.fc(output)
def gen_nopeek_mask(self, length):
mask = (torch.triu(torch.ones(length, length)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward_encoder(self, src):
src = self.pos_enc(src*math.sqrt(self.d_model))
memory = self.transformer.encoder(src)
return memory
def forward_decoder(self, tgt, memory):
tgt_mask = self.gen_nopeek_mask(tgt.shape[0]).to(tgt.device)
tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model))
output = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask)
# output = rearrange(output, 't n e -> n t e')
output = output.transpose(0, 1)
return self.fc(output), memory
def expand_memory(self, memory, beam_size):
memory = memory.repeat(1, beam_size, 1)
return memory
def get_memory(self, memory, i):
memory = memory[:, [i], :]
return memory
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=100):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class LearnedPositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=100):
super(LearnedPositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
self.pos_embed = nn.Embedding(max_len, d_model)
self.layernorm = LayerNorm(d_model)
def forward(self, x):
seq_len = x.size(0)
pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
pos = pos.unsqueeze(-1).expand(x.size()[:2])
x = x + self.pos_embed(pos)
return self.dropout(self.layernorm(x))
class LayerNorm(nn.Module):
"A layernorm module in the TF style (epsilon inside the square root)."
def __init__(self, d_model, variance_epsilon=1e-12):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.variance_epsilon = variance_epsilon
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.gamma * x + self.beta
from vietocr.optim.optim import ScheduledOptim
from vietocr.optim.labelsmoothingloss import LabelSmoothingLoss
from torch.optim import Adam, SGD, AdamW
from torch import nn
from vietocr.tool.translate import build_model
from vietocr.tool.translate import translate, batch_translate_beam_search
from vietocr.tool.utils import download_weights
from vietocr.tool.logger import Logger
from vietocr.loader.aug import ImgAugTransform, ImgAugTransformV2
import yaml
import torch
from vietocr.loader.dataloader_v1 import DataGen
from vietocr.loader.dataloader import OCRDataset, ClusterRandomSampler, Collator
from torch.utils.data import DataLoader
from einops import rearrange
from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR, OneCycleLR
import torchvision
from vietocr.tool.utils import compute_accuracy
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
import time
class Trainer():
def __init__(self, config, pretrained=True, augmentor=ImgAugTransformV2()):
self.config = config
self.model, self.vocab = build_model(config)
self.device = config['device']
self.num_iters = config['trainer']['iters']
self.beamsearch = config['predictor']['beamsearch']
self.data_root = config['dataset']['data_root']
self.train_annotation = config['dataset']['train_annotation']
self.valid_annotation = config['dataset']['valid_annotation']
self.dataset_name = config['dataset']['name']
self.batch_size = config['trainer']['batch_size']
self.print_every = config['trainer']['print_every']
self.valid_every = config['trainer']['valid_every']
self.image_aug = config['aug']['image_aug']
self.masked_language_model = config['aug']['masked_language_model']
self.checkpoint = config['trainer']['checkpoint']
self.export_weights = config['trainer']['export']
self.metrics = config['trainer']['metrics']
logger = config['trainer']['log']
if logger:
self.logger = Logger(logger)
if pretrained:
weight_file = download_weights(config['pretrain'], quiet=config['quiet'])
self.load_weights(weight_file)
self.iter = 0
self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09)
self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, **config['optimizer'])
# self.optimizer = ScheduledOptim(
# Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
# #config['transformer']['d_model'],
# 512,
# **config['optimizer'])
self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1)
transforms = None
if self.image_aug:
transforms = augmentor
self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
self.data_root, self.train_annotation, self.masked_language_model, transform=transforms)
if self.valid_annotation:
self.valid_gen = self.data_gen('valid_{}'.format(self.dataset_name),
self.data_root, self.valid_annotation, masked_language_model=False)
self.train_losses = []
def train(self):
total_loss = 0
total_loader_time = 0
total_gpu_time = 0
best_acc = 0
data_iter = iter(self.train_gen)
for i in range(self.num_iters):
self.iter += 1
start = time.time()
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(self.train_gen)
batch = next(data_iter)
total_loader_time += time.time() - start
start = time.time()
loss = self.step(batch)
total_gpu_time += time.time() - start
total_loss += loss
self.train_losses.append((self.iter, loss))
if self.iter % self.print_every == 0:
info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(self.iter,
total_loss/self.print_every, self.optimizer.param_groups[0]['lr'],
total_loader_time, total_gpu_time)
total_loss = 0
total_loader_time = 0
total_gpu_time = 0
print(info)
self.logger.log(info)
if self.valid_annotation and self.iter % self.valid_every == 0:
val_loss = self.validate()
acc_full_seq, acc_per_char = self.precision(self.metrics)
info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}'.format(self.iter, val_loss, acc_full_seq, acc_per_char)
print(info)
self.logger.log(info)
if acc_full_seq > best_acc:
self.save_weights(self.export_weights)
best_acc = acc_full_seq
def validate(self):
self.model.eval()
total_loss = []
with torch.no_grad():
for step, batch in enumerate(self.valid_gen):
batch = self.batch_to_device(batch)
img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch['tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']
outputs = self.model(img, tgt_input, tgt_padding_mask)
# loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
outputs = outputs.flatten(0,1)
tgt_output = tgt_output.flatten()
loss = self.criterion(outputs, tgt_output)
total_loss.append(loss.item())
del outputs
del loss
total_loss = np.mean(total_loss)
self.model.train()
return total_loss
def predict(self, sample=None):
pred_sents = []
actual_sents = []
img_files = []
for batch in self.valid_gen:
batch = self.batch_to_device(batch)
if self.beamsearch:
translated_sentence = batch_translate_beam_search(batch['img'], self.model)
prob = None
else:
translated_sentence, prob = translate(batch['img'], self.model)
pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())
img_files.extend(batch['filenames'])
pred_sents.extend(pred_sent)
actual_sents.extend(actual_sent)
if sample != None and len(pred_sents) > sample:
break
return pred_sents, actual_sents, img_files, prob
def precision(self, sample=None):
pred_sents, actual_sents, _, _ = self.predict(sample=sample)
acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence')
acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char')
return acc_full_seq, acc_per_char
def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16):
pred_sents, actual_sents, img_files, probs = self.predict(sample)
if errorcase:
wrongs = []
for i in range(len(img_files)):
if pred_sents[i]!= actual_sents[i]:
wrongs.append(i)
pred_sents = [pred_sents[i] for i in wrongs]
actual_sents = [actual_sents[i] for i in wrongs]
img_files = [img_files[i] for i in wrongs]
probs = [probs[i] for i in wrongs]
img_files = img_files[:sample]
fontdict = {
'family':fontname,
'size':fontsize
}
for vis_idx in range(0, len(img_files)):
img_path = img_files[vis_idx]
pred_sent = pred_sents[vis_idx]
actual_sent = actual_sents[vis_idx]
prob = probs[vis_idx]
img = Image.open(open(img_path, 'rb'))
plt.figure()
plt.imshow(img)
plt.title('prob: {:.3f} - pred: {} - actual: {}'.format(prob, pred_sent, actual_sent), loc='left', fontdict=fontdict)
plt.axis('off')
plt.show()
def visualize_dataset(self, sample=16, fontname='serif'):
n = 0
for batch in self.train_gen:
for i in range(self.batch_size):
img = batch['img'][i].numpy().transpose(1,2,0)
sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())
plt.figure()
plt.title('sent: {}'.format(sent), loc='center', fontname=fontname)
plt.imshow(img)
plt.axis('off')
n += 1
if n >= sample:
plt.show()
return
def load_checkpoint(self, filename):
checkpoint = torch.load(filename)
optim = ScheduledOptim(
Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
self.config['transformer']['d_model'], **self.config['optimizer'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.model.load_state_dict(checkpoint['state_dict'])
self.iter = checkpoint['iter']
self.train_losses = checkpoint['train_losses']
def save_checkpoint(self, filename):
state = {'iter':self.iter, 'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses}
path, _ = os.path.split(filename)
os.makedirs(path, exist_ok=True)
torch.save(state, filename)
def load_weights(self, filename):
state_dict = torch.load(filename, map_location=torch.device(self.device))
for name, param in self.model.named_parameters():
if name not in state_dict:
print('{} not found'.format(name))
elif state_dict[name].shape != param.shape:
print('{} missmatching shape, required {} but found {}'.format(name, param.shape, state_dict[name].shape))
del state_dict[name]
self.model.load_state_dict(state_dict, strict=False)
def save_weights(self, filename):
path, _ = os.path.split(filename)
os.makedirs(path, exist_ok=True)
torch.save(self.model.state_dict(), filename)
def batch_to_device(self, batch):
img = batch['img'].to(self.device, non_blocking=True)
tgt_input = batch['tgt_input'].to(self.device, non_blocking=True)
tgt_output = batch['tgt_output'].to(self.device, non_blocking=True)
tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True)
batch = {
'img': img, 'tgt_input':tgt_input,
'tgt_output':tgt_output, 'tgt_padding_mask':tgt_padding_mask,
'filenames': batch['filenames']
}
return batch
def data_gen(self, lmdb_path, data_root, annotation, masked_language_model=True, transform=None):
dataset = OCRDataset(lmdb_path=lmdb_path,
root_dir=data_root, annotation_path=annotation,
vocab=self.vocab, transform=transform,
image_height=self.config['dataset']['image_height'],
image_min_width=self.config['dataset']['image_min_width'],
image_max_width=self.config['dataset']['image_max_width'])
sampler = ClusterRandomSampler(dataset, self.batch_size, True)
collate_fn = Collator(masked_language_model)
gen = DataLoader(
dataset,
batch_size=self.batch_size,
sampler=sampler,
collate_fn = collate_fn,
shuffle=False,
drop_last=False,
**self.config['dataloader'])
return gen
def data_gen_v1(self, lmdb_path, data_root, annotation):
data_gen = DataGen(data_root, annotation, self.vocab, 'cpu',
image_height = self.config['dataset']['image_height'],
image_min_width = self.config['dataset']['image_min_width'],
image_max_width = self.config['dataset']['image_max_width'])
return data_gen
def step(self, batch):
self.model.train()
batch = self.batch_to_device(batch)
img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch['tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']
outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask)
# loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
outputs = outputs.view(-1, outputs.size(2))#flatten(0, 1)
tgt_output = tgt_output.view(-1)#flatten()
loss = self.criterion(outputs, tgt_output)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
self.optimizer.step()
self.scheduler.step()
loss_item = loss.item()
return loss_item
from vietocr.model.backbone.cnn import CNN
from vietocr.model.seqmodel.transformer import LanguageTransformer
from vietocr.model.seqmodel.seq2seq import Seq2Seq
from vietocr.model.seqmodel.convseq2seq import ConvSeq2Seq
from torch import nn
class VietOCR(nn.Module):
def __init__(self, vocab_size,
backbone,
cnn_args,
transformer_args, seq_modeling='transformer'):
super(VietOCR, self).__init__()
self.cnn = CNN(backbone, **cnn_args)
self.seq_modeling = seq_modeling
if seq_modeling == 'transformer':
self.transformer = LanguageTransformer(vocab_size, **transformer_args)
elif seq_modeling == 'seq2seq':
self.transformer = Seq2Seq(vocab_size, **transformer_args)
elif seq_modeling == 'convseq2seq':
self.transformer = ConvSeq2Seq(vocab_size, **transformer_args)
else:
raise('Not Support Seq Model')
def forward(self, img, tgt_input, tgt_key_padding_mask):
"""
Shape:
- img: (N, C, H, W)
- tgt_input: (T, N)
- tgt_key_padding_mask: (N, T)
- output: b t v
"""
src = self.cnn(img)
if self.seq_modeling == 'transformer':
outputs = self.transformer(src, tgt_input, tgt_key_padding_mask=tgt_key_padding_mask)
elif self.seq_modeling == 'seq2seq':
outputs = self.transformer(src, tgt_input)
elif self.seq_modeling == 'convseq2seq':
outputs = self.transformer(src, tgt_input)
return outputs
class Vocab():
def __init__(self, chars):
self.pad = 0
self.go = 1
self.eos = 2
self.mask_token = 3
self.chars = chars
self.c2i = {c:i+4 for i, c in enumerate(chars)}
self.i2c = {i+4:c for i, c in enumerate(chars)}
self.i2c[0] = '<pad>'
self.i2c[1] = '<sos>'
self.i2c[2] = '<eos>'
self.i2c[3] = '*'
def encode(self, chars):
return [self.go] + [self.c2i[c] for c in chars] + [self.eos]
def decode(self, ids):
first = 1 if self.go in ids else 0
last = ids.index(self.eos) if self.eos in ids else None
sent = ''.join([self.i2c[i] for i in ids[first:last]])
return sent
def __len__(self):
return len(self.c2i) + 4
def batch_decode(self, arr):
texts = [self.decode(ids) for ids in arr]
return texts
def __str__(self):
return self.chars
import torch
from torch import nn
class LabelSmoothingLoss(nn.Module):
def __init__(self, classes, padding_idx, smoothing=0.0, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.cls = classes
self.dim = dim
self.padding_idx = padding_idx
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
with torch.no_grad():
# true_dist = pred.data.clone()
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (self.cls - 2))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
true_dist[:, self.padding_idx] = 0
mask = torch.nonzero(target.data == self.padding_idx, as_tuple=False)
if mask.dim() > 0:
true_dist.index_fill_(0, mask.squeeze(), 0.0)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
class ScheduledOptim():
'''A simple wrapper class for learning rate scheduling'''
def __init__(self, optimizer, d_model, init_lr, n_warmup_steps):
assert n_warmup_steps > 0, 'must be greater than 0'
self._optimizer = optimizer
self.init_lr = init_lr
self.d_model = d_model
self.n_warmup_steps = n_warmup_steps
self.n_steps = 0
def step(self):
"Step with the inner optimizer"
self._update_learning_rate()
self._optimizer.step()
def zero_grad(self):
"Zero out the gradients with the inner optimizer"
self._optimizer.zero_grad()
def _get_lr_scale(self):
d_model = self.d_model
n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))
def state_dict(self):
optimizer_state_dict = {
'init_lr':self.init_lr,
'd_model':self.d_model,
'n_warmup_steps':self.n_warmup_steps,
'n_steps':self.n_steps,
'_optimizer':self._optimizer.state_dict(),
}
return optimizer_state_dict
def load_state_dict(self, state_dict):
self.init_lr = state_dict['init_lr']
self.d_model = state_dict['d_model']
self.n_warmup_steps = state_dict['n_warmup_steps']
self.n_steps = state_dict['n_steps']
self._optimizer.load_state_dict(state_dict['_optimizer'])
def _update_learning_rate(self):
''' Learning rate scheduling per step '''
self.n_steps += 1
for param_group in self._optimizer.param_groups:
lr = self.init_lr*self._get_lr_scale()
self.lr = lr
param_group['lr'] = lr
import argparse
from PIL import Image
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--img', required=True, help='foo help')
parser.add_argument('--config', required=True, help='foo help')
args = parser.parse_args()
config = Cfg.load_config_from_file(args.config)
detector = Predictor(config)
img = Image.open(args.img)
s = detector.predict(img)
print(s)
if __name__ == '__main__':
main()
./image/036170002830.jpeg HOÀNG THỊ THOI
./image/079193002341.jpeg TRỊNH THỊ THÚY HẰNG
./image/001099025107.jpeg NGUYỄN VĂN BÌNH
./image/060085000115.jpeg NGUYỄN MINH TOÀN
./image/026301003919.jpeg NGUYỄN THỊ KIỀU TRANG
./image/079084000809.jpeg LÊ NGỌC PHƯƠNG KHANH
./image/038144000109.jpeg ĐÀO THỊ TƠ
./image/072183002222.jpeg NGUYỄN THANH PHƯỚC
./image/038078002355.jpeg HÀ ĐÌNH LỢI
./image/038089010274.jpeg HÀ VĂN LUÂN
from vietocr.loader.dataloader_v1 import DataGen
from vietocr.model.vocab import Vocab
def test_loader():
chars = 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '
vocab = Vocab(chars)
s_gen = DataGen('./vietocr/tests/', 'sample.txt', vocab, 'cpu', 32, 512)
iterator = s_gen.gen(30)
for batch in iterator:
assert batch['img'].shape[1]==3, 'image must have 3 channels'
assert batch['img'].shape[2]==32, 'the height must be 32'
print(batch['img'].shape, batch['tgt_input'].shape, batch['tgt_output'].shape, batch['tgt_padding_mask'].shape)
if __name__ == '__main__':
test_loader()
import yaml
from vietocr.tool.utils import download_config
url_config = {
'vgg_transformer':'vgg-transformer.yml',
'resnet_transformer':'resnet_transformer.yml',
'resnet_fpn_transformer':'resnet_fpn_transformer.yml',
'vgg_seq2seq':'vgg-seq2seq.yml',
'vgg_convseq2seq':'vgg_convseq2seq.yml',
'vgg_decoderseq2seq':'vgg_decoderseq2seq.yml',
'base':'base.yml',
}
class Cfg(dict):
def __init__(self, config_dict):
super(Cfg, self).__init__(**config_dict)
self.__dict__ = self
@staticmethod
def load_config_from_file(fname):
#base_config = download_config(url_config['base'])
base_config = {}
with open(fname, encoding='utf-8') as f:
config = yaml.safe_load(f)
base_config.update(config)
return Cfg(base_config)
@staticmethod
def load_config_from_name(name):
base_config = download_config(url_config['base'])
config = download_config(url_config[name])
base_config.update(config)
return Cfg(base_config)
def save(self, fname):
with open(fname, 'w', encoding='utf-8') as outfile:
yaml.dump(dict(self), outfile, default_flow_style=False, allow_unicode=True)
import sys
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
def checkImageIsValid(imageBin):
isvalid = True
imgH = None
imgW = None
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
try:
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
isvalid = False
except Exception as e:
isvalid = False
return isvalid, imgH, imgW
def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.items():
txn.put(k.encode(), v)
def createDataset(outputPath, root_dir, annotation_path):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
annotation_path = os.path.join(root_dir, annotation_path)
with open(annotation_path, 'r', encoding='utf-8') as ann_file:
lines = ann_file.readlines()
annotations = [l.strip().split('\t') for l in lines]
nSamples = len(annotations)
env = lmdb.open(outputPath, map_size=1099511627776)
cache = {}
cnt = 0
error = 0
pbar = tqdm(range(nSamples), ncols = 100, desc='Create {}'.format(outputPath))
for i in pbar:
imageFile, label = annotations[i]
imagePath = os.path.join(root_dir, imageFile)
if not os.path.exists(imagePath):
error += 1
continue
with open(imagePath, 'rb') as f:
imageBin = f.read()
isvalid, imgH, imgW = checkImageIsValid(imageBin)
if not isvalid:
error += 1
continue
imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
pathKey = 'path-%09d' % cnt
dimKey = 'dim-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label.encode()
cache[pathKey] = imageFile.encode()
cache[dimKey] = np.array([imgH, imgW], dtype=np.int32).tobytes()
cnt += 1
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
nSamples = cnt-1
cache['num-samples'] = str(nSamples).encode()
writeCache(env, cache)
if error > 0:
print('Remove {} invalid images'.format(error))
print('Created dataset with %d samples' % nSamples)
sys.stdout.flush()
import os
class Logger():
def __init__(self, fname):
path, _ = os.path.split(fname)
os.makedirs(path, exist_ok=True)
self.logger = open(fname, 'w')
def log(self, string):
self.logger.write(string+'\n')
self.logger.flush()
def close(self):
self.logger.close()
from vietocr.tool.translate import build_model, translate, translate_beam_search, process_input, predict
from vietocr.tool.utils import download_weights
import torch
from collections import defaultdict
class Predictor():
def __init__(self, config):
device = config['device']
model, vocab = build_model(config)
weights = '/tmp/weights.pth'
if config['weights'].startswith('http'):
weights = download_weights(config['weights'])
else:
weights = config['weights']
model.load_state_dict(torch.load(weights, map_location=torch.device(device)))
self.config = config
self.model = model
self.vocab = vocab
self.device = device
def predict(self, img, return_prob=False):
img = process_input(img, self.config['dataset']['image_height'],
self.config['dataset']['image_min_width'], self.config['dataset']['image_max_width'])
img = img.to(self.config['device'])
if self.config['predictor']['beamsearch']:
sent = translate_beam_search(img, self.model)
s = sent
prob = None
else:
s, prob = translate(img, self.model)
s = s[0].tolist()
prob = prob[0]
s = self.vocab.decode(s)
if return_prob:
return s, prob
else:
return s
def predict_batch(self, imgs, return_prob=False):
bucket = defaultdict(list)
bucket_idx = defaultdict(list)
bucket_pred = {}
sents, probs = [0]*len(imgs), [0]*len(imgs)
for i, img in enumerate(imgs):
img = process_input(img, self.config['dataset']['image_height'],
self.config['dataset']['image_min_width'], self.config['dataset']['image_max_width'])
bucket[img.shape[-1]].append(img)
bucket_idx[img.shape[-1]].append(i)
for k, batch in bucket.items():
batch = torch.cat(batch, 0).to(self.device)
s, prob = translate(batch, self.model)
prob = prob.tolist()
s = s.tolist()
s = self.vocab.batch_decode(s)
bucket_pred[k] = (s, prob)
for k in bucket_pred:
idx = bucket_idx[k]
sent, prob = bucket_pred[k]
for i, j in enumerate(idx):
sents[j] = sent[i]
probs[j] = prob[i]
if return_prob:
return sents, probs
else:
return sents
import torch
import numpy as np
import math
from PIL import Image
from torch.nn.functional import log_softmax, softmax
from vietocr.model.transformerocr import VietOCR
from vietocr.model.vocab import Vocab
from vietocr.model.beam import Beam
def batch_translate_beam_search(img, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2):
# img: NxCxHxW
model.eval()
device = img.device
sents = []
with torch.no_grad():
src = model.cnn(img)
print(src.shap)
memories = model.transformer.forward_encoder(src)
for i in range(src.size(0)):
# memory = memories[:,i,:].repeat(1, beam_size, 1) # TxNxE
memory = model.transformer.get_memory(memories, i)
sent = beamsearch(memory, model, device, beam_size, candidates, max_seq_length, sos_token, eos_token)
sents.append(sent)
sents = np.asarray(sents)
return sents
def translate_beam_search(img, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2):
# img: 1xCxHxW
model.eval()
device = img.device
with torch.no_grad():
src = model.cnn(img)
memory = model.transformer.forward_encoder(src) #TxNxE
sent = beamsearch(memory, model, device, beam_size, candidates, max_seq_length, sos_token, eos_token)
return sent
def beamsearch(memory, model, device, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2):
# memory: Tx1xE
model.eval()
beam = Beam(beam_size=beam_size, min_length=0, n_top=candidates, ranker=None, start_token_id=sos_token, end_token_id=eos_token)
with torch.no_grad():
# memory = memory.repeat(1, beam_size, 1) # TxNxE
memory = model.transformer.expand_memory(memory, beam_size)
for _ in range(max_seq_length):
tgt_inp = beam.get_current_state().transpose(0,1).to(device) # TxN
decoder_outputs, memory = model.transformer.forward_decoder(tgt_inp, memory)
log_prob = log_softmax(decoder_outputs[:,-1, :].squeeze(0), dim=-1)
beam.advance(log_prob.cpu())
if beam.done():
break
scores, ks = beam.sort_finished(minimum=1)
hypothesises = []
for i, (times, k) in enumerate(ks[:candidates]):
hypothesis = beam.get_hypothesis(times, k)
hypothesises.append(hypothesis)
return [1] + [int(i) for i in hypothesises[0][:-1]]
def translate(img, model, max_seq_length=128, sos_token=1, eos_token=2):
"data: BxCXHxW"
model.eval()
device = img.device
with torch.no_grad():
src = model.cnn(img)
memory = model.transformer.forward_encoder(src)
translated_sentence = [[sos_token]*len(img)]
char_probs = [[1]*len(img)]
max_length = 0
while max_length <= max_seq_length and not all(np.any(np.asarray(translated_sentence).T==eos_token, axis=1)):
tgt_inp = torch.LongTensor(translated_sentence).to(device)
# output = model(img, tgt_inp, tgt_key_padding_mask=None)
# output = model.transformer(src, tgt_inp, tgt_key_padding_mask=None)
output, memory = model.transformer.forward_decoder(tgt_inp, memory)
output = softmax(output, dim=-1)
output = output.to('cpu')
values, indices = torch.topk(output, 5)
indices = indices[:, -1, 0]
indices = indices.tolist()
values = values[:, -1, 0]
values = values.tolist()
char_probs.append(values)
translated_sentence.append(indices)
max_length += 1
del output
translated_sentence = np.asarray(translated_sentence).T
char_probs = np.asarray(char_probs).T
char_probs = np.multiply(char_probs, translated_sentence>3)
char_probs = np.sum(char_probs, axis=-1)/(char_probs>0).sum(-1)
return translated_sentence, char_probs
def build_model(config):
vocab = Vocab(config['vocab'])
device = config['device']
model = VietOCR(len(vocab),
config['backbone'],
config['cnn'],
config['transformer'],
config['seq_modeling'])
model = model.to(device)
return model, vocab
def resize(w, h, expected_height, image_min_width, image_max_width):
new_w = int(expected_height * float(w) / float(h))
round_to = 10
new_w = math.ceil(new_w/round_to)*round_to
new_w = max(new_w, image_min_width)
new_w = min(new_w, image_max_width)
return new_w, expected_height
def process_image(image, image_height, image_min_width, image_max_width):
img = image.convert('RGB')
w, h = img.size
new_w, image_height = resize(w, h, image_height, image_min_width, image_max_width)
img = img.resize((new_w, image_height), Image.LANCZOS)
img = np.asarray(img).transpose(2,0, 1)
img = img/255
return img
def process_input(image, image_height, image_min_width, image_max_width):
img = process_image(image, image_height, image_min_width, image_max_width)
img = img[np.newaxis, ...]
img = torch.FloatTensor(img)
return img
def predict(filename, config):
img = Image.open(filename)
img = process_input(img)
img = img.to(config['device'])
model, vocab = build_model(config)
s = translate(img, model)[0].tolist()
s = vocab.decode(s)
return s
import os
import gdown
import yaml
import numpy as np
import uuid
import requests
import tempfile
from tqdm import tqdm
def download_weights(uri, cached=None, md5=None, quiet=False):
if uri.startswith('http'):
return download(url=uri, quiet=quiet)
return uri
def download(url, quiet=False):
tmp_dir = tempfile.gettempdir()
filename = url.split('/')[-1]
full_path = os.path.join(tmp_dir, filename)
if os.path.exists(full_path):
print('Model weight {} exsits. Ignore download!'.format(full_path))
return full_path
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(full_path, 'wb') as f:
for chunk in tqdm(r.iter_content(chunk_size=8192)):
# If you have chunk encoded response uncomment if
# and set chunk_size parameter to None.
#if chunk:
f.write(chunk)
return full_path
def download_config(id):
url = 'https://vocr.vn/data/vietocr/config/{}'.format(id)
r = requests.get(url)
config = yaml.safe_load(r.text)
return config
def compute_accuracy(ground_truth, predictions, mode='full_sequence'):
"""
Computes accuracy
:param ground_truth:
:param predictions:
:param display: Whether to print values to stdout
:param mode: if 'per_char' is selected then
single_label_accuracy = correct_predicted_char_nums_of_single_sample / single_label_char_nums
avg_label_accuracy = sum(single_label_accuracy) / label_nums
if 'full_sequence' is selected then
single_label_accuracy = 1 if the prediction result is exactly the same as label else 0
avg_label_accuracy = sum(single_label_accuracy) / label_nums
:return: avg_label_accuracy
"""
if mode == 'per_char':
accuracy = []
for index, label in enumerate(ground_truth):
prediction = predictions[index]
total_count = len(label)
correct_count = 0
try:
for i, tmp in enumerate(label):
if tmp == prediction[i]:
correct_count += 1
except IndexError:
continue
finally:
try:
accuracy.append(correct_count / total_count)
except ZeroDivisionError:
if len(prediction) == 0:
accuracy.append(1)
else:
accuracy.append(0)
avg_accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
elif mode == 'full_sequence':
try:
correct_count = 0
for index, label in enumerate(ground_truth):
prediction = predictions[index]
if prediction == label:
correct_count += 1
avg_accuracy = correct_count / len(ground_truth)
except ZeroDivisionError:
if not predictions:
avg_accuracy = 1
else:
avg_accuracy = 0
else:
raise NotImplementedError('Other accuracy compute mode has not been implemented')
return avg_accuracy
import argparse
from vietocr.model.trainer import Trainer
from vietocr.tool.config import Cfg
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', required=True, help='see example at ')
parser.add_argument('--checkpoint', required=False, help='your checkpoint')
args = parser.parse_args()
config = Cfg.load_config_from_file(args.config)
trainer = Trainer(config)
if args.checkpoint:
trainer.load_checkpoint(args.checkpoint)
trainer.train()
if __name__ == '__main__':
main()
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment