Commit cc451504 authored by Cao Duc Anh's avatar Cao Duc Anh

fix: training error

parent d3981e01
device: cuda
device: cuda:0
classes: ["khac", "phan_dong", "thu_ghet", "khieu_dam"]
model_checkpoint: /src/phobert-base/checkpoint_best.pth
chunk_size: 64
......@@ -23,7 +23,7 @@ phobert_base:
max_token_length: 256
training:
epoch: 200
epoch: 100
batch_size: 8
load_data_worker: 2
k_fold: 5
......
version: '3.8'
services:
server:
image: registry.vivas.vn/vietnam_text_moderation/vn-text-moderation
volumes:
- ./config.yaml:/src/config.yaml
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
tty: true
restart: always
\ No newline at end of file
......@@ -14258,7 +14258,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.11.9"
}
},
"nbformat": 4,
......@@ -38,3 +38,51 @@
10.3.2.100 - - [15/Jul/2024:02:13:07 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.100 - - [15/Jul/2024:02:13:22 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.100 - - [22/Jul/2024:03:13:02 +0000] "POST /text-classify HTTP/1.1" 200 2301 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.100 - - [23/Aug/2024:03:51:47 +0000] "POST /text-classify HTTP/1.1" 200 2327 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:08:00 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:08:14 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:08:18 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:09:35 +0000] "POST /text-classify HTTP/1.1" 200 2325 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:16:14 +0000] "POST /text-classify HTTP/1.1" 200 2325 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:16:17 +0000] "POST /text-classify HTTP/1.1" 200 2324 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:16:20 +0000] "POST /text-classify HTTP/1.1" 200 2324 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:17:14 +0000] "POST /text-classify HTTP/1.1" 200 2328 "-" "curl/7.81.0" "-"
10.3.2.100 - - [23/Aug/2024:04:17:15 +0000] "POST /text-classify HTTP/1.1" 200 2323 "-" "curl/7.81.0" "-"
10.3.2.100 - - [23/Aug/2024:04:17:16 +0000] "POST /text-classify HTTP/1.1" 200 2323 "-" "curl/7.81.0" "-"
10.3.2.100 - - [23/Aug/2024:04:17:43 +0000] "POST /text-classify HTTP/1.1" 200 2328 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:17:46 +0000] "POST /text-classify HTTP/1.1" 499 0 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:17:48 +0000] "POST /text-classify HTTP/1.1" 200 2324 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:17:50 +0000] "POST /text-classify HTTP/1.1" 200 2327 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:17:53 +0000] "POST /text-classify HTTP/1.1" 200 2328 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:18:19 +0000] "POST /text-classify HTTP/1.1" 200 2326 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:18:23 +0000] "POST /text-classify HTTP/1.1" 200 2323 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:18:25 +0000] "POST /text-classify HTTP/1.1" 200 2328 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:18:33 +0000] "POST /text-classify HTTP/1.1" 200 2320 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:25:46 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:25:48 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:25:49 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:25:50 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:25:51 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [23/Aug/2024:04:25:57 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.208 - - [26/Aug/2024:01:16:15 +0000] "GET / HTTP/1.1" 404 22 "-" "Avast Antivirus" "-"
10.3.2.208 - - [26/Aug/2024:01:16:17 +0000] "GET / HTTP/1.1" 404 22 "-" "-" "-"
10.3.2.100 - - [26/Aug/2024:04:01:36 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.111 - - [26/Aug/2024:04:02:52 +0000] "POST /text-classify HTTP/1.1" 200 113 "-" "Go-http-client/1.1" "-"
10.3.2.111 - - [26/Aug/2024:04:03:12 +0000] "POST /api/v1/text-classifier HTTP/1.1" 404 22 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.111 - - [26/Aug/2024:04:03:21 +0000] "POST /text-classifier HTTP/1.1" 404 22 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.111 - - [26/Aug/2024:04:03:30 +0000] "POST /text-classifiy HTTP/1.1" 404 22 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.111 - - [26/Aug/2024:04:03:40 +0000] "POST /text-classify HTTP/1.1" 200 113 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.111 - - [26/Aug/2024:04:12:37 +0000] "POST /text-classify HTTP/1.1" 200 113 "-" "PostmanRuntime/7.37.3" "-"
10.3.3.94 - - [26/Aug/2024:06:49:30 +0000] "POST /text-classify HTTP/1.1" 200 2286 "-" "PostmanRuntime/7.41.2" "-"
10.3.2.100 - - [26/Aug/2024:06:58:36 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [26/Aug/2024:07:17:50 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [26/Aug/2024:07:19:45 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [26/Aug/2024:07:19:47 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [26/Aug/2024:07:19:48 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.3.94 - - [26/Aug/2024:07:55:25 +0000] "POST /text-classify HTTP/1.1" 200 113 "-" "PostmanRuntime/7.41.2" "-"
10.3.3.94 - - [26/Aug/2024:07:55:38 +0000] "POST /text-classify HTTP/1.1" 200 288 "-" "PostmanRuntime/7.41.2" "-"
10.3.3.94 - - [26/Aug/2024:07:56:33 +0000] "POST /text-classify HTTP/1.1" 200 103 "-" "PostmanRuntime/7.41.2" "-"
10.3.3.94 - - [26/Aug/2024:07:56:44 +0000] "POST /text-classify HTTP/1.1" 200 282 "-" "PostmanRuntime/7.41.2" "-"
10.3.2.100 - - [26/Aug/2024:08:15:07 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.1" "-"
10.3.2.100 - - [27/Aug/2024:02:52:53 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "curl/7.81.0" "-"
10.3.2.100 - - [27/Aug/2024:02:53:05 +0000] "POST /text-classify HTTP/1.1" 200 2295 "-" "PostmanRuntime/7.41.2" "-"
This diff is collapsed.
......@@ -6,7 +6,7 @@ from contextlib import asynccontextmanager
from typing import List
from bert_model import BERTClassifier
from utils import get_data_from_yaml, split_chunk
from utils import get_data_from_yaml, split_chunk, seed_everything
from minio_client import download_latest_model
from constants import PHOBERTBASE_DIR, MAX_TOKEN_LENGTH, DEVICE, CLASSES, MODEL_CHECKPOINT, INFER_LENGTH, CHUNK_SIZE
from logger import logger
......@@ -37,9 +37,12 @@ def infer(text, model, tokenizer, class_names, max_len=MAX_TOKEN_LENGTH+2):
model = BERTClassifier(model_bert=PHOBERTBASE_DIR, n_classes=len(CLASSES))
model.to(DEVICE)
print(DEVICE)
@asynccontextmanager
async def lifespan(app: FastAPI):
# This ensures that any random number generation in NumPy, PyTorch (CPU and GPU), and cuDNN will be consistent
seed_everything(86)
global model
try:
update_model = download_latest_model()
......
import yaml
import re
import numpy as np
import torch
def get_data_from_yaml(filename):
try:
......@@ -42,3 +44,14 @@ def split_chunk(text, max_words=200):
paragraphs.append(' '.join(current_paragraph))
return paragraphs
def seed_everything(seed_value):
np.random.seed(seed_value)
torch.manual_seed(seed_value)
if torch.cuda.is_available():
print("Torch available. Start seed_everything.")
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
\ No newline at end of file
......@@ -180,43 +180,49 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
logger.info(f'Fold {fold+1} Epoch {cur_epoch+1}/{EPOCH}')
logger.info('-'*30)
# Train ----------------------------------------------------------------------------------
model.train()
losses = []
correct = 0
for data in train_loader:
if stop_training_flag.is_set():
logger.info("Training stopped.")
writer.close()
return "Training stopped"
if data is not None:
input_ids = data['input_ids'].to(DEVICE)
attention_mask = data['attention_masks'].to(DEVICE)
targets = data['targets'].to(DEVICE)
optimizer.zero_grad()
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
loss = criterion(outputs, targets)
_, pred = torch.max(outputs, dim=1)
try:
model.train()
losses = []
correct = 0
correct += torch.sum(pred == targets)
losses.append(loss.item())
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
lr_scheduler.step()
else:
logger.warning("Warning: Empty data received from data loader. Skipping this iteration.")
train_acc = correct.double()/len(train_loader.dataset)
train_loss = np.mean(losses)
logger.info(f'Train Accuracy: {train_acc} Loss: {train_loss}')
writer.add_scalar("Accuracy/train", train_acc, cur_epoch)
writer.add_scalar("Loss/train", train_loss, cur_epoch)
for data in train_loader:
if stop_training_flag.is_set():
logger.info("Training stopped.")
writer.close()
return "Training stopped"
if data is not None:
input_ids = data['input_ids'].to(DEVICE)
attention_mask = data['attention_masks'].to(DEVICE)
targets = data['targets'].to(DEVICE)
optimizer.zero_grad()
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
loss = criterion(outputs, targets)
_, pred = torch.max(outputs, dim=1)
correct += torch.sum(pred == targets)
losses.append(loss.item())
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
lr_scheduler.step()
else:
logger.warning("Warning: Empty data received from data loader. Skipping this iteration.")
train_acc = correct.double()/len(train_loader.dataset)
train_loss = np.mean(losses)
logger.info(f'Train Accuracy: {train_acc} Loss: {train_loss}')
writer.add_scalar("Accuracy/train", train_acc, cur_epoch)
writer.add_scalar("Loss/train", train_loss, cur_epoch)
except Exception as e:
logger.error("Training error: ", e)
release_training_vram()
is_training = False
# End train ----------------------------------------------------------------------------------
if stop_training_flag.is_set():
logger.info("Training stopped.")
......@@ -413,6 +419,7 @@ async def stop_training():
global stop_training_flag, train_thread, is_training
if not train_thread or not train_thread.is_alive():
is_training = False
raise HTTPException(status_code=400, detail="No training in progress")
stop_training_flag.set()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment