Commit 5b96e8f0 authored by Cao Duc Anh's avatar Cao Duc Anh

all1

parent 3bb240fb
minio_data
model
runs
phobert-base
newfile.log
postgres_data
*__pycache__*
\ No newline at end of file
......@@ -187,18 +187,18 @@ Status 400 Bad Request
## Các thành phần trong hệ thống
### MinIO
Quản lý model, dữ liệu huấn luyện.
### PostgreSQL
Lưu trữ dữ liệu đã gắn nhãn từ người kiểm duyệt nội dung trên nền tảng MXH, được bổ sung liên tục để tiến hóa mô hình.
### Postgres DB
Lưu trữ dữ liệu đã gán nhãn từ người kiểm duyệt nội dung trên nền tảng MXH, được bổ sung liên tục để cải tiến mô hình.
### Adminer
Công cụ quản trị dữ liệu trên PostgreSQL DB.
### DataManager
Server tự động khởi tạo bucket minio, các bảng trong cơ sở dữ liệu và cung cấp api bổ sung dữ liệu.
### NlpCore
Server cung cấp api phân loại văn bản.
Công cụ quản trị dữ liệu trên Postgres DB.
### NLP Data
Tự động kiểm tra và khởi tạo các bucket minio, các bảng trong cơ sở dữ liệu khi khởi động hệ thống. Cung cấp api bổ sung dữ liệu đã gắn nhãn vào cơ sở dữ liệu phục vụ cho cải tiến mô hình.
### NLP Infer
NLP Infer là thành phần chịu trách nhiệm xử lý và phân loại nội dung văn bản. Cung cấp điểm cuối (endpoint) để các hệ thống bên ngoài và các ứng dụng khác có thể sử dụng tính năng phân loại văn bản.
### Nginx
Thực hiện phân tải: phân phối yêu cầu của người dùng đến các NlpCore.
### NlpTraining
Server cung cấp api huấn luyện model.
Thực hiện phân tải: phân phối yêu cầu của người dùng đến các NLP Core.
### NLP Training
NLP Training là thành phần thực hiện huấn luyện mô hình AI giúp cải tiến độ chính xác. Cung cấp điểm cuối (endpoint) để điều khiển (bắt đầu hoặc kết thúc) quá trình huấn luyện.
### Tensorboard
Giao diện theo dõi các chỉ số trong quá trình huấn luyện model.
......
......@@ -28,3 +28,5 @@ training:
load_data_worker: 2
k_fold: 5
test_ratio: 0.1
log_file: /src/server.log
......@@ -49,7 +49,7 @@ services:
ports:
- 8080:8080
datamanager:
nlpdata:
image: vn-text-moderation-data
restart: always
env_file:
......
......@@ -35,3 +35,5 @@
10.3.2.100 - - [10/Jul/2024:10:23:35 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.100 - - [11/Jul/2024:06:33:03 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.100 - - [11/Jul/2024:06:41:13 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.100 - - [15/Jul/2024:02:13:07 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.100 - - [15/Jul/2024:02:13:22 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
......@@ -7323,3 +7323,91 @@
2024/07/11 06:40:56 [notice] 1#1: start worker process 29
2024/07/11 06:40:56 [notice] 1#1: start worker process 30
2024/07/11 06:40:56 [notice] 1#1: start worker process 31
2024/07/11 09:35:05 [notice] 1#1: signal 3 (SIGQUIT) received, shutting down
2024/07/11 09:35:05 [notice] 20#20: gracefully shutting down
2024/07/11 09:35:05 [notice] 21#21: gracefully shutting down
2024/07/11 09:35:05 [notice] 22#22: gracefully shutting down
2024/07/11 09:35:05 [notice] 23#23: gracefully shutting down
2024/07/11 09:35:05 [notice] 25#25: gracefully shutting down
2024/07/11 09:35:05 [notice] 24#24: gracefully shutting down
2024/07/11 09:35:05 [notice] 20#20: exiting
2024/07/11 09:35:05 [notice] 21#21: exiting
2024/07/11 09:35:05 [notice] 22#22: exiting
2024/07/11 09:35:05 [notice] 23#23: exiting
2024/07/11 09:35:05 [notice] 25#25: exiting
2024/07/11 09:35:05 [notice] 24#24: exiting
2024/07/11 09:35:05 [notice] 20#20: exit
2024/07/11 09:35:05 [notice] 21#21: exit
2024/07/11 09:35:05 [notice] 22#22: exit
2024/07/11 09:35:05 [notice] 23#23: exit
2024/07/11 09:35:05 [notice] 25#25: exit
2024/07/11 09:35:05 [notice] 24#24: exit
2024/07/11 09:35:05 [notice] 28#28: gracefully shutting down
2024/07/11 09:35:05 [notice] 27#27: gracefully shutting down
2024/07/11 09:35:05 [notice] 28#28: exiting
2024/07/11 09:35:05 [notice] 31#31: gracefully shutting down
2024/07/11 09:35:05 [notice] 29#29: gracefully shutting down
2024/07/11 09:35:05 [notice] 26#26: gracefully shutting down
2024/07/11 09:35:05 [notice] 27#27: exiting
2024/07/11 09:35:05 [notice] 30#30: gracefully shutting down
2024/07/11 09:35:05 [notice] 31#31: exiting
2024/07/11 09:35:05 [notice] 28#28: exit
2024/07/11 09:35:05 [notice] 29#29: exiting
2024/07/11 09:35:05 [notice] 26#26: exiting
2024/07/11 09:35:05 [notice] 27#27: exit
2024/07/11 09:35:05 [notice] 31#31: exit
2024/07/11 09:35:05 [notice] 26#26: exit
2024/07/11 09:35:05 [notice] 29#29: exit
2024/07/11 09:35:05 [notice] 30#30: exiting
2024/07/11 09:35:05 [notice] 30#30: exit
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 25
2024/07/11 09:35:05 [notice] 1#1: worker process 25 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 28
2024/07/11 09:35:05 [notice] 1#1: worker process 28 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 29
2024/07/11 09:35:05 [notice] 1#1: worker process 29 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: worker process 26 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 26
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 20
2024/07/11 09:35:05 [notice] 1#1: worker process 20 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 31
2024/07/11 09:35:05 [notice] 1#1: worker process 31 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 23
2024/07/11 09:35:05 [notice] 1#1: worker process 23 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 27
2024/07/11 09:35:05 [notice] 1#1: worker process 27 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 24
2024/07/11 09:35:05 [notice] 1#1: worker process 21 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: worker process 24 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 21
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 30
2024/07/11 09:35:05 [notice] 1#1: worker process 30 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 22
2024/07/11 09:35:05 [notice] 1#1: worker process 22 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: exit
2024/07/11 09:36:00 [notice] 1#1: using the "epoll" event method
2024/07/11 09:36:00 [notice] 1#1: nginx/1.25.0
2024/07/11 09:36:00 [notice] 1#1: built by gcc 10.2.1 20210110 (Debian 10.2.1-6)
2024/07/11 09:36:00 [notice] 1#1: OS: Linux 6.5.0-17-generic
2024/07/11 09:36:00 [notice] 1#1: getrlimit(RLIMIT_NOFILE): 1048576:1048576
2024/07/11 09:36:00 [notice] 1#1: start worker processes
2024/07/11 09:36:00 [notice] 1#1: start worker process 20
2024/07/11 09:36:00 [notice] 1#1: start worker process 21
2024/07/11 09:36:00 [notice] 1#1: start worker process 22
2024/07/11 09:36:00 [notice] 1#1: start worker process 23
2024/07/11 09:36:00 [notice] 1#1: start worker process 24
2024/07/11 09:36:00 [notice] 1#1: start worker process 25
2024/07/11 09:36:00 [notice] 1#1: start worker process 26
2024/07/11 09:36:00 [notice] 1#1: start worker process 27
2024/07/11 09:36:00 [notice] 1#1: start worker process 28
2024/07/11 09:36:00 [notice] 1#1: start worker process 29
2024/07/11 09:36:00 [notice] 1#1: start worker process 30
2024/07/11 09:36:00 [notice] 1#1: start worker process 31
from utils import get_data_from_yaml
config = get_data_from_yaml("/src/config.yaml")
DEVICE = config.get("device")
CLASSES = config.get("classes")
MINIO_SERVER = config.get("minio")["server"]
MINIO_DATA_LABELED = config.get("minio")["data_labeled"]
MINIO_MODEL_TRAINED = config.get("minio")["model_trained"]
VNCORENLP_DIR = config.get("vncorenlp")["save_dir"]
PHOBERTBASE_DIR = config.get("phobert_base")["save_dir"]
MAX_TOKEN_LENGTH = config.get("phobert_base")["max_token_length"]
MODEL_CHECKPOINT = config.get("model_checkpoint")
CHUNK_SIZE = config.get("chunk_size")
INFER_LENGTH = config.get("limit_infer_length")
LOG_FILE = config.get("log_file")
\ No newline at end of file
import logging
import logging.handlers
import sys
from constants import LOG_FILE
# Định nghĩa hàm để cấu hình logger
def configure_logger(log_file=LOG_FILE):
# Tạo logger
logger = logging.getLogger('NLP_Infer')
logger.setLevel(logging.DEBUG)
# Định nghĩa handler để ghi log vào file
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
# Định nghĩa handler để in log ra stderr
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.DEBUG)
# Định nghĩa format cho log
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
stream_handler.setFormatter(formatter)
# Thêm handler vào logger
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
return logger
# Sử dụng hàm để cấu hình logger
logger = configure_logger()
import os
from minio import Minio
from minio.error import S3Error
from constants import MINIO_SERVER, MINIO_MODEL_TRAINED, MODEL_CHECKPOINT
from logger import logger
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
def download_latest_model():
# List objects in the bucket
objects = minio_client.list_objects(MINIO_MODEL_TRAINED)
latest_obj = None
latest_time = None
for obj in objects:
if "best" in obj.object_name:
if latest_time is None or obj.last_modified > latest_time:
latest_time = obj.last_modified
latest_obj = obj
if latest_obj is not None:
try:
minio_client.fget_object(MINIO_MODEL_TRAINED, latest_obj.object_name, MODEL_CHECKPOINT)
except S3Error as exc:
logger.error(f"Error occurred: {exc}")
return latest_obj.object_name
else:
raise Exception("No *best* models found in the bucket")
\ No newline at end of file
......@@ -3,57 +3,14 @@ from pydantic import BaseModel
import torch
from transformers import AutoTokenizer
from contextlib import asynccontextmanager
from minio import Minio
from minio.error import S3Error
import os
from typing import List
from bert_model import BERTClassifier
from utils import get_data_from_yaml, split_chunk
from preprocess import preprocess_text
config = get_data_from_yaml("/src/config.yaml")
DEVICE = config.get("device")
CLASSES = config.get("classes")
MINIO_SERVER = config.get("minio")["server"]
MINIO_DATA_LABELED = config.get("minio")["data_labeled"]
MINIO_MODEL_TRAINED = config.get("minio")["model_trained"]
VNCORENLP_DIR = config.get("vncorenlp")["save_dir"]
PHOBERTBASE_DIR = config.get("phobert_base")["save_dir"]
MAX_TOKEN_LENGTH = config.get("phobert_base")["max_token_length"]
MODEL_CHECKPOINT = config.get("model_checkpoint")
CHUNK_SIZE = config.get("chunk_size")
INFER_LENGTH = config.get("limit_infer_length")
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
def download_latest_model():
# List objects in the bucket
objects = minio_client.list_objects(MINIO_MODEL_TRAINED)
latest_obj = None
latest_time = None
for obj in objects:
if "best" in obj.object_name:
if latest_time is None or obj.last_modified > latest_time:
latest_time = obj.last_modified
latest_obj = obj
if latest_obj is not None:
try:
minio_client.fget_object(MINIO_MODEL_TRAINED, latest_obj.object_name, MODEL_CHECKPOINT)
except S3Error as exc:
print(f"Error occurred: {exc}")
return latest_obj.object_name
else:
raise Exception("No *best* models found in the bucket")
from minio_client import download_latest_model
from constants import PHOBERTBASE_DIR, MAX_TOKEN_LENGTH, DEVICE, CLASSES, MODEL_CHECKPOINT, INFER_LENGTH, CHUNK_SIZE
from logger import logger
tokenizer = AutoTokenizer.from_pretrained(PHOBERTBASE_DIR, local_files_only=True, use_fast=False)
......@@ -88,9 +45,9 @@ async def lifespan(app: FastAPI):
update_model = download_latest_model()
model.load_state_dict(torch.load(MODEL_CHECKPOINT))
model.eval()
print(f"Model updated: {update_model}")
logger.info(f"Model updated: {update_model}")
except Exception as e:
print(f"An error occurred: {e}")
logger.error(f"An error occurred: {e}")
yield
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
......
......@@ -5,4 +5,5 @@ DB_SERVER = config.get("sqldb")["server"]
DB_TABLENAME = config.get("sqldb")["table"]
MINIO_SERVER = config.get("minio")["server"]
MINIO_DATA_LABELED = config.get("minio")["data_labeled"]
MINIO_MODEL_TRAINED = config.get("minio")["model_trained"]
\ No newline at end of file
MINIO_MODEL_TRAINED = config.get("minio")["model_trained"]
LOG_FILE = config.get("log_file")
\ No newline at end of file
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
import os
from config import DB_SERVER
from constants import DB_SERVER
DATABASE_URL = f'postgresql+asyncpg://{os.getenv("POSTGRES_USER")}:{os.getenv("POSTGRES_PASSWORD")}@{DB_SERVER}/{os.getenv("POSTGRES_DB")}'
......
import logging
import logging.handlers
import sys
from constants import LOG_FILE
# Định nghĩa hàm để cấu hình logger
def configure_logger(log_file=LOG_FILE):
# Tạo logger
logger = logging.getLogger('NLP_Data')
logger.setLevel(logging.DEBUG)
# Định nghĩa handler để ghi log vào file
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
# Định nghĩa handler để in log ra stderr
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.DEBUG)
# Định nghĩa format cho log
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
stream_handler.setFormatter(formatter)
# Thêm handler vào logger
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
return logger
# Sử dụng hàm để cấu hình logger
logger = configure_logger()
from minio import Minio
import os
from constants import MINIO_SERVER, MINIO_DATA_LABELED, MINIO_MODEL_TRAINED
from logger import logger
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
bucket_names = [MINIO_DATA_LABELED, MINIO_MODEL_TRAINED]
def check_bucket(minio_client, bucket_names):
# Check if bucket exists
for bucket_name in bucket_names:
found = minio_client.bucket_exists(bucket_name)
if not found:
# Create bucket
minio_client.make_bucket(bucket_name)
logger.info(f"Bucket '{bucket_name}' created successfully.")
else:
logger.info(f"Bucket '{bucket_name}' already exists.")
\ No newline at end of file
from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from config import DB_TABLENAME
from constants import DB_TABLENAME
Base = declarative_base()
......
from fastapi import FastAPI, Depends, HTTPException
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import MetaData, Table, inspect
from minio import Minio
from minio.error import S3Error
import os
from models import Base, LabeledData
from schemas import LabeledDataCreate
from sqlalchemy import create_engine, MetaData, Table
from sqlalchemy.orm import sessionmaker
import pandas as pd
from io import StringIO, BytesIO
from database import engine, get_db, SessionLocal
from config import MINIO_SERVER, MINIO_MODEL_TRAINED, MINIO_DATA_LABELED
from utils import check_bucket
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
bucket_names = [MINIO_DATA_LABELED, MINIO_MODEL_TRAINED]
from constants import MINIO_SERVER, MINIO_MODEL_TRAINED, MINIO_DATA_LABELED
from minio_client import check_bucket
@asynccontextmanager
async def lifespan(app: FastAPI):
check_bucket(minio_client=minio_client, bucket_names=bucket_names)
check_bucket()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
......
......@@ -8,14 +8,3 @@ def get_data_from_yaml(filename):
raise IOError(f"Error opening file: {filename}")
return data
def check_bucket(minio_client, bucket_names):
# Check if bucket exists
for bucket_name in bucket_names:
found = minio_client.bucket_exists(bucket_name)
if not found:
# Create bucket
minio_client.make_bucket(bucket_name)
print(f"Bucket '{bucket_name}' created successfully.")
else:
print(f"Bucket '{bucket_name}' already exists.")
\ No newline at end of file
from utils import get_data_from_yaml
config = get_data_from_yaml("/src/config.yaml")
DEVICE = config.get("device")
CLASSES = config.get("classes")
MINIO_SERVER = config.get("minio")["server"]
MINIO_DATA_LABELED = config.get("minio")["data_labeled"]
MINIO_MODEL_TRAINED = config.get("minio")["model_trained"]
VNCORENLP_DIR = config.get("vncorenlp")["save_dir"]
PHOBERTBASE_DIR = config.get("phobert_base")["save_dir"]
MAX_TOKEN_LENGTH = config.get("phobert_base")["max_token_length"]
MODEL_CHECKPOINT = config.get("model_checkpoint")
CHUNK_SIZE = config.get("chunk_size")
EPOCH = config.get("training")["epoch"]
K_FOLD = config.get("training")["k_fold"]
TEST_RATIO = config.get("training")["test_ratio"]
BATCH_SIZE = config.get("training")["batch_size"]
LOAD_DATA_WORKER = config.get("training")["load_data_worker"]
LOG_FILE = config.get("log_file")
\ No newline at end of file
import logging
import logging.handlers
import sys
from constants import LOG_FILE
# Định nghĩa hàm để cấu hình logger
def configure_logger(log_file=LOG_FILE):
# Tạo logger
logger = logging.getLogger('NLP_Training')
logger.setLevel(logging.DEBUG)
# Định nghĩa handler để ghi log vào file
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
# Định nghĩa handler để in log ra stderr
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.DEBUG)
# Định nghĩa format cho log
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
stream_handler.setFormatter(formatter)
# Thêm handler vào logger
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
return logger
# Sử dụng hàm để cấu hình logger
logger = configure_logger()
import os
from minio import Minio
from minio.error import S3Error
from constants import MINIO_SERVER, MINIO_DATA_LABELED, MINIO_MODEL_TRAINED, MODEL_CHECKPOINT
from logger import logger
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
bucket_names = [MINIO_DATA_LABELED, MINIO_MODEL_TRAINED]
def download_latest_model():
# List objects in the bucket
objects = minio_client.list_objects(MINIO_MODEL_TRAINED)
latest_obj = None
latest_time = None
for obj in objects:
if "last" in obj.object_name:
if latest_time is None or obj.last_modified > latest_time:
latest_time = obj.last_modified
latest_obj = obj
if latest_obj is not None:
try:
minio_client.fget_object(MINIO_MODEL_TRAINED, latest_obj.object_name, MODEL_CHECKPOINT)
except S3Error as exc:
logger.error(f"Error occurred: {exc}")
return latest_obj.object_name
else:
raise Exception("No *last* models found in the bucket")
\ No newline at end of file
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import datetime
import threading
import torch
import numpy as np
import pandas as pd
from io import BytesIO
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager
from pydantic import BaseModel
from transformers import AutoTokenizer
from transformers import get_linear_schedule_with_warmup, AutoTokenizer
from sklearn.model_selection import train_test_split
import pandas as pd
import os
from gensim.utils import simple_preprocess
from sklearn.model_selection import StratifiedKFold
from contextlib import asynccontextmanager
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import get_linear_schedule_with_warmup, AutoTokenizer, AutoModel
from io import BytesIO
import pandas as pd
import sys
import logging
import datetime
from typing import List
from minio import Minio
from typing import List
from minio.error import S3Error
from bert_model import BERTClassifier
from utils import get_data_from_yaml, seed_everything
from utils import seed_everything
from preprocess import preprocess_row
from minio_client import minio_client, download_latest_model
from constants import PHOBERTBASE_DIR, CLASSES, DEVICE, MODEL_CHECKPOINT, EPOCH, \
MAX_TOKEN_LENGTH, BATCH_SIZE, LOAD_DATA_WORKER, MINIO_MODEL_TRAINED, \
MINIO_DATA_LABELED, TEST_RATIO, K_FOLD
from logger import logger
# Global variables to manage the training thread and the stop flag
train_thread = None
is_training = False
stop_training_flag = threading.Event()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
config = get_data_from_yaml("/src/config.yaml")
DEVICE = config.get("device")
CLASSES = config.get("classes")
MINIO_SERVER = config.get("minio")["server"]
MINIO_DATA_LABELED = config.get("minio")["data_labeled"]
MINIO_MODEL_TRAINED = config.get("minio")["model_trained"]
VNCORENLP_DIR = config.get("vncorenlp")["save_dir"]
PHOBERTBASE_DIR = config.get("phobert_base")["save_dir"]
MAX_TOKEN_LENGTH = config.get("phobert_base")["max_token_length"]
MODEL_CHECKPOINT = config.get("model_checkpoint")
CHUNK_SIZE = config.get("chunk_size")
EPOCH = config.get("training")["epoch"]
K_FOLD = config.get("training")["k_fold"]
TEST_RATIO = config.get("training")["test_ratio"]
BATCH_SIZE = config.get("training")["batch_size"]
LOAD_DATA_WORKER = config.get("training")["load_data_worker"]
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
bucket_names = [MINIO_DATA_LABELED, MINIO_MODEL_TRAINED]
def download_latest_model():
# List objects in the bucket
objects = minio_client.list_objects(MINIO_MODEL_TRAINED)
latest_obj = None
latest_time = None
for obj in objects:
if "last" in obj.object_name:
if latest_time is None or obj.last_modified > latest_time:
latest_time = obj.last_modified
latest_obj = obj
if latest_obj is not None:
try:
minio_client.fget_object(MINIO_MODEL_TRAINED, latest_obj.object_name, MODEL_CHECKPOINT)
except S3Error as exc:
print(f"Error occurred: {exc}")
return latest_obj.object_name
else:
raise Exception("No *last* models found in the bucket")
def release_training_vram():
# Giải phóng bộ nhớ GPU của quá trình training
torch.cuda.empty_cache()
......@@ -214,7 +155,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
for fold in range(skf.n_splits):
if stop_training_flag.is_set():
print("Training stopped.")
logger.info("Training stopped.")
writer.close()
return "Training stopped"
logger.info(f'-----------Fold: {fold+1} ------------------')
......@@ -233,7 +174,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
best_acc = 0
for e in range(epoch_each_fold):
if stop_training_flag.is_set():
print("Training stopped.")
logger.info("Training stopped.")
writer.close()
return "Training stopped"
logger.info(f'Fold {fold+1} Epoch {cur_epoch+1}/{EPOCH}')
......@@ -245,7 +186,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
for data in train_loader:
if stop_training_flag.is_set():
print("Training stopped.")
logger.info("Training stopped.")
writer.close()
return "Training stopped"
if data is not None:
......@@ -278,7 +219,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
writer.add_scalar("Loss/train", train_loss, cur_epoch)
# End train ----------------------------------------------------------------------------------
if stop_training_flag.is_set():
print("Training stopped.")
logger.info("Training stopped.")
writer.close()
return "Training stopped"
# Valid --------------------------------------------------------------------------------------
......@@ -289,7 +230,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
with torch.no_grad():
for data in valid_loader:
if stop_training_flag.is_set():
print("Training stopped.")
logger.info("Training stopped.")
writer.close()
return "Training stopped"
input_ids = data['input_ids'].to(DEVICE)
......@@ -315,7 +256,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
# End valid ----------------------------------------------------------------------------------
if stop_training_flag.is_set():
print("Training stopped.")
logger.info("Training stopped.")
writer.close()
return "Training stopped"
......@@ -333,11 +274,11 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
'/src/phobert_best.pth' # Path to the file you want to upload
)
except S3Error as exc:
print(f"Error occurred: {exc}")
print(f"File uploaded successfully. {checkpoint_best}")
logger.error(f"Error occurred: {exc}")
logger.info(f"File uploaded successfully. {checkpoint_best}")
if stop_training_flag.is_set():
print("Training stopped.")
logger.info("Training stopped.")
writer.close()
return "Training stopped"
......@@ -350,12 +291,12 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
'/src/phobert_last.pth' # Path to the file you want to upload
)
except S3Error as exc:
print(f"Error occurred: {exc}")
print(f"File uploaded successfully. {checkpoint_last}")
logger.error(f"Error occurred: {exc}")
logger.info(f"File uploaded successfully. {checkpoint_last}")
cur_epoch = cur_epoch + 1
print("Training completed.")
logger.info("Training completed.")
writer.close()
release_training_vram()
is_training = False
......@@ -372,7 +313,7 @@ class TrainingResponse(BaseModel):
@asynccontextmanager
async def lifespan(app: FastAPI):
print(f'CUDA available: {torch.cuda.is_available()}')
logger.info(f'CUDA available: {torch.cuda.is_available()}')
yield
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
......@@ -389,7 +330,7 @@ async def start_training(request: TrainingRequest):
use_pretrain = False
if request.pretrain == "":
print("Training from zero")
logger.info("Training from zero")
elif request.pretrain == "latest":
try:
latest_model = download_latest_model()
......@@ -397,7 +338,7 @@ async def start_training(request: TrainingRequest):
is_training = False
raise HTTPException(status_code=500, detail=str(exc))
use_pretrain = True
print(f"Training from latest: {latest_model}")
logger.info(f"Training from latest: {latest_model}")
else:
try:
minio_client.fget_object(MINIO_MODEL_TRAINED, request.pretrain, MODEL_CHECKPOINT)
......@@ -405,7 +346,7 @@ async def start_training(request: TrainingRequest):
is_training = False
raise HTTPException(status_code=500, detail=str(exc))
use_pretrain = True
print(f"Training from : {request.pretrain}")
logger.info(f"Training from : {request.pretrain}")
objects = minio_client.list_objects(MINIO_DATA_LABELED, recursive=True)
......
......@@ -2,6 +2,8 @@ import numpy as np
import torch
import yaml
from logger import logger
def get_data_from_yaml(filename):
try:
with open(filename, 'r') as f:
......@@ -16,8 +18,8 @@ def seed_everything(seed_value):
torch.manual_seed(seed_value)
if torch.cuda.is_available():
print("Torch available")
logger.info("Torch available. Start seed_everything.")
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark = True
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment