Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
Y
ydl_ai_recommender
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
闫发泽
ydl_ai_recommender
Commits
cfdc63bd
Commit
cfdc63bd
authored
Dec 08, 2022
by
柴鹏飞
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
基于 u2u 的推荐
parent
6e946e5a
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
504 additions
and
24 deletions
+504
-24
environment.yaml
environment.yaml
+2
-3
db_data_manager.py
src/core/db_data_manager.py
+36
-2
main.py
src/core/main.py
+12
-5
order_data_manager.py
src/core/order_data_manager.py
+41
-3
__init__.py
src/core/profile/__init__.py
+0
-0
profile_manager.py
src/core/profile_manager.py
+76
-8
recommender.py
src/core/recommender.py
+118
-0
test.py
src/core/test.py
+176
-0
user_similarity.py
src/core/user_similarity.py
+28
-3
mysql_client.py
src/data/mysql_client.py
+15
-0
No files found.
environment.yaml
View file @
cfdc63bd
...
@@ -4,9 +4,9 @@ channels:
...
@@ -4,9 +4,9 @@ channels:
-
pytorch
-
pytorch
-
defaults
-
defaults
dependencies
:
dependencies
:
-
python==3.
8
-
python==3.
9
-
ipykernel
-
ipykernel
-
faiss-cpu
-
faiss-cpu
-
pip
-
pip
-
pip
:
-
pip
:
-
-r requirements.txt
-
-r requirements.txt
\ No newline at end of file
src/core/db_data_manager.py
View file @
cfdc63bd
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
os
import
os
import
json
import
json
import
logging
import
logging
from
datetime
import
datetime
,
timedelta
import
pandas
as
pd
import
pandas
as
pd
...
@@ -35,7 +36,7 @@ class DBDataManager():
...
@@ -35,7 +36,7 @@ class DBDataManager():
self
.
logger
.
info
(
'开始保存
%
s 到本地'
,
name
)
self
.
logger
.
info
(
'开始保存
%
s 到本地'
,
name
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
name
),
mode
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
name
),
mode
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
'
\n
'
.
join
(
lines
))
f
.
write
(
'
\n
'
.
join
(
lines
))
self
.
logger
.
info
(
'
%
s 保存成功,共保存
%
s 行
内人
'
,
name
,
len
(
lines
))
self
.
logger
.
info
(
'
%
s 保存成功,共保存
%
s 行
数据
'
,
name
,
len
(
lines
))
def
_save_json_data
(
self
,
data
,
name
):
def
_save_json_data
(
self
,
data
,
name
):
...
@@ -63,10 +64,42 @@ class DBDataManager():
...
@@ -63,10 +64,42 @@ class DBDataManager():
_
,
all_data
=
self
.
client
.
query
(
sql
)
_
,
all_data
=
self
.
client
.
query
(
sql
)
df
=
pd
.
DataFrame
(
all_data
)
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_order_info.xlsx'
),
index
=
None
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_order_info.xlsx'
),
index
=
None
)
# self._save_json_data(all_data, 'all_order_info.json')
def
_load_order_data
(
self
,
conditions
=
None
):
select_fields
=
[
'main_order_id'
,
'uid'
,
'supplier_id'
,
'price'
,
'standard_order_type'
]
select_fields
.
append
(
'DATE_FORMAT(update_time, "
%
Y-
%
m-
%
d") AS update_time'
)
sql
=
'SELECT {} FROM ods.ods_ydl_standard_order'
.
format
(
', '
.
join
(
select_fields
))
if
conditions
:
sql
+=
' WHERE {}'
.
format
(
'AND '
.
join
(
conditions
))
self
.
logger
.
info
(
'开始执行sql
%
s'
,
sql
)
cnt
,
data
=
self
.
client
.
query
(
sql
)
self
.
logger
.
info
(
'sql执行成功,共获取
%
s 条数据'
,
cnt
)
return
data
def
load_test_data
(
self
,
days
=
5
):
now
=
datetime
.
now
()
start_time
=
now
-
timedelta
(
days
=
days
)
conditions
=
[
'create_time >= "{}"'
.
format
(
start_time
.
strftime
(
'
%
Y-
%
m-
%
d'
)),
]
order_data
=
self
.
_load_order_data
(
conditions
=
conditions
)
df
=
pd
.
DataFrame
(
order_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_order_info.xlsx'
),
index
=
None
)
select_fields
=
[
'*'
]
sql
=
'SELECT {} FROM ads.ads_register_user_profiles'
.
format
(
', '
.
join
(
select_fields
))
sql
+=
' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order WHERE create_time >= "{}")'
.
format
(
start_time
.
strftime
(
'
%
Y-
%
m-
%
d'
))
_
,
all_data
=
self
.
client
.
query
(
sql
)
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_profile.xlsx'
),
index
=
None
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
manager
=
DBDataManager
()
manager
=
DBDataManager
()
manager
.
load_test_data
()
# manager.update_local_data()
# manager.update_local_data()
# print(manager.make_index())
# print(manager.make_index())
\ No newline at end of file
src/core/main.py
View file @
cfdc63bd
...
@@ -29,10 +29,12 @@ parser.add_argument('--index_last_date', default=None, type=str, help='构建索
...
@@ -29,10 +29,12 @@ parser.add_argument('--index_last_date', default=None, type=str, help='构建索
parser
.
add_argument
(
parser
.
add_argument
(
'-t'
,
'--task'
,
type
=
str
,
required
=
True
,
'-t'
,
'--task'
,
type
=
str
,
required
=
True
,
choices
=
(
'load_db_data'
,
'make_profile_index'
),
help
=
'执行任务名称'
choices
=
(
'load_db_data'
,
'make_profile_index'
,
'do_test'
),
help
=
'执行任务名称'
)
)
parser
.
add_argument
(
'--output_dir'
,
default
=
'outputs'
,
type
=
str
,
help
=
'模型训练中间结果和训练好的模型保存目录'
)
parser
.
add_argument
(
'--test_start_date'
,
default
=
'-3'
,
type
=
str
,
help
=
'测试任务 - 开始日期'
)
parser
.
add_argument
(
'--max_seq_length'
,
default
=
128
,
type
=
int
,
help
=
'tokenization 之后序列最大长度。超过会被截断,小于会补齐'
)
parser
.
add_argument
(
'--test_end_date'
,
default
=
'0'
,
type
=
str
,
help
=
'测试任务 - 结束日期'
)
parser
.
add_argument
(
'--batch_size'
,
default
=
128
,
type
=
int
,
help
=
'训练时一个 batch 包含多少条数据'
)
parser
.
add_argument
(
'--batch_size'
,
default
=
128
,
type
=
int
,
help
=
'训练时一个 batch 包含多少条数据'
)
parser
.
add_argument
(
'--learning_rate'
,
default
=
1e-3
,
type
=
float
,
help
=
'Adam 优化器的学习率'
)
parser
.
add_argument
(
'--learning_rate'
,
default
=
1e-3
,
type
=
float
,
help
=
'Adam 优化器的学习率'
)
parser
.
add_argument
(
'--add_special_tokens'
,
default
=
True
,
type
=
bool
,
help
=
'bert encode 时前后是否添加特殊token'
)
parser
.
add_argument
(
'--add_special_tokens'
,
default
=
True
,
type
=
bool
,
help
=
'bert encode 时前后是否添加特殊token'
)
...
@@ -46,10 +48,15 @@ args = parser.parse_args()
...
@@ -46,10 +48,15 @@ args = parser.parse_args()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
args
.
task
==
'load_db_data'
:
if
args
.
task
==
'load_db_data'
:
# 从数据库中导出信息
manager
=
DBDataManager
()
manager
=
DBDataManager
()
manager
.
update_order_info
()
manager
.
update_order_info
()
manager
.
update_profile
()
manager
.
update_profile
()
if
args
.
task
==
'make_profile_index'
:
if
args
.
task
==
'make_profile_index'
:
manager
=
ProfileManager
()
manager
=
ProfileManager
()
manager
.
make_embeddings
()
manager
.
make_embeddings
()
\ No newline at end of file
manager
.
make_virtual_embedding
()
if
args
.
task
==
'make_similarity'
:
pass
\ No newline at end of file
src/core/order_data_manager.py
View file @
cfdc63bd
...
@@ -11,13 +11,51 @@ from ydl_ai_recommender.src.utils import get_data_path
...
@@ -11,13 +11,51 @@ from ydl_ai_recommender.src.utils import get_data_path
class
OrderDataManager
():
class
OrderDataManager
():
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
client
=
None
)
->
None
:
self
.
local_file_dir
=
get_data_path
()
self
.
local_file_dir
=
get_data_path
()
self
.
client
=
client
self
.
logger
=
logging
.
getLogger
(
__name__
)
self
.
logger
=
logging
.
getLogger
(
__name__
)
def
_fetch_data_from_db
(
self
,
conditions
=
None
):
if
self
.
client
is
None
:
self
.
logger
.
error
(
'未连接数据库'
)
raise
condition_sql
=
''
if
conditions
:
condition_sql
=
' WHERE '
+
' AND '
.
join
(
conditions
)
select_fields
=
[
'main_order_id'
,
'uid'
,
'supplier_id'
,
'price'
,
'standard_order_type'
]
select_fields
.
append
(
'DATE_FORMAT(update_time, "
%
Y-
%
m-
%
d") AS update_time'
)
sql
=
'SELECT {} FROM ods.ods_ydl_standard_order'
.
format
(
', '
.
join
(
select_fields
))
sql
+=
condition_sql
_
,
all_data
=
self
.
client
.
query
(
sql
)
return
all_data
def
update_order_data
(
self
):
""" 从数据库中拉取最新订单数据并保存 """
all_data
=
self
.
_fetch_data_from_db
()
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_order_info.xlsx'
),
index
=
None
)
def
update_test_order_data
(
self
,
conditions
):
""" 从数据库中拉取指定条件订单用于测试 """
all_data
=
self
.
_fetch_data_from_db
(
conditions
)
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_order_info.xlsx'
),
index
=
None
)
def
load_raw_data
(
self
):
def
load_raw_data
(
self
):
df
=
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_order_info.xlsx'
))
df
=
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_order_info.xlsx'
),
dtype
=
str
)
return
df
def
load_test_order_data
(
self
):
df
=
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_order_info.xlsx'
),
dtype
=
str
)
return
df
return
df
...
@@ -49,7 +87,7 @@ class OrderDataManager():
...
@@ -49,7 +87,7 @@ class OrderDataManager():
latest_time
=
max
([
info
[
1
]
for
info
in
infos
])
latest_time
=
max
([
info
[
1
]
for
info
in
infos
])
supplier_values
.
append
([
supplier_id
,
value
,
latest_time
])
supplier_values
.
append
([
supplier_id
,
value
,
latest_time
])
index
[
uid
]
=
sorted
(
supplier_values
,
key
=
lambda
x
:
(
x
[
1
],
x
[
2
]),
reverse
=
True
)
index
[
uid
]
=
sorted
(
supplier_values
,
key
=
lambda
x
:
(
x
[
2
],
x
[
1
]),
reverse
=
True
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_index.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_index.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
...
...
src/core/profile/__init__.py
View file @
cfdc63bd
This diff is collapsed.
Click to expand it.
src/core/profile_manager.py
View file @
cfdc63bd
...
@@ -15,16 +15,49 @@ class ProfileManager():
...
@@ -15,16 +15,49 @@ class ProfileManager():
订单用户画像数据管理
订单用户画像数据管理
"""
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
client
=
None
)
->
None
:
self
.
local_file_dir
=
get_data_path
()
self
.
local_file_dir
=
get_data_path
()
self
.
profile_file_path
=
os
.
path
.
join
(
self
.
local_file_dir
,
'all_profile.json'
)
self
.
profile_file_path
=
os
.
path
.
join
(
self
.
local_file_dir
,
'all_profile.json'
)
self
.
client
=
client
self
.
logger
=
logging
.
getLogger
(
__name__
)
self
.
logger
=
logging
.
getLogger
(
__name__
)
def
_fetch_data_from_db
(
self
,
conditions
=
None
):
if
self
.
client
is
None
:
self
.
logger
.
error
(
'未连接数据库'
)
raise
condition_sql
=
''
if
conditions
:
condition_sql
=
' WHERE '
+
' AND '
.
join
(
conditions
)
sql
=
'SELECT * FROM ads.ads_register_user_profiles'
sql
+=
' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'
.
format
(
condition_sql
)
_
,
all_data
=
self
.
client
.
query
(
sql
)
return
all_data
def
update_profile
(
self
):
""" 从数据库中拉取最新画像特征并保存 """
all_data
=
self
.
_fetch_data_from_db
()
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_profile.xlsx'
),
index
=
None
)
def
update_test_profile
(
self
,
conditions
):
""" 从数据库中拉取指定条件画像信息用于测试 """
all_data
=
self
.
_fetch_data_from_db
(
conditions
)
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_profile.xlsx'
),
index
=
None
)
def
_load_profile_data
(
self
):
def
_load_profile_data
(
self
):
return
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_profile.xlsx'
))
return
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_profile.xlsx'
),
dtype
=
str
)
# with open(self.profile_file_path, 'r', encoding='utf-8') as f:
# return json.load(f)
def
load_test_profile_data
(
self
):
return
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_profile.xlsx'
),
dtype
=
str
)
def
profile_to_embedding
(
self
,
profile
):
def
profile_to_embedding
(
self
,
profile
):
...
@@ -36,6 +69,18 @@ class ProfileManager():
...
@@ -36,6 +69,18 @@ class ProfileManager():
embedding
.
extend
(
converter
.
convert
(
profile
[
name
]))
embedding
.
extend
(
converter
.
convert
(
profile
[
name
]))
return
embedding
return
embedding
def
embedding_to_profile
(
self
,
embedding
):
"""
向量转换为用户画像
"""
ret
=
{}
si
=
0
for
[
name
,
converter
]
in
profile_converters
:
ei
=
si
+
converter
.
dim
ret
[
name
]
=
converter
.
inconvert
(
embedding
[
si
:
ei
])
si
=
ei
return
ret
def
make_embeddings
(
self
):
def
make_embeddings
(
self
):
user_profiles
=
self
.
_load_profile_data
()
user_profiles
=
self
.
_load_profile_data
()
...
@@ -61,10 +106,33 @@ class ProfileManager():
...
@@ -61,10 +106,33 @@ class ProfileManager():
user_ids
=
[]
user_ids
=
[]
embeddings
=
[]
embeddings
=
[]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings_ids.txt'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
user_ids
=
[
line
.
strip
()
for
line
in
f
]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings.json'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
embeddings
=
json
.
load
(
f
)
v_embedding_set
=
{}
for
user_id
,
embedding
in
zip
(
user_ids
,
embeddings
):
key
=
'_'
.
join
(
map
(
str
,
embedding
))
if
key
not
in
v_embedding_set
:
v_embedding_set
[
key
]
=
{
'embedding'
:
embedding
,
'user_ids'
:
[],
}
v_embedding_set
[
key
][
'user_ids'
]
.
append
(
str
(
user_id
))
v_embedding_list
=
[]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'virtual_user_embeddings_ids.txt'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
for
info
in
v_embedding_set
.
values
():
f
.
write
(
','
.
join
(
info
[
'user_ids'
])
+
'
\n
'
)
v_embedding_list
.
append
(
info
[
'embedding'
])
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'virtual_user_embeddings.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
v_embedding_list
,
f
,
ensure_ascii
=
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
manager
=
ProfileManager
()
manager
=
ProfileManager
()
manager
.
make_embeddings
()
# manager.make_embeddings()
# manager.update_local_data()
manager
.
make_virtual_embedding
()
# print(manager.make_index())
\ No newline at end of file
\ No newline at end of file
src/core/recommender.py
0 → 100644
View file @
cfdc63bd
# -*- coding: utf-8 -*-
import
os
import
json
from
typing
import
List
,
Dict
import
faiss
import
numpy
as
np
from
ydl_ai_recommender.src.core.profile_manager
import
ProfileManager
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.utils
import
get_conf_path
,
get_data_path
class
Recommender
():
def
__init__
(
self
)
->
None
:
pass
def
recommend
(
self
,
user
)
->
List
:
raise
NotImplemented
class
UserCFRecommender
(
Recommender
):
def
__init__
(
self
,
top_n
=
5
,
k
=
5
,
is_lazy
=
True
)
->
None
:
super
()
.
__init__
()
# 召回 top_n 个相似用户
self
.
top_n
=
top_n
# 每个召回的用户取 k 个相关咨询师
self
.
k
=
k
if
is_lazy
is
False
:
self
.
client
=
MySQLClient
.
create_from_config_file
(
get_conf_path
())
self
.
manager
=
ProfileManager
()
self
.
local_file_dir
=
get_data_path
()
self
.
load_data
()
def
load_data
(
self
):
order_user_embedding
=
[]
order_user_ids
=
[]
order_user_counselor_index
=
{}
default_counselor
=
[]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings_ids.txt'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
order_user_ids
=
[
line
.
strip
()
for
line
in
f
]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings.json'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
order_user_embedding
=
json
.
load
(
f
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_index.json'
),
encoding
=
'utf-8'
)
as
f
:
order_user_counselor_index
=
json
.
load
(
f
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'top50_supplier.txt'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
default_counselor
=
[
line
.
strip
()
for
line
in
f
]
self
.
order_user_embedding
=
order_user_embedding
self
.
order_user_ids
=
order_user_ids
self
.
order_user_counselor_index
=
order_user_counselor_index
self
.
default_counselor
=
default_counselor
self
.
index
=
faiss
.
IndexFlatL2
(
len
(
self
.
order_user_embedding
[
0
]))
self
.
index
.
add
(
np
.
array
(
self
.
order_user_embedding
))
def
get_user_profile
(
self
,
user_id
):
sql
=
'SELECT * FROM ads.ads_register_user_profiles'
sql
+=
' WHERE uid={}'
.
format
(
user_id
)
_
,
all_data
=
self
.
client
.
query
(
sql
)
if
len
(
all_data
)
==
0
:
return
[]
return
all_data
[
0
]
def
user_token
(
self
,
user_profile
):
return
self
.
manager
.
profile_to_embedding
(
user_profile
)
def
_recommend
(
self
,
user_embedding
):
D
,
I
=
self
.
index
.
search
(
np
.
array
([
user_embedding
]),
self
.
k
)
counselors
=
[]
for
idx
,
score
in
zip
(
I
[
0
],
D
[
0
]):
# 相似用户uid
similar_user_id
=
self
.
order_user_ids
[
idx
]
similar_user_counselor
=
self
.
order_user_counselor_index
.
get
(
similar_user_id
,
[])
recommend_data
=
[{
'counselor'
:
str
(
user
[
0
]),
'score'
:
float
(
score
),
'from'
:
'similar_users {}'
.
format
(
similar_user_id
),
}
for
user
in
similar_user_counselor
[:
self
.
top_n
]]
counselors
.
extend
(
recommend_data
)
return
counselors
def
recommend_with_profile
(
self
,
user_profile
):
user_embedding
=
self
.
user_token
(
user_profile
)
counselors
=
self
.
_recommend
(
user_embedding
)
return
counselors
def
recommend
(
self
,
user_id
):
"""
根据用户画像,推荐咨询师
若获取不到用户画像,推荐默认咨询师(订单最多的)
"""
user_profile
=
self
.
get_user_profile
(
user_id
)
if
not
user_profile
:
return
[]
return
self
.
recommend_with_profile
(
user_profile
)
if
__name__
==
'__main__'
:
recommender
=
UserCFRecommender
()
print
(
recommender
.
recommend
(
'10957910'
))
\ No newline at end of file
src/core/test.py
0 → 100644
View file @
cfdc63bd
# -*- coding: utf-8 -*-
import
re
import
json
import
logging
import
argparse
from
datetime
import
datetime
,
timedelta
from
ydl_ai_recommender.src.core.order_data_manager
import
OrderDataManager
from
ydl_ai_recommender.src.core.profile_manager
import
ProfileManager
from
ydl_ai_recommender.src.core.recommender
import
UserCFRecommender
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.utils
import
get_conf_path
logging
.
basicConfig
(
format
=
'
%(asctime)
s -
%(levelname)
s -
%(name)
s -
%(message)
s'
,
datefmt
=
'
%
m/
%
d/
%
Y
%
H:
%
M:
%
S'
,
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
def
main
(
args
):
# 构建用户画像字典,不用每次都从数据库中获取
profile_manager
=
ProfileManager
()
df
=
profile_manager
.
load_test_profile_data
()
user_profile_dict
=
{}
for
_
,
row
in
df
.
iterrows
():
user_profile_dict
[
row
[
'uid'
]]
=
row
manager
=
OrderDataManager
()
# 加载训练订单数据,为后面判断用户是否为新用户
train_orders
=
manager
.
load_raw_data
()
old_users
=
set
(
train_orders
[
'uid'
])
logger
.
info
(
'订单用户数
%
s '
,
len
(
old_users
))
# 加载测试数据
test_orders
=
manager
.
load_test_order_data
()
logger
.
info
(
'加载测试数据成功,共加载
%
s 条'
,
len
(
test_orders
))
recommender
=
UserCFRecommender
(
top_n
=
args
.
top_n
,
k
=
args
.
k
)
result_detail
=
[]
for
index
,
order_info
in
test_orders
.
iterrows
():
if
args
.
max_test
>
0
:
if
index
>=
args
.
max_test
:
break
uid
=
order_info
[
'uid'
]
profile
=
user_profile_dict
.
get
(
uid
)
if
profile
is
None
:
continue
recommend_result
=
recommender
.
recommend_with_profile
(
profile
)
recall_resons
=
[]
for
rr
in
recommend_result
:
if
rr
[
'counselor'
]
==
order_info
[
'supplier_id'
]:
recall_resons
.
append
(
rr
[
'from'
])
result_detail
.
append
({
'uid'
:
uid
,
'supplier_id'
:
order_info
[
'supplier_id'
],
'is_old_user'
:
uid
in
old_users
,
'recall_counselors'
:
recommend_result
,
'is_recall'
:
len
(
recall_resons
)
>
0
,
'recall_reason'
:
'|'
.
join
(
recall_resons
),
})
# 结果报告
metrics
=
{
'all_test_cnt'
:
0
,
'all_recall_cnt'
:
0
,
'old_user_test_cnt'
:
0
,
'old_user_recall_cnt'
:
0
,
'new_user_test_cnt'
:
0
,
'new_user_recall_cnt'
:
0
,
'same_user_recall_cnt'
:
0
,
'similar_user_recall_cnt'
:
0
,
}
for
rd
in
result_detail
:
metrics
[
'all_test_cnt'
]
+=
1
if
rd
[
'is_recall'
]:
metrics
[
'all_recall_cnt'
]
+=
1
is_same_user
,
is_similar_user
=
False
,
False
for
counselor
in
rd
[
'recall_counselors'
]:
from_id
=
counselor
[
'from'
]
.
split
(
' '
)[
1
]
if
from_id
==
rd
[
'uid'
]:
is_same_user
=
True
if
from_id
!=
rd
[
'uid'
]:
is_similar_user
=
True
if
is_same_user
:
metrics
[
'same_user_recall_cnt'
]
+=
1
if
is_similar_user
:
metrics
[
'similar_user_recall_cnt'
]
+=
1
if
rd
[
'is_old_user'
]:
metrics
[
'old_user_test_cnt'
]
+=
1
if
rd
[
'is_recall'
]:
metrics
[
'old_user_recall_cnt'
]
+=
1
else
:
metrics
[
'new_user_test_cnt'
]
+=
1
if
rd
[
'is_recall'
]:
metrics
[
'new_user_recall_cnt'
]
+=
1
logger
.
info
(
'=='
*
20
+
' 测试结果 '
+
'=='
*
20
)
logger
.
info
(
'--'
*
45
)
logger
.
info
(
'{:<10}{:<10}{:<10}{:<10}'
.
format
(
''
,
'样本数'
,
'召回数'
,
'召回率'
))
logger
.
info
(
'{:<10}{:<10}{:<10}{:<10.2
%
}'
.
format
(
'整体 '
,
metrics
[
'all_test_cnt'
],
metrics
[
'all_recall_cnt'
],
metrics
[
'all_recall_cnt'
]
/
metrics
[
'all_test_cnt'
]))
logger
.
info
(
'{:<10}{:<10}{:<10}{:<10.2
%
}'
.
format
(
'老用户'
,
metrics
[
'old_user_test_cnt'
],
metrics
[
'old_user_recall_cnt'
],
metrics
[
'old_user_recall_cnt'
]
/
metrics
[
'old_user_test_cnt'
]))
logger
.
info
(
'{:<10}{:<10}{:<10}{:<10.2
%
}'
.
format
(
'新用户'
,
metrics
[
'new_user_test_cnt'
],
metrics
[
'new_user_recall_cnt'
],
metrics
[
'new_user_recall_cnt'
]
/
metrics
[
'new_user_test_cnt'
]))
logger
.
info
(
''
)
logger
.
info
(
'--'
*
45
)
logger
.
info
(
'用户自己召回数 {} 占总召回比例 {:.2
%
}'
.
format
(
metrics
[
'same_user_recall_cnt'
],
metrics
[
'same_user_recall_cnt'
]
/
metrics
[
'all_recall_cnt'
]))
logger
.
info
(
'相似用户召回数 {} 占总召回比例 {:.2
%
}'
.
format
(
metrics
[
'similar_user_recall_cnt'
],
metrics
[
'similar_user_recall_cnt'
]
/
metrics
[
'all_recall_cnt'
]))
with
open
(
'result_detail.json'
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
result_detail
,
f
,
ensure_ascii
=
False
,
indent
=
2
)
def
update_test_data
(
args
):
start_date
=
''
if
re
.
match
(
r'20\d\d-[01]\d-\d\d'
,
args
.
start_date
):
start_date
=
args
.
start_date
elif
re
.
match
(
r'-\d+'
,
args
.
start_date
):
now
=
datetime
.
now
()
start_date
=
(
now
-
timedelta
(
days
=
int
(
args
.
start_date
[
1
:])))
.
strftime
(
'
%
Y-
%
m-
%
d'
)
else
:
logger
.
error
(
'args.start_date 参数格式错误,
%
s'
,
args
.
start_date
)
raise
conditions
=
[
'create_time >= "{}"'
.
format
(
start_date
)]
client
=
MySQLClient
.
create_from_config_file
(
get_conf_path
())
# 订单数据
manager
=
OrderDataManager
(
client
)
manager
.
update_test_order_data
(
conditions
=
conditions
)
# 用户画像数据
manager
=
ProfileManager
(
client
)
manager
.
update_test_profile
(
conditions
=
conditions
)
logger
.
info
(
'测试数据更新完成'
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--k'
,
default
=
5
,
type
=
int
,
help
=
'召回相似用户的数量'
)
parser
.
add_argument
(
'--top_n'
,
default
=
5
,
type
=
int
,
help
=
'每个相似用户召回的咨询师数量'
)
parser
.
add_argument
(
'--max_test'
,
default
=
0
,
type
=
int
,
help
=
'最多测试数据量'
)
parser
.
add_argument
(
'--do_update_test_data'
,
default
=
False
,
action
=
'store_true'
,
help
=
'是否更新测试数据'
)
parser
.
add_argument
(
'--start_date'
,
default
=
'-1'
,
type
=
str
,
help
=
'测试订单创建的开始时间,可以是"
%
Y-
%
m-
%
d"格式,也可以是 -3 表示前3天'
)
args
=
parser
.
parse_args
()
if
args
.
do_update_test_data
:
logger
.
info
(
'更新测试数据'
)
update_test_data
(
args
)
main
(
args
)
\ No newline at end of file
src/core/user_similarity.py
View file @
cfdc63bd
...
@@ -8,6 +8,13 @@ from itertools import combinations
...
@@ -8,6 +8,13 @@ from itertools import combinations
from
ydl_ai_recommender.src.utils
import
get_data_path
from
ydl_ai_recommender.src.utils
import
get_data_path
logging
.
basicConfig
(
format
=
'
%(asctime)
s -
%(levelname)
s -
%(name)
s -
%(message)
s'
,
datefmt
=
'
%
m/
%
d/
%
Y
%
H:
%
M:
%
S'
,
level
=
logging
.
INFO
)
class
BasicUserSimilarity
():
class
BasicUserSimilarity
():
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
...
@@ -24,7 +31,7 @@ class BasicUserSimilarity():
...
@@ -24,7 +31,7 @@ class BasicUserSimilarity():
counselor_user_index
=
{}
counselor_user_index
=
{}
for
user
,
counselors
in
user_counselor_index
.
items
():
for
user
,
counselors
in
user_counselor_index
.
items
():
user_like_set
[
user
]
=
len
(
counselors
)
user_like_set
[
user
]
=
set
([
c
[
0
]
for
c
in
counselors
]
)
for
[
counselor
,
_
,
_
]
in
counselors
:
for
[
counselor
,
_
,
_
]
in
counselors
:
if
counselor
not
in
counselor_user_index
:
if
counselor
not
in
counselor_user_index
:
...
@@ -32,19 +39,37 @@ class BasicUserSimilarity():
...
@@ -32,19 +39,37 @@ class BasicUserSimilarity():
counselor_user_index
[
counselor
]
.
append
(
user
)
counselor_user_index
[
counselor
]
.
append
(
user
)
# 两个用户与同一个咨询师有订单,就认为两个用户相似
# 两个用户与同一个咨询师有订单,就认为两个用户相似
self
.
logger
.
info
(
'开始构建用户相似性关系'
)
self
.
logger
.
info
(
'开始构建用户
之间
相似性关系'
)
relations
=
{}
relations
=
{}
user_index
=
{}
for
users
in
counselor_user_index
.
values
():
for
users
in
counselor_user_index
.
values
():
for
[
_u1
,
_u2
]
in
combinations
(
users
,
2
):
for
[
_u1
,
_u2
]
in
combinations
(
users
,
2
):
u1
,
u2
=
min
(
_u1
,
_u2
),
max
(
_u1
,
_u2
)
u1
,
u2
=
min
(
_u1
,
_u2
),
max
(
_u1
,
_u2
)
key
=
'{}_{}'
.
format
(
u1
,
u2
)
key
=
'{}_{}'
.
format
(
u1
,
u2
)
if
key
in
relations
:
if
key
in
relations
:
continue
continue
relations
[
key
]
=
1.0
/
(
user_like_set
[
u1
]
*
user_like_set
[
u2
])
sim
=
len
(
user_like_set
[
u1
]
&
user_like_set
[
u2
])
/
(
len
(
user_like_set
[
u1
])
*
len
(
user_like_set
[
u2
]))
relations
[
key
]
=
sim
if
u1
not
in
user_index
:
user_index
[
u1
]
=
{}
if
u2
not
in
user_index
:
user_index
[
u2
]
=
{}
if
u2
not
in
user_index
[
u1
]:
user_index
[
u1
][
u2
]
=
sim
if
u1
not
in
user_index
[
u2
]:
user_index
[
u2
][
u1
]
=
sim
user_counselor_index
=
{}
self
.
logger
.
info
(
'用户相似性关系构建完成,共有
%
s 对关系'
,
len
(
relations
))
self
.
logger
.
info
(
'用户相似性关系构建完成,共有
%
s 对关系'
,
len
(
relations
))
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_similarity.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_similarity.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
relations
,
f
,
ensure_ascii
=
False
,
indent
=
2
)
json
.
dump
(
relations
,
f
,
ensure_ascii
=
False
,
indent
=
2
)
def
recall
(
self
,
user
,
N
=
10
,
top_k
=
10
):
pass
bs
=
BasicUserSimilarity
()
bs
=
BasicUserSimilarity
()
...
...
src/data/mysql_client.py
View file @
cfdc63bd
...
@@ -21,11 +21,26 @@ class MySQLClient():
...
@@ -21,11 +21,26 @@ class MySQLClient():
self
.
cursor
=
self
.
connection
.
cursor
()
self
.
cursor
=
self
.
connection
.
cursor
()
self
.
_log_info
(
'数据库连接成功'
)
self
.
_log_info
(
'数据库连接成功'
)
def
_connect
(
self
):
self
.
connection
=
pymysql
.
connect
(
host
=
self
.
host
,
port
=
self
.
port
,
user
=
self
.
user
,
password
=
self
.
password
,
charset
=
'utf8mb4'
,
cursorclass
=
pymysql
.
cursors
.
DictCursor
)
self
.
cursor
=
self
.
connection
.
cursor
()
self
.
_log_info
(
'数据库连接成功'
)
def
_log_info
(
self
,
text
,
*
args
,
**
params
):
def
_log_info
(
self
,
text
,
*
args
,
**
params
):
if
self
.
logger
:
if
self
.
logger
:
self
.
logger
.
info
(
text
,
*
args
,
**
params
)
self
.
logger
.
info
(
text
,
*
args
,
**
params
)
def
query
(
self
,
sql
):
def
query
(
self
,
sql
):
if
self
.
cursor
is
None
:
self
.
_connect
()
sql
+=
' limit 1000'
sql
+=
' limit 1000'
self
.
_log_info
(
'begin execute sql:
%
s'
,
sql
)
self
.
_log_info
(
'begin execute sql:
%
s'
,
sql
)
row_count
=
self
.
cursor
.
execute
(
sql
)
row_count
=
self
.
cursor
.
execute
(
sql
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment