Commit 5b07caca by 柴鹏飞

重构相关代码,增加日志记录

parent cfdc63bd
......@@ -2,7 +2,6 @@
import re
import json
import logging
import argparse
from datetime import datetime, timedelta
......@@ -12,19 +11,138 @@ from ydl_ai_recommender.src.core.profile_manager import ProfileManager
from ydl_ai_recommender.src.core.recommender import UserCFRecommender
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path
from ydl_ai_recommender.src.utils.log import create_logger
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO
)
logger = create_logger(__name__, 'test.log')
logger = logging.getLogger(__name__)
def load_local_user_profile():
""" 为了保证测试结果可重复、同时避免多次执行sql查询操作,用户画像数据使用离线保存的 """
profile_manager = ProfileManager()
df = profile_manager.load_test_profile_data()
user_profile_dict = {}
for _, row in df.iterrows():
user_profile_dict[row['uid']] = row
logger.info('用户画像数 %s ', len(user_profile_dict))
return user_profile_dict
def load_test_data():
manager = OrderDataManager()
# 加载构建召回模型时的用户uid,用于后续判断用户是否是新用户
train_orders = manager.load_raw_data()
old_users = set(train_orders['uid'])
logger.info('订单用户数 %s ', len(old_users))
# 加载测试数据
test_orders = manager.load_test_order_data()
logger.info('加载测试数据成功,共加载 %s 条', len(test_orders))
return old_users, test_orders
def evaluation(result_detail):
metrics = {
'all_test_cnt': 0,
'all_recall_cnt': 0,
'old_user_test_cnt': 0,
'old_user_recall_cnt': 0,
'new_user_test_cnt': 0,
'new_user_recall_cnt': 0,
'same_user_recall_cnt': 0,
'similar_user_recall_cnt': 0,
}
for rd in result_detail:
metrics['all_test_cnt'] += 1
if rd['is_recall']:
metrics['all_recall_cnt'] += 1
is_same_user, is_similar_user = False, False
for counselor in rd['recall_counselors']:
from_id = counselor['from'].split(' ')[1]
if from_id == rd['uid']:
is_same_user = True
if from_id != rd['uid']:
is_similar_user = True
if is_same_user:
metrics['same_user_recall_cnt'] += 1
if is_similar_user:
metrics['similar_user_recall_cnt'] += 1
if rd['is_old_user']:
metrics['old_user_test_cnt'] += 1
if rd['is_recall']:
metrics['old_user_recall_cnt'] += 1
else:
metrics['new_user_test_cnt'] += 1
if rd['is_recall']:
metrics['new_user_recall_cnt'] += 1
logger.info('==' * 20 + ' 测试结果 ' + '==' * 20)
logger.info('')
logger.info('相关参数配置: 相似用户数(k) %s ;每个相似用户召回咨询师数(top_n) %s', args.k, args.top_n)
logger.info('--' * 45)
logger.info('')
logger.info('{:<13}{:<7}{:<7}{:<7}'.format('', '样本数', '召回数', '召回率'))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('整体\u3000', metrics['all_test_cnt'], metrics['all_recall_cnt'], metrics['all_recall_cnt'] / metrics['all_test_cnt']))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('老用户', metrics['old_user_test_cnt'], metrics['old_user_recall_cnt'], metrics['old_user_recall_cnt'] / metrics['old_user_test_cnt']))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('新用户', metrics['new_user_test_cnt'], metrics['new_user_recall_cnt'], metrics['new_user_recall_cnt'] / metrics['new_user_test_cnt']))
logger.info('--' * 45)
logger.info('')
logger.info('用户自己召回数 {} 占总召回比例 {:.2%}'.format(metrics['same_user_recall_cnt'], metrics['same_user_recall_cnt'] / metrics['all_recall_cnt']))
logger.info('相似用户召回数 {} 占总召回比例 {:.2%}'.format(metrics['similar_user_recall_cnt'], metrics['similar_user_recall_cnt'] / metrics['all_recall_cnt']))
def do_test(args):
user_profile_dict = load_local_user_profile()
old_users, test_orders = load_test_data()
recommender = UserCFRecommender(top_n=args.top_n, k=args.k)
result_detail = []
for index, order_info in test_orders.iterrows():
if args.max_test > 0:
if index >= args.max_test:
break
uid = order_info['uid']
profile = user_profile_dict.get(uid)
if profile is None:
continue
is_merge = args.mode == 0
recommend_result = recommender.recommend_with_profile(profile, is_merge=is_merge)
recall_resons = []
for rr in recommend_result:
if rr['counselor'] == order_info['supplier_id']:
recall_resons.append(rr['from'])
result_detail.append({
'uid': uid,
'supplier_id': order_info['supplier_id'],
'is_old_user': uid in old_users,
'recall_counselors': recommend_result,
'is_recall': len(recall_resons) > 0,
'recall_reason': '|'.join(recall_resons),
})
# 测试结果统计
evaluation(result_detail)
with open('result_detail.json', 'w', encoding='utf-8') as f:
json.dump(result_detail, f, ensure_ascii=False, indent=2)
def main(args):
def test(args):
# 构建用户画像字典,不用每次都从数据库中获取
profile_manager = ProfileManager()
......@@ -58,7 +176,7 @@ def main(args):
if profile is None:
continue
recommend_result = recommender.recommend_with_profile(profile)
recommend_result = recommender.recommend_with_profile(profile, is_merge=False)
recall_resons = []
for rr in recommend_result:
if rr['counselor'] == order_info['supplier_id']:
......@@ -114,17 +232,23 @@ def main(args):
logger.info('==' * 20 + ' 测试结果 ' + '==' * 20)
logger.info('')
logger.info('相关参数配置: 相似用户数(k) %s ;每个相似用户召回咨询师数(top_n) %s', args.k, args.top_n)
logger.info('--' * 45)
logger.info('{:<10}{:<10}{:<10}{:<10}'.format('', '样本数', '召回数', '召回率'))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('整体 ', metrics['all_test_cnt'], metrics['all_recall_cnt'], metrics['all_recall_cnt'] / metrics['all_test_cnt']))
logger.info('')
logger.info('{:<13}{:<7}{:<7}{:<7}'.format('', '样本数', '召回数', '召回率'))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('整体\u3000', metrics['all_test_cnt'], metrics['all_recall_cnt'], metrics['all_recall_cnt'] / metrics['all_test_cnt']))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('老用户', metrics['old_user_test_cnt'], metrics['old_user_recall_cnt'], metrics['old_user_recall_cnt'] / metrics['old_user_test_cnt']))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('新用户', metrics['new_user_test_cnt'], metrics['new_user_recall_cnt'], metrics['new_user_recall_cnt'] / metrics['new_user_test_cnt']))
logger.info('')
logger.info('--' * 45)
logger.info('')
logger.info('用户自己召回数 {} 占总召回比例 {:.2%}'.format(metrics['same_user_recall_cnt'], metrics['same_user_recall_cnt'] / metrics['all_recall_cnt']))
logger.info('相似用户召回数 {} 占总召回比例 {:.2%}'.format(metrics['similar_user_recall_cnt'], metrics['similar_user_recall_cnt'] / metrics['all_recall_cnt']))
# 召回位置指标
with open('result_detail.json', 'w', encoding='utf-8') as f:
json.dump(result_detail, f, ensure_ascii=False, indent=2)
......@@ -132,6 +256,7 @@ def main(args):
def update_test_data(args):
""" 更新测试数据,没有就新建 """
start_date = ''
if re.match(r'20\d\d-[01]\d-\d\d', args.start_date):
start_date = args.start_date
......@@ -160,6 +285,7 @@ def update_test_data(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--k', default=5, type=int, help='召回相似用户的数量')
parser.add_argument('--mode', default=0, type=int, help='模式:0-推荐的咨询师列表去重(默认,与实际线上一样);1-推荐的咨询师列表没有去重')
parser.add_argument('--top_n', default=5, type=int, help='每个相似用户召回的咨询师数量')
parser.add_argument('--max_test', default=0, type=int, help='最多测试数据量')
......@@ -167,9 +293,12 @@ if __name__ == '__main__':
parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天')
args = parser.parse_args()
logger.info('')
if args.do_update_test_data:
logger.info('更新测试数据')
logger.info('测试数据创建时间 %s', args.start_date)
update_test_data(args)
main(args)
\ No newline at end of file
logger.info('开始执行测试任务')
do_test(args)
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import argparse
from datetime import datetime
from ydl_ai_recommender.src.core.order_data_manager import OrderDataManager
from ydl_ai_recommender.src.core.profile_manager import ProfileManager
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_project_path
from ydl_ai_recommender.src.utils.log import create_logger
logger = create_logger(__name__, 'update.log')
parser = argparse.ArgumentParser(description='壹点灵 咨询师推荐 算法召回 离线更新数据模型')
parser.add_argument(
'-t', '--task', type=str, required=True,
choices=('load_db_data', 'make_embedding', 'make_virtual_embedding'), help='执行任务名称'
)
parser.add_argument('--index_last_date', default=None, type=str, help='构建索引最后日期,超过该日期的数据不使用')
args = parser.parse_args()
if __name__ == '__main__':
if args.task == 'load_db_data':
logger.info('')
# 创建数据目录
now = datetime.now()
really_data_dir = os.path.join(get_project_path(), 'data_{}'.format(now.strftime('%Y%m%d_%H%M%S')))
default_data_dir = os.path.join(get_project_path(), 'data')
# 判断data目录是否存在
if os.path.exists(default_data_dir):
if os.path.islink(default_data_dir):
os.unlink(default_data_dir)
else:
logger.error('%s 目录已经存在!请备份后删除该目录再重新执行本操作', default_data_dir)
os.mkdir(really_data_dir)
logger.info('创建数据保存目录成功')
# 创建软连接
os.symlink(really_data_dir, default_data_dir)
logger.info('创建软连接成功 %s -> %s', really_data_dir, default_data_dir)
# TODO 历史数据删除
logger.info('开始从数据库中更新数据')
client = MySQLClient.create_from_config_file(get_conf_path())
logger.info('开始从数据库中更新画像数据')
profile_manager = ProfileManager(client)
profile_manager.update_profile()
logger.info('开始从数据库中更新订单数据')
order_data_manager = OrderDataManager(client)
order_data_manager.update_order_data()
logger.info('更新数据完成')
if args.task == 'make_embedding':
logger.info('')
logger.info('开始构建用户特征 embedding')
manager = ProfileManager()
manager.make_embeddings()
logger.info('用户特征 embedding 构建完成')
logger.info('开始构建订单相关索引')
manager = OrderDataManager()
manager.make_index()
logger.info('订单相关索引 构建完成')
if args.task == 'make_virtual_embedding':
logger.info('')
logger.info('开始构建用户特征虚拟embedding')
manager = ProfileManager()
manager.make_virtual_embedding()
logger.info('用户特征虚拟 embedding 构建完成')
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import json
import logging
from datetime import datetime, timedelta
import pandas as pd
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_data_path
class DBDataManager():
"""
导出/更新数据库中的数据到本地文件
"""
def __init__(self) -> None:
self.local_file_dir = get_data_path()
self.logger = logging.getLogger(__name__)
self.client = MySQLClient.create_from_config_file(get_conf_path())
def _load_local_plain_file(self, name):
with open(os.path.join(self.local_file_dir, name), 'r', encoding='utf-8') as f:
return f.readlines()
def _load_local_json_file(self, name):
with open(os.path.join(self.local_file_dir, name), 'r', encoding='utf-8') as f:
return json.load(f)
def _save_plain_data(self, lines, name, mode='w'):
self.logger.info('开始保存 %s 到本地', name)
with open(os.path.join(self.local_file_dir, name), mode, encoding='utf-8') as f:
f.write('\n'.join(lines))
self.logger.info('%s 保存成功,共保存 %s 行数据', name, len(lines))
def _save_json_data(self, data, name):
self.logger.info('开始保存 %s 到本地', name)
with open(os.path.join(self.local_file_dir, name), 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False)
self.logger.info('%s 保存成功,共保存 %s 个对象', name, len(data))
def update_profile(self):
# TODO 用户画像字段
select_fields = ['*']
sql = 'SELECT {} FROM ads.ads_register_user_profiles'.format(', '.join(select_fields))
sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order)'
_, all_data = self.client.query(sql)
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_profile.xlsx'), index=None)
# self._save_json_data(all_data, 'all_profile.json')
def update_order_info(self):
select_fields = ['main_order_id', 'uid', 'supplier_id', 'price', 'standard_order_type']
select_fields.append('DATE_FORMAT(update_time, "%Y-%m-%d") AS update_time')
sql = 'SELECT {} FROM ods.ods_ydl_standard_order'.format(', '.join(select_fields))
_, all_data = self.client.query(sql)
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), index=None)
def _load_order_data(self, conditions=None):
select_fields = ['main_order_id', 'uid', 'supplier_id', 'price', 'standard_order_type']
select_fields.append('DATE_FORMAT(update_time, "%Y-%m-%d") AS update_time')
sql = 'SELECT {} FROM ods.ods_ydl_standard_order'.format(', '.join(select_fields))
if conditions:
sql += ' WHERE {}'.format('AND '.join(conditions))
self.logger.info('开始执行sql %s', sql)
cnt, data = self.client.query(sql)
self.logger.info('sql执行成功,共获取 %s 条数据', cnt)
return data
def load_test_data(self, days=5):
now = datetime.now()
start_time = now - timedelta(days=days)
conditions = [
'create_time >= "{}"'.format(start_time.strftime('%Y-%m-%d')),
]
order_data = self._load_order_data(conditions=conditions)
df = pd.DataFrame(order_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_order_info.xlsx'), index=None)
select_fields = ['*']
sql = 'SELECT {} FROM ads.ads_register_user_profiles'.format(', '.join(select_fields))
sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order WHERE create_time >= "{}")'.format(start_time.strftime('%Y-%m-%d'))
_, all_data = self.client.query(sql)
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_profile.xlsx'), index=None)
if __name__ == '__main__':
manager = DBDataManager()
manager.load_test_data()
# manager.update_local_data()
# print(manager.make_index())
\ No newline at end of file
# -*- coding: utf-8 -*-
import logging
import argparse
from ydl_ai_recommender.src.core.db_data_manager import DBDataManager
from ydl_ai_recommender.src.core.profile_manager import ProfileManager
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO
)
parser = argparse.ArgumentParser(description='壹点灵 咨询师推荐 算法召回')
parser.add_argument('--index_last_date', default=None, type=str, help='构建索引最后日期,超过该日期的数据不使用')
# 数据自动备份
# 新建索引数据
# 更新索引
# 性能测试
# 的
parser.add_argument(
'-t', '--task', type=str, required=True,
choices=('load_db_data', 'make_profile_index', 'do_test'), help='执行任务名称'
)
parser.add_argument('--test_start_date', default='-3', type=str, help='测试任务 - 开始日期')
parser.add_argument('--test_end_date', default='0', type=str, help='测试任务 - 结束日期')
parser.add_argument('--batch_size', default=128, type=int, help='训练时一个 batch 包含多少条数据')
parser.add_argument('--learning_rate', default=1e-3, type=float, help='Adam 优化器的学习率')
parser.add_argument('--add_special_tokens', default=True, type=bool, help='bert encode 时前后是否添加特殊token')
parser.add_argument('--num_train_epochs', default=2, type=int, help='训练时运行的 epoch 数量', )
parser.add_argument('--seed', type=int, default=42, help='随机数种子,保证结果可复现')
parser.add_argument('--do_train', action='store_true', default=False)
parser.add_argument('--do_test', action='store_true', default=False)
parser.add_argument('--do_predict', action='store_true', default=False)
args = parser.parse_args()
if __name__ == '__main__':
if args.task == 'load_db_data':
# 从数据库中导出信息
manager = DBDataManager()
manager.update_order_info()
manager.update_profile()
if args.task == 'make_profile_index':
manager = ProfileManager()
manager.make_embeddings()
manager.make_virtual_embedding()
if args.task == 'make_similarity':
pass
\ No newline at end of file
......@@ -2,19 +2,19 @@
import os
import json
import logging
from collections import Counter
import pandas as pd
from ydl_ai_recommender.src.utils import get_data_path
from ydl_ai_recommender.src.utils.log import create_logger
class OrderDataManager():
def __init__(self, client=None) -> None:
self.local_file_dir = get_data_path()
self.client = client
self.logger = logging.getLogger(__name__)
self.logger = create_logger(__name__, 'order_data_manager.log')
def _fetch_data_from_db(self, conditions=None):
......@@ -65,8 +65,11 @@ class OrderDataManager():
用户-咨询师 索引
top50 咨询师列表 用于冷启动
"""
self.logger.info('')
self.logger.info('开始构建 用户-咨询师 索引')
df = self.load_raw_data()
self.logger.info('加载数据完成,共加载 %s 条数据', len(df))
self.logger.info('本地订单加载数据完成,共加载 %s 条数据', len(df))
user_order = {}
for index, row in df.iterrows():
......@@ -89,17 +92,22 @@ class OrderDataManager():
index[uid] = sorted(supplier_values, key=lambda x: (x[2], x[1]), reverse=True)
self.logger.info('用户-咨询师 索引构建完成,共构建 %s 条数据', len(index))
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
# 订单最多的咨询师
self.logger.info('用户-咨询师 索引数据已保存,共有用户 %s', len(index))
# 订单最多的咨询师
supplier_cnter = Counter(df['supplier_id'])
top50_supplier = []
for key, _ in supplier_cnter.most_common(50):
top50_supplier.append(str(key))
self.logger.info('top50 订单量咨询师统计完成')
with open(os.path.join(self.local_file_dir, 'top50_supplier.txt'), 'w', encoding='utf-8') as f:
f.write('\n'.join(top50_supplier))
self.logger.info('top50 订单量咨询师列表已保存')
if __name__ == '__main__':
......
......@@ -8,7 +8,7 @@ import pandas as pd
from ydl_ai_recommender.src.utils import get_data_path
from ydl_ai_recommender.src.core.profile import profile_converters
from ydl_ai_recommender.src.utils.log import create_logger
class ProfileManager():
"""
......@@ -19,7 +19,7 @@ class ProfileManager():
self.local_file_dir = get_data_path()
self.profile_file_path = os.path.join(self.local_file_dir, 'all_profile.json')
self.client = client
self.logger = logging.getLogger(__name__)
self.logger = create_logger(__name__, 'profile_manager.log')
def _fetch_data_from_db(self, conditions=None):
......
......@@ -57,7 +57,11 @@ class UserCFRecommender(Recommender):
self.order_user_embedding = order_user_embedding
self.order_user_ids = order_user_ids
self.order_user_counselor_index = order_user_counselor_index
self.default_counselor = default_counselor
self.default_counselor = [{
'counselor': str(user),
'score': index + 1,
'from': 'top_50',
} for index, user in enumerate(default_counselor)]
self.index = faiss.IndexFlatL2(len(self.order_user_embedding[0]))
self.index.add(np.array(self.order_user_embedding))
......@@ -66,15 +70,25 @@ class UserCFRecommender(Recommender):
def get_user_profile(self, user_id):
sql = 'SELECT * FROM ads.ads_register_user_profiles'
sql += ' WHERE uid={}'.format(user_id)
try:
_, all_data = self.client.query(sql)
if len(all_data) == 0:
return []
return all_data[0]
except Exception as e:
self.logging.exception("Exception occurred")
return []
def user_token(self, user_profile):
return self.manager.profile_to_embedding(user_profile)
def _recommend_top(self):
return self.default_counselor
def _recommend(self, user_embedding):
D, I = self.index.search(np.array([user_embedding]), self.k)
counselors = []
......@@ -86,30 +100,46 @@ class UserCFRecommender(Recommender):
recommend_data = [{
'counselor': str(user[0]),
'score': float(score),
'score': float(score) * (index + 1),
'from': 'similar_users {}'.format(similar_user_id),
} for user in similar_user_counselor[:self.top_n]]
} for index, user in enumerate(similar_user_counselor[:self.top_n])]
counselors.extend(recommend_data)
counselors.sort(key=lambda x: x['score'])
return counselors
def recommend_with_profile(self, user_profile):
def recommend_with_profile(self, user_profile, count=0, is_merge=True):
user_embedding = self.user_token(user_profile)
counselors = self._recommend(user_embedding)
if count > 0:
counselors.extend(self._recommend_top())
if is_merge:
counselor_set = set()
merged_counselors = []
for counselor in counselors:
if counselor['counselor'] not in counselor_set:
counselor_set.add(counselor['counselor'])
merged_counselors.append(counselor)
counselors = merged_counselors
if count > 0:
counselors = counselors[:count]
return counselors
def recommend(self, user_id):
def recommend(self, user_id, count=0, is_merge=True):
"""
根据用户画像,推荐咨询师
若获取不到用户画像,推荐默认咨询师(订单最多的)
"""
user_profile = self.get_user_profile(user_id)
if not user_profile:
return []
return self._recommend_top()
return self.recommend_with_profile(user_profile)
return self.recommend_with_profile(user_profile, count, is_merge)
if __name__ == '__main__':
......
......@@ -6,20 +6,14 @@ import logging
from itertools import combinations
from ydl_ai_recommender.src.utils import get_data_path
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO
)
from ydl_ai_recommender.src.utils.log import create_logger
class BasicUserSimilarity():
def __init__(self) -> None:
self.local_file_dir = get_data_path()
self.logger = logging.getLogger(__name__)
self.logger = create_logger(__name__, 'basic_user_similarity')
def compute_similarity(self):
......
# -*- coding: utf-8 -*-
import logging
import configparser
import pymysql
from ydl_ai_recommender.src.utils.log import create_logger
class MySQLClient():
def __init__(self, host, port, user, password) -> None:
self.logger = logging.getLogger(__name__)
self.logger = create_logger(__name__, 'mysql_client.log', is_rotating=True)
self.connection = pymysql.connect(
host=host,
port=port,
......@@ -35,13 +36,13 @@ class MySQLClient():
def _log_info(self, text, *args, **params):
if self.logger:
self.logger.info(text, *args, **params)
self.logger.debug(text, *args, **params)
def query(self, sql):
if self.cursor is None:
self._connect()
sql += ' limit 1000'
# sql += ' limit 1000'
self._log_info('begin execute sql: %s', sql)
row_count = self.cursor.execute(sql)
data = self.cursor.fetchall()
......
# -*- coding: utf-8 -*-
import os
import logging
import logging.handlers
from ydl_ai_recommender.src.utils import get_project_path
def create_logger(name, filename=None, is_rotating=False):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
# 创建handler的输出格式(formatter)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
# 创建一个handler,用于输出控制台,并且设定严重级别
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if filename:
full_filename = os.path.join(get_project_path(), 'log', filename)
if is_rotating:
stream_handler = logging.handlers.TimedRotatingFileHandler(full_filename, when='D', backupCount=14, encoding='utf-8')
else:
stream_handler = logging.FileHandler(full_filename, encoding='utf-8')
stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment