Commit 5b96e8f0 authored by Cao Duc Anh's avatar Cao Duc Anh

all1

parent 3bb240fb
minio_data
model
runs
phobert-base
newfile.log
postgres_data
*__pycache__*
\ No newline at end of file
...@@ -187,18 +187,18 @@ Status 400 Bad Request ...@@ -187,18 +187,18 @@ Status 400 Bad Request
## Các thành phần trong hệ thống ## Các thành phần trong hệ thống
### MinIO ### MinIO
Quản lý model, dữ liệu huấn luyện. Quản lý model, dữ liệu huấn luyện.
### PostgreSQL ### Postgres DB
Lưu trữ dữ liệu đã gắn nhãn từ người kiểm duyệt nội dung trên nền tảng MXH, được bổ sung liên tục để tiến hóa mô hình. Lưu trữ dữ liệu đã gán nhãn từ người kiểm duyệt nội dung trên nền tảng MXH, được bổ sung liên tục để cải tiến mô hình.
### Adminer ### Adminer
Công cụ quản trị dữ liệu trên PostgreSQL DB. Công cụ quản trị dữ liệu trên Postgres DB.
### DataManager ### NLP Data
Server tự động khởi tạo bucket minio, các bảng trong cơ sở dữ liệu và cung cấp api bổ sung dữ liệu. Tự động kiểm tra và khởi tạo các bucket minio, các bảng trong cơ sở dữ liệu khi khởi động hệ thống. Cung cấp api bổ sung dữ liệu đã gắn nhãn vào cơ sở dữ liệu phục vụ cho cải tiến mô hình.
### NlpCore ### NLP Infer
Server cung cấp api phân loại văn bản. NLP Infer là thành phần chịu trách nhiệm xử lý và phân loại nội dung văn bản. Cung cấp điểm cuối (endpoint) để các hệ thống bên ngoài và các ứng dụng khác có thể sử dụng tính năng phân loại văn bản.
### Nginx ### Nginx
Thực hiện phân tải: phân phối yêu cầu của người dùng đến các NlpCore. Thực hiện phân tải: phân phối yêu cầu của người dùng đến các NLP Core.
### NlpTraining ### NLP Training
Server cung cấp api huấn luyện model. NLP Training là thành phần thực hiện huấn luyện mô hình AI giúp cải tiến độ chính xác. Cung cấp điểm cuối (endpoint) để điều khiển (bắt đầu hoặc kết thúc) quá trình huấn luyện.
### Tensorboard ### Tensorboard
Giao diện theo dõi các chỉ số trong quá trình huấn luyện model. Giao diện theo dõi các chỉ số trong quá trình huấn luyện model.
......
...@@ -28,3 +28,5 @@ training: ...@@ -28,3 +28,5 @@ training:
load_data_worker: 2 load_data_worker: 2
k_fold: 5 k_fold: 5
test_ratio: 0.1 test_ratio: 0.1
log_file: /src/server.log
...@@ -49,7 +49,7 @@ services: ...@@ -49,7 +49,7 @@ services:
ports: ports:
- 8080:8080 - 8080:8080
datamanager: nlpdata:
image: vn-text-moderation-data image: vn-text-moderation-data
restart: always restart: always
env_file: env_file:
......
...@@ -35,3 +35,5 @@ ...@@ -35,3 +35,5 @@
10.3.2.100 - - [10/Jul/2024:10:23:35 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-" 10.3.2.100 - - [10/Jul/2024:10:23:35 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.100 - - [11/Jul/2024:06:33:03 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-" 10.3.2.100 - - [11/Jul/2024:06:33:03 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.100 - - [11/Jul/2024:06:41:13 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-" 10.3.2.100 - - [11/Jul/2024:06:41:13 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.100 - - [15/Jul/2024:02:13:07 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
10.3.2.100 - - [15/Jul/2024:02:13:22 +0000] "POST /text-classify HTTP/1.1" 200 2318 "-" "PostmanRuntime/7.37.3" "-"
...@@ -7323,3 +7323,91 @@ ...@@ -7323,3 +7323,91 @@
2024/07/11 06:40:56 [notice] 1#1: start worker process 29 2024/07/11 06:40:56 [notice] 1#1: start worker process 29
2024/07/11 06:40:56 [notice] 1#1: start worker process 30 2024/07/11 06:40:56 [notice] 1#1: start worker process 30
2024/07/11 06:40:56 [notice] 1#1: start worker process 31 2024/07/11 06:40:56 [notice] 1#1: start worker process 31
2024/07/11 09:35:05 [notice] 1#1: signal 3 (SIGQUIT) received, shutting down
2024/07/11 09:35:05 [notice] 20#20: gracefully shutting down
2024/07/11 09:35:05 [notice] 21#21: gracefully shutting down
2024/07/11 09:35:05 [notice] 22#22: gracefully shutting down
2024/07/11 09:35:05 [notice] 23#23: gracefully shutting down
2024/07/11 09:35:05 [notice] 25#25: gracefully shutting down
2024/07/11 09:35:05 [notice] 24#24: gracefully shutting down
2024/07/11 09:35:05 [notice] 20#20: exiting
2024/07/11 09:35:05 [notice] 21#21: exiting
2024/07/11 09:35:05 [notice] 22#22: exiting
2024/07/11 09:35:05 [notice] 23#23: exiting
2024/07/11 09:35:05 [notice] 25#25: exiting
2024/07/11 09:35:05 [notice] 24#24: exiting
2024/07/11 09:35:05 [notice] 20#20: exit
2024/07/11 09:35:05 [notice] 21#21: exit
2024/07/11 09:35:05 [notice] 22#22: exit
2024/07/11 09:35:05 [notice] 23#23: exit
2024/07/11 09:35:05 [notice] 25#25: exit
2024/07/11 09:35:05 [notice] 24#24: exit
2024/07/11 09:35:05 [notice] 28#28: gracefully shutting down
2024/07/11 09:35:05 [notice] 27#27: gracefully shutting down
2024/07/11 09:35:05 [notice] 28#28: exiting
2024/07/11 09:35:05 [notice] 31#31: gracefully shutting down
2024/07/11 09:35:05 [notice] 29#29: gracefully shutting down
2024/07/11 09:35:05 [notice] 26#26: gracefully shutting down
2024/07/11 09:35:05 [notice] 27#27: exiting
2024/07/11 09:35:05 [notice] 30#30: gracefully shutting down
2024/07/11 09:35:05 [notice] 31#31: exiting
2024/07/11 09:35:05 [notice] 28#28: exit
2024/07/11 09:35:05 [notice] 29#29: exiting
2024/07/11 09:35:05 [notice] 26#26: exiting
2024/07/11 09:35:05 [notice] 27#27: exit
2024/07/11 09:35:05 [notice] 31#31: exit
2024/07/11 09:35:05 [notice] 26#26: exit
2024/07/11 09:35:05 [notice] 29#29: exit
2024/07/11 09:35:05 [notice] 30#30: exiting
2024/07/11 09:35:05 [notice] 30#30: exit
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 25
2024/07/11 09:35:05 [notice] 1#1: worker process 25 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 28
2024/07/11 09:35:05 [notice] 1#1: worker process 28 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 29
2024/07/11 09:35:05 [notice] 1#1: worker process 29 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: worker process 26 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 26
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 20
2024/07/11 09:35:05 [notice] 1#1: worker process 20 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 31
2024/07/11 09:35:05 [notice] 1#1: worker process 31 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 23
2024/07/11 09:35:05 [notice] 1#1: worker process 23 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 27
2024/07/11 09:35:05 [notice] 1#1: worker process 27 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 24
2024/07/11 09:35:05 [notice] 1#1: worker process 21 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: worker process 24 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 29 (SIGIO) received
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 21
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 30
2024/07/11 09:35:05 [notice] 1#1: worker process 30 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: signal 17 (SIGCHLD) received from 22
2024/07/11 09:35:05 [notice] 1#1: worker process 22 exited with code 0
2024/07/11 09:35:05 [notice] 1#1: exit
2024/07/11 09:36:00 [notice] 1#1: using the "epoll" event method
2024/07/11 09:36:00 [notice] 1#1: nginx/1.25.0
2024/07/11 09:36:00 [notice] 1#1: built by gcc 10.2.1 20210110 (Debian 10.2.1-6)
2024/07/11 09:36:00 [notice] 1#1: OS: Linux 6.5.0-17-generic
2024/07/11 09:36:00 [notice] 1#1: getrlimit(RLIMIT_NOFILE): 1048576:1048576
2024/07/11 09:36:00 [notice] 1#1: start worker processes
2024/07/11 09:36:00 [notice] 1#1: start worker process 20
2024/07/11 09:36:00 [notice] 1#1: start worker process 21
2024/07/11 09:36:00 [notice] 1#1: start worker process 22
2024/07/11 09:36:00 [notice] 1#1: start worker process 23
2024/07/11 09:36:00 [notice] 1#1: start worker process 24
2024/07/11 09:36:00 [notice] 1#1: start worker process 25
2024/07/11 09:36:00 [notice] 1#1: start worker process 26
2024/07/11 09:36:00 [notice] 1#1: start worker process 27
2024/07/11 09:36:00 [notice] 1#1: start worker process 28
2024/07/11 09:36:00 [notice] 1#1: start worker process 29
2024/07/11 09:36:00 [notice] 1#1: start worker process 30
2024/07/11 09:36:00 [notice] 1#1: start worker process 31
from utils import get_data_from_yaml
config = get_data_from_yaml("/src/config.yaml")
DEVICE = config.get("device")
CLASSES = config.get("classes")
MINIO_SERVER = config.get("minio")["server"]
MINIO_DATA_LABELED = config.get("minio")["data_labeled"]
MINIO_MODEL_TRAINED = config.get("minio")["model_trained"]
VNCORENLP_DIR = config.get("vncorenlp")["save_dir"]
PHOBERTBASE_DIR = config.get("phobert_base")["save_dir"]
MAX_TOKEN_LENGTH = config.get("phobert_base")["max_token_length"]
MODEL_CHECKPOINT = config.get("model_checkpoint")
CHUNK_SIZE = config.get("chunk_size")
INFER_LENGTH = config.get("limit_infer_length")
LOG_FILE = config.get("log_file")
\ No newline at end of file
import logging
import logging.handlers
import sys
from constants import LOG_FILE
# Định nghĩa hàm để cấu hình logger
def configure_logger(log_file=LOG_FILE):
# Tạo logger
logger = logging.getLogger('NLP_Infer')
logger.setLevel(logging.DEBUG)
# Định nghĩa handler để ghi log vào file
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
# Định nghĩa handler để in log ra stderr
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.DEBUG)
# Định nghĩa format cho log
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
stream_handler.setFormatter(formatter)
# Thêm handler vào logger
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
return logger
# Sử dụng hàm để cấu hình logger
logger = configure_logger()
import os
from minio import Minio
from minio.error import S3Error
from constants import MINIO_SERVER, MINIO_MODEL_TRAINED, MODEL_CHECKPOINT
from logger import logger
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
def download_latest_model():
# List objects in the bucket
objects = minio_client.list_objects(MINIO_MODEL_TRAINED)
latest_obj = None
latest_time = None
for obj in objects:
if "best" in obj.object_name:
if latest_time is None or obj.last_modified > latest_time:
latest_time = obj.last_modified
latest_obj = obj
if latest_obj is not None:
try:
minio_client.fget_object(MINIO_MODEL_TRAINED, latest_obj.object_name, MODEL_CHECKPOINT)
except S3Error as exc:
logger.error(f"Error occurred: {exc}")
return latest_obj.object_name
else:
raise Exception("No *best* models found in the bucket")
\ No newline at end of file
...@@ -3,57 +3,14 @@ from pydantic import BaseModel ...@@ -3,57 +3,14 @@ from pydantic import BaseModel
import torch import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from minio import Minio
from minio.error import S3Error
import os
from typing import List from typing import List
from bert_model import BERTClassifier from bert_model import BERTClassifier
from utils import get_data_from_yaml, split_chunk from utils import get_data_from_yaml, split_chunk
from preprocess import preprocess_text from preprocess import preprocess_text
from minio_client import download_latest_model
config = get_data_from_yaml("/src/config.yaml") from constants import PHOBERTBASE_DIR, MAX_TOKEN_LENGTH, DEVICE, CLASSES, MODEL_CHECKPOINT, INFER_LENGTH, CHUNK_SIZE
DEVICE = config.get("device") from logger import logger
CLASSES = config.get("classes")
MINIO_SERVER = config.get("minio")["server"]
MINIO_DATA_LABELED = config.get("minio")["data_labeled"]
MINIO_MODEL_TRAINED = config.get("minio")["model_trained"]
VNCORENLP_DIR = config.get("vncorenlp")["save_dir"]
PHOBERTBASE_DIR = config.get("phobert_base")["save_dir"]
MAX_TOKEN_LENGTH = config.get("phobert_base")["max_token_length"]
MODEL_CHECKPOINT = config.get("model_checkpoint")
CHUNK_SIZE = config.get("chunk_size")
INFER_LENGTH = config.get("limit_infer_length")
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
def download_latest_model():
# List objects in the bucket
objects = minio_client.list_objects(MINIO_MODEL_TRAINED)
latest_obj = None
latest_time = None
for obj in objects:
if "best" in obj.object_name:
if latest_time is None or obj.last_modified > latest_time:
latest_time = obj.last_modified
latest_obj = obj
if latest_obj is not None:
try:
minio_client.fget_object(MINIO_MODEL_TRAINED, latest_obj.object_name, MODEL_CHECKPOINT)
except S3Error as exc:
print(f"Error occurred: {exc}")
return latest_obj.object_name
else:
raise Exception("No *best* models found in the bucket")
tokenizer = AutoTokenizer.from_pretrained(PHOBERTBASE_DIR, local_files_only=True, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(PHOBERTBASE_DIR, local_files_only=True, use_fast=False)
...@@ -88,9 +45,9 @@ async def lifespan(app: FastAPI): ...@@ -88,9 +45,9 @@ async def lifespan(app: FastAPI):
update_model = download_latest_model() update_model = download_latest_model()
model.load_state_dict(torch.load(MODEL_CHECKPOINT)) model.load_state_dict(torch.load(MODEL_CHECKPOINT))
model.eval() model.eval()
print(f"Model updated: {update_model}") logger.info(f"Model updated: {update_model}")
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") logger.error(f"An error occurred: {e}")
yield yield
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
......
...@@ -5,4 +5,5 @@ DB_SERVER = config.get("sqldb")["server"] ...@@ -5,4 +5,5 @@ DB_SERVER = config.get("sqldb")["server"]
DB_TABLENAME = config.get("sqldb")["table"] DB_TABLENAME = config.get("sqldb")["table"]
MINIO_SERVER = config.get("minio")["server"] MINIO_SERVER = config.get("minio")["server"]
MINIO_DATA_LABELED = config.get("minio")["data_labeled"] MINIO_DATA_LABELED = config.get("minio")["data_labeled"]
MINIO_MODEL_TRAINED = config.get("minio")["model_trained"] MINIO_MODEL_TRAINED = config.get("minio")["model_trained"]
\ No newline at end of file LOG_FILE = config.get("log_file")
\ No newline at end of file
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
import os import os
from config import DB_SERVER from constants import DB_SERVER
DATABASE_URL = f'postgresql+asyncpg://{os.getenv("POSTGRES_USER")}:{os.getenv("POSTGRES_PASSWORD")}@{DB_SERVER}/{os.getenv("POSTGRES_DB")}' DATABASE_URL = f'postgresql+asyncpg://{os.getenv("POSTGRES_USER")}:{os.getenv("POSTGRES_PASSWORD")}@{DB_SERVER}/{os.getenv("POSTGRES_DB")}'
......
import logging
import logging.handlers
import sys
from constants import LOG_FILE
# Định nghĩa hàm để cấu hình logger
def configure_logger(log_file=LOG_FILE):
# Tạo logger
logger = logging.getLogger('NLP_Data')
logger.setLevel(logging.DEBUG)
# Định nghĩa handler để ghi log vào file
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
# Định nghĩa handler để in log ra stderr
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.DEBUG)
# Định nghĩa format cho log
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
stream_handler.setFormatter(formatter)
# Thêm handler vào logger
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
return logger
# Sử dụng hàm để cấu hình logger
logger = configure_logger()
from minio import Minio
import os
from constants import MINIO_SERVER, MINIO_DATA_LABELED, MINIO_MODEL_TRAINED
from logger import logger
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
bucket_names = [MINIO_DATA_LABELED, MINIO_MODEL_TRAINED]
def check_bucket(minio_client, bucket_names):
# Check if bucket exists
for bucket_name in bucket_names:
found = minio_client.bucket_exists(bucket_name)
if not found:
# Create bucket
minio_client.make_bucket(bucket_name)
logger.info(f"Bucket '{bucket_name}' created successfully.")
else:
logger.info(f"Bucket '{bucket_name}' already exists.")
\ No newline at end of file
from sqlalchemy import Column, Integer, String from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from config import DB_TABLENAME from constants import DB_TABLENAME
Base = declarative_base() Base = declarative_base()
......
from fastapi import FastAPI, Depends, HTTPException from fastapi import FastAPI, Depends, HTTPException
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import MetaData, Table, inspect
from minio import Minio
from minio.error import S3Error
import os
from models import Base, LabeledData from models import Base, LabeledData
from schemas import LabeledDataCreate from schemas import LabeledDataCreate
from sqlalchemy import create_engine, MetaData, Table
from sqlalchemy.orm import sessionmaker
import pandas as pd
from io import StringIO, BytesIO
from database import engine, get_db, SessionLocal from database import engine, get_db, SessionLocal
from config import MINIO_SERVER, MINIO_MODEL_TRAINED, MINIO_DATA_LABELED from constants import MINIO_SERVER, MINIO_MODEL_TRAINED, MINIO_DATA_LABELED
from utils import check_bucket from minio_client import check_bucket
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
bucket_names = [MINIO_DATA_LABELED, MINIO_MODEL_TRAINED]
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
check_bucket(minio_client=minio_client, bucket_names=bucket_names) check_bucket()
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
yield yield
......
...@@ -8,14 +8,3 @@ def get_data_from_yaml(filename): ...@@ -8,14 +8,3 @@ def get_data_from_yaml(filename):
raise IOError(f"Error opening file: {filename}") raise IOError(f"Error opening file: {filename}")
return data return data
def check_bucket(minio_client, bucket_names):
# Check if bucket exists
for bucket_name in bucket_names:
found = minio_client.bucket_exists(bucket_name)
if not found:
# Create bucket
minio_client.make_bucket(bucket_name)
print(f"Bucket '{bucket_name}' created successfully.")
else:
print(f"Bucket '{bucket_name}' already exists.")
\ No newline at end of file
from utils import get_data_from_yaml
config = get_data_from_yaml("/src/config.yaml")
DEVICE = config.get("device")
CLASSES = config.get("classes")
MINIO_SERVER = config.get("minio")["server"]
MINIO_DATA_LABELED = config.get("minio")["data_labeled"]
MINIO_MODEL_TRAINED = config.get("minio")["model_trained"]
VNCORENLP_DIR = config.get("vncorenlp")["save_dir"]
PHOBERTBASE_DIR = config.get("phobert_base")["save_dir"]
MAX_TOKEN_LENGTH = config.get("phobert_base")["max_token_length"]
MODEL_CHECKPOINT = config.get("model_checkpoint")
CHUNK_SIZE = config.get("chunk_size")
EPOCH = config.get("training")["epoch"]
K_FOLD = config.get("training")["k_fold"]
TEST_RATIO = config.get("training")["test_ratio"]
BATCH_SIZE = config.get("training")["batch_size"]
LOAD_DATA_WORKER = config.get("training")["load_data_worker"]
LOG_FILE = config.get("log_file")
\ No newline at end of file
import logging
import logging.handlers
import sys
from constants import LOG_FILE
# Định nghĩa hàm để cấu hình logger
def configure_logger(log_file=LOG_FILE):
# Tạo logger
logger = logging.getLogger('NLP_Training')
logger.setLevel(logging.DEBUG)
# Định nghĩa handler để ghi log vào file
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
# Định nghĩa handler để in log ra stderr
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.DEBUG)
# Định nghĩa format cho log
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
stream_handler.setFormatter(formatter)
# Thêm handler vào logger
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
return logger
# Sử dụng hàm để cấu hình logger
logger = configure_logger()
import os
from minio import Minio
from minio.error import S3Error
from constants import MINIO_SERVER, MINIO_DATA_LABELED, MINIO_MODEL_TRAINED, MODEL_CHECKPOINT
from logger import logger
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
bucket_names = [MINIO_DATA_LABELED, MINIO_MODEL_TRAINED]
def download_latest_model():
# List objects in the bucket
objects = minio_client.list_objects(MINIO_MODEL_TRAINED)
latest_obj = None
latest_time = None
for obj in objects:
if "last" in obj.object_name:
if latest_time is None or obj.last_modified > latest_time:
latest_time = obj.last_modified
latest_obj = obj
if latest_obj is not None:
try:
minio_client.fget_object(MINIO_MODEL_TRAINED, latest_obj.object_name, MODEL_CHECKPOINT)
except S3Error as exc:
logger.error(f"Error occurred: {exc}")
return latest_obj.object_name
else:
raise Exception("No *last* models found in the bucket")
\ No newline at end of file
from fastapi import FastAPI, HTTPException import datetime
from pydantic import BaseModel
import threading import threading
import torch
import numpy as np import numpy as np
import pandas as pd
from io import BytesIO
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager
from pydantic import BaseModel
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers import get_linear_schedule_with_warmup, AutoTokenizer
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import pandas as pd
import os
from gensim.utils import simple_preprocess from gensim.utils import simple_preprocess
from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import StratifiedKFold
from contextlib import asynccontextmanager import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import get_linear_schedule_with_warmup, AutoTokenizer, AutoModel from typing import List
from io import BytesIO
import pandas as pd
import sys
import logging
import datetime
from typing import List
from minio import Minio
from minio.error import S3Error from minio.error import S3Error
from bert_model import BERTClassifier from bert_model import BERTClassifier
from utils import get_data_from_yaml, seed_everything from utils import seed_everything
from preprocess import preprocess_row from preprocess import preprocess_row
from minio_client import minio_client, download_latest_model
from constants import PHOBERTBASE_DIR, CLASSES, DEVICE, MODEL_CHECKPOINT, EPOCH, \
MAX_TOKEN_LENGTH, BATCH_SIZE, LOAD_DATA_WORKER, MINIO_MODEL_TRAINED, \
MINIO_DATA_LABELED, TEST_RATIO, K_FOLD
from logger import logger
# Global variables to manage the training thread and the stop flag # Global variables to manage the training thread and the stop flag
train_thread = None train_thread = None
is_training = False is_training = False
stop_training_flag = threading.Event() stop_training_flag = threading.Event()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
config = get_data_from_yaml("/src/config.yaml")
DEVICE = config.get("device")
CLASSES = config.get("classes")
MINIO_SERVER = config.get("minio")["server"]
MINIO_DATA_LABELED = config.get("minio")["data_labeled"]
MINIO_MODEL_TRAINED = config.get("minio")["model_trained"]
VNCORENLP_DIR = config.get("vncorenlp")["save_dir"]
PHOBERTBASE_DIR = config.get("phobert_base")["save_dir"]
MAX_TOKEN_LENGTH = config.get("phobert_base")["max_token_length"]
MODEL_CHECKPOINT = config.get("model_checkpoint")
CHUNK_SIZE = config.get("chunk_size")
EPOCH = config.get("training")["epoch"]
K_FOLD = config.get("training")["k_fold"]
TEST_RATIO = config.get("training")["test_ratio"]
BATCH_SIZE = config.get("training")["batch_size"]
LOAD_DATA_WORKER = config.get("training")["load_data_worker"]
minio_client = Minio(
endpoint=MINIO_SERVER,
access_key=os.getenv("MINIO_ROOT_USER"),
secret_key=os.getenv("MINIO_ROOT_PASSWORD"),
secure=False
)
bucket_names = [MINIO_DATA_LABELED, MINIO_MODEL_TRAINED]
def download_latest_model():
# List objects in the bucket
objects = minio_client.list_objects(MINIO_MODEL_TRAINED)
latest_obj = None
latest_time = None
for obj in objects:
if "last" in obj.object_name:
if latest_time is None or obj.last_modified > latest_time:
latest_time = obj.last_modified
latest_obj = obj
if latest_obj is not None:
try:
minio_client.fget_object(MINIO_MODEL_TRAINED, latest_obj.object_name, MODEL_CHECKPOINT)
except S3Error as exc:
print(f"Error occurred: {exc}")
return latest_obj.object_name
else:
raise Exception("No *last* models found in the bucket")
def release_training_vram(): def release_training_vram():
# Giải phóng bộ nhớ GPU của quá trình training # Giải phóng bộ nhớ GPU của quá trình training
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -214,7 +155,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False): ...@@ -214,7 +155,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
for fold in range(skf.n_splits): for fold in range(skf.n_splits):
if stop_training_flag.is_set(): if stop_training_flag.is_set():
print("Training stopped.") logger.info("Training stopped.")
writer.close() writer.close()
return "Training stopped" return "Training stopped"
logger.info(f'-----------Fold: {fold+1} ------------------') logger.info(f'-----------Fold: {fold+1} ------------------')
...@@ -233,7 +174,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False): ...@@ -233,7 +174,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
best_acc = 0 best_acc = 0
for e in range(epoch_each_fold): for e in range(epoch_each_fold):
if stop_training_flag.is_set(): if stop_training_flag.is_set():
print("Training stopped.") logger.info("Training stopped.")
writer.close() writer.close()
return "Training stopped" return "Training stopped"
logger.info(f'Fold {fold+1} Epoch {cur_epoch+1}/{EPOCH}') logger.info(f'Fold {fold+1} Epoch {cur_epoch+1}/{EPOCH}')
...@@ -245,7 +186,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False): ...@@ -245,7 +186,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
for data in train_loader: for data in train_loader:
if stop_training_flag.is_set(): if stop_training_flag.is_set():
print("Training stopped.") logger.info("Training stopped.")
writer.close() writer.close()
return "Training stopped" return "Training stopped"
if data is not None: if data is not None:
...@@ -278,7 +219,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False): ...@@ -278,7 +219,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
writer.add_scalar("Loss/train", train_loss, cur_epoch) writer.add_scalar("Loss/train", train_loss, cur_epoch)
# End train ---------------------------------------------------------------------------------- # End train ----------------------------------------------------------------------------------
if stop_training_flag.is_set(): if stop_training_flag.is_set():
print("Training stopped.") logger.info("Training stopped.")
writer.close() writer.close()
return "Training stopped" return "Training stopped"
# Valid -------------------------------------------------------------------------------------- # Valid --------------------------------------------------------------------------------------
...@@ -289,7 +230,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False): ...@@ -289,7 +230,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
with torch.no_grad(): with torch.no_grad():
for data in valid_loader: for data in valid_loader:
if stop_training_flag.is_set(): if stop_training_flag.is_set():
print("Training stopped.") logger.info("Training stopped.")
writer.close() writer.close()
return "Training stopped" return "Training stopped"
input_ids = data['input_ids'].to(DEVICE) input_ids = data['input_ids'].to(DEVICE)
...@@ -315,7 +256,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False): ...@@ -315,7 +256,7 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
# End valid ---------------------------------------------------------------------------------- # End valid ----------------------------------------------------------------------------------
if stop_training_flag.is_set(): if stop_training_flag.is_set():
print("Training stopped.") logger.info("Training stopped.")
writer.close() writer.close()
return "Training stopped" return "Training stopped"
...@@ -333,11 +274,11 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False): ...@@ -333,11 +274,11 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
'/src/phobert_best.pth' # Path to the file you want to upload '/src/phobert_best.pth' # Path to the file you want to upload
) )
except S3Error as exc: except S3Error as exc:
print(f"Error occurred: {exc}") logger.error(f"Error occurred: {exc}")
print(f"File uploaded successfully. {checkpoint_best}") logger.info(f"File uploaded successfully. {checkpoint_best}")
if stop_training_flag.is_set(): if stop_training_flag.is_set():
print("Training stopped.") logger.info("Training stopped.")
writer.close() writer.close()
return "Training stopped" return "Training stopped"
...@@ -350,12 +291,12 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False): ...@@ -350,12 +291,12 @@ def train_model(skf, train_df, tokenizer, use_pretrain=False):
'/src/phobert_last.pth' # Path to the file you want to upload '/src/phobert_last.pth' # Path to the file you want to upload
) )
except S3Error as exc: except S3Error as exc:
print(f"Error occurred: {exc}") logger.error(f"Error occurred: {exc}")
print(f"File uploaded successfully. {checkpoint_last}") logger.info(f"File uploaded successfully. {checkpoint_last}")
cur_epoch = cur_epoch + 1 cur_epoch = cur_epoch + 1
print("Training completed.") logger.info("Training completed.")
writer.close() writer.close()
release_training_vram() release_training_vram()
is_training = False is_training = False
...@@ -372,7 +313,7 @@ class TrainingResponse(BaseModel): ...@@ -372,7 +313,7 @@ class TrainingResponse(BaseModel):
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
print(f'CUDA available: {torch.cuda.is_available()}') logger.info(f'CUDA available: {torch.cuda.is_available()}')
yield yield
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
...@@ -389,7 +330,7 @@ async def start_training(request: TrainingRequest): ...@@ -389,7 +330,7 @@ async def start_training(request: TrainingRequest):
use_pretrain = False use_pretrain = False
if request.pretrain == "": if request.pretrain == "":
print("Training from zero") logger.info("Training from zero")
elif request.pretrain == "latest": elif request.pretrain == "latest":
try: try:
latest_model = download_latest_model() latest_model = download_latest_model()
...@@ -397,7 +338,7 @@ async def start_training(request: TrainingRequest): ...@@ -397,7 +338,7 @@ async def start_training(request: TrainingRequest):
is_training = False is_training = False
raise HTTPException(status_code=500, detail=str(exc)) raise HTTPException(status_code=500, detail=str(exc))
use_pretrain = True use_pretrain = True
print(f"Training from latest: {latest_model}") logger.info(f"Training from latest: {latest_model}")
else: else:
try: try:
minio_client.fget_object(MINIO_MODEL_TRAINED, request.pretrain, MODEL_CHECKPOINT) minio_client.fget_object(MINIO_MODEL_TRAINED, request.pretrain, MODEL_CHECKPOINT)
...@@ -405,7 +346,7 @@ async def start_training(request: TrainingRequest): ...@@ -405,7 +346,7 @@ async def start_training(request: TrainingRequest):
is_training = False is_training = False
raise HTTPException(status_code=500, detail=str(exc)) raise HTTPException(status_code=500, detail=str(exc))
use_pretrain = True use_pretrain = True
print(f"Training from : {request.pretrain}") logger.info(f"Training from : {request.pretrain}")
objects = minio_client.list_objects(MINIO_DATA_LABELED, recursive=True) objects = minio_client.list_objects(MINIO_DATA_LABELED, recursive=True)
......
...@@ -2,6 +2,8 @@ import numpy as np ...@@ -2,6 +2,8 @@ import numpy as np
import torch import torch
import yaml import yaml
from logger import logger
def get_data_from_yaml(filename): def get_data_from_yaml(filename):
try: try:
with open(filename, 'r') as f: with open(filename, 'r') as f:
...@@ -16,8 +18,8 @@ def seed_everything(seed_value): ...@@ -16,8 +18,8 @@ def seed_everything(seed_value):
torch.manual_seed(seed_value) torch.manual_seed(seed_value)
if torch.cuda.is_available(): if torch.cuda.is_available():
print("Torch available") logger.info("Torch available. Start seed_everything.")
torch.cuda.manual_seed(seed_value) torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value) torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment