Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
Y
ydl_ai_recommender
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
闫发泽
ydl_ai_recommender
Commits
cfdc63bd
Commit
cfdc63bd
authored
Dec 08, 2022
by
柴鹏飞
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
基于 u2u 的推荐
parent
6e946e5a
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
710 additions
and
80 deletions
+710
-80
environment.yaml
environment.yaml
+1
-1
db_data_manager.py
src/core/db_data_manager.py
+36
-2
main.py
src/core/main.py
+11
-3
order_data_manager.py
src/core/order_data_manager.py
+41
-3
__init__.py
src/core/profile/__init__.py
+207
-60
profile_manager.py
src/core/profile_manager.py
+76
-8
recommender.py
src/core/recommender.py
+118
-0
test.py
src/core/test.py
+176
-0
user_similarity.py
src/core/user_similarity.py
+29
-3
mysql_client.py
src/data/mysql_client.py
+15
-0
No files found.
environment.yaml
View file @
cfdc63bd
...
@@ -4,7 +4,7 @@ channels:
...
@@ -4,7 +4,7 @@ channels:
-
pytorch
-
pytorch
-
defaults
-
defaults
dependencies
:
dependencies
:
-
python==3.
8
-
python==3.
9
-
ipykernel
-
ipykernel
-
faiss-cpu
-
faiss-cpu
-
pip
-
pip
...
...
src/core/db_data_manager.py
View file @
cfdc63bd
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
os
import
os
import
json
import
json
import
logging
import
logging
from
datetime
import
datetime
,
timedelta
import
pandas
as
pd
import
pandas
as
pd
...
@@ -35,7 +36,7 @@ class DBDataManager():
...
@@ -35,7 +36,7 @@ class DBDataManager():
self
.
logger
.
info
(
'开始保存
%
s 到本地'
,
name
)
self
.
logger
.
info
(
'开始保存
%
s 到本地'
,
name
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
name
),
mode
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
name
),
mode
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
'
\n
'
.
join
(
lines
))
f
.
write
(
'
\n
'
.
join
(
lines
))
self
.
logger
.
info
(
'
%
s 保存成功,共保存
%
s 行
内人
'
,
name
,
len
(
lines
))
self
.
logger
.
info
(
'
%
s 保存成功,共保存
%
s 行
数据
'
,
name
,
len
(
lines
))
def
_save_json_data
(
self
,
data
,
name
):
def
_save_json_data
(
self
,
data
,
name
):
...
@@ -63,10 +64,42 @@ class DBDataManager():
...
@@ -63,10 +64,42 @@ class DBDataManager():
_
,
all_data
=
self
.
client
.
query
(
sql
)
_
,
all_data
=
self
.
client
.
query
(
sql
)
df
=
pd
.
DataFrame
(
all_data
)
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_order_info.xlsx'
),
index
=
None
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_order_info.xlsx'
),
index
=
None
)
# self._save_json_data(all_data, 'all_order_info.json')
def
_load_order_data
(
self
,
conditions
=
None
):
select_fields
=
[
'main_order_id'
,
'uid'
,
'supplier_id'
,
'price'
,
'standard_order_type'
]
select_fields
.
append
(
'DATE_FORMAT(update_time, "
%
Y-
%
m-
%
d") AS update_time'
)
sql
=
'SELECT {} FROM ods.ods_ydl_standard_order'
.
format
(
', '
.
join
(
select_fields
))
if
conditions
:
sql
+=
' WHERE {}'
.
format
(
'AND '
.
join
(
conditions
))
self
.
logger
.
info
(
'开始执行sql
%
s'
,
sql
)
cnt
,
data
=
self
.
client
.
query
(
sql
)
self
.
logger
.
info
(
'sql执行成功,共获取
%
s 条数据'
,
cnt
)
return
data
def
load_test_data
(
self
,
days
=
5
):
now
=
datetime
.
now
()
start_time
=
now
-
timedelta
(
days
=
days
)
conditions
=
[
'create_time >= "{}"'
.
format
(
start_time
.
strftime
(
'
%
Y-
%
m-
%
d'
)),
]
order_data
=
self
.
_load_order_data
(
conditions
=
conditions
)
df
=
pd
.
DataFrame
(
order_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_order_info.xlsx'
),
index
=
None
)
select_fields
=
[
'*'
]
sql
=
'SELECT {} FROM ads.ads_register_user_profiles'
.
format
(
', '
.
join
(
select_fields
))
sql
+=
' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order WHERE create_time >= "{}")'
.
format
(
start_time
.
strftime
(
'
%
Y-
%
m-
%
d'
))
_
,
all_data
=
self
.
client
.
query
(
sql
)
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_profile.xlsx'
),
index
=
None
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
manager
=
DBDataManager
()
manager
=
DBDataManager
()
manager
.
load_test_data
()
# manager.update_local_data()
# manager.update_local_data()
# print(manager.make_index())
# print(manager.make_index())
\ No newline at end of file
src/core/main.py
View file @
cfdc63bd
...
@@ -29,10 +29,12 @@ parser.add_argument('--index_last_date', default=None, type=str, help='构建索
...
@@ -29,10 +29,12 @@ parser.add_argument('--index_last_date', default=None, type=str, help='构建索
parser
.
add_argument
(
parser
.
add_argument
(
'-t'
,
'--task'
,
type
=
str
,
required
=
True
,
'-t'
,
'--task'
,
type
=
str
,
required
=
True
,
choices
=
(
'load_db_data'
,
'make_profile_index'
),
help
=
'执行任务名称'
choices
=
(
'load_db_data'
,
'make_profile_index'
,
'do_test'
),
help
=
'执行任务名称'
)
)
parser
.
add_argument
(
'--output_dir'
,
default
=
'outputs'
,
type
=
str
,
help
=
'模型训练中间结果和训练好的模型保存目录'
)
parser
.
add_argument
(
'--test_start_date'
,
default
=
'-3'
,
type
=
str
,
help
=
'测试任务 - 开始日期'
)
parser
.
add_argument
(
'--max_seq_length'
,
default
=
128
,
type
=
int
,
help
=
'tokenization 之后序列最大长度。超过会被截断,小于会补齐'
)
parser
.
add_argument
(
'--test_end_date'
,
default
=
'0'
,
type
=
str
,
help
=
'测试任务 - 结束日期'
)
parser
.
add_argument
(
'--batch_size'
,
default
=
128
,
type
=
int
,
help
=
'训练时一个 batch 包含多少条数据'
)
parser
.
add_argument
(
'--batch_size'
,
default
=
128
,
type
=
int
,
help
=
'训练时一个 batch 包含多少条数据'
)
parser
.
add_argument
(
'--learning_rate'
,
default
=
1e-3
,
type
=
float
,
help
=
'Adam 优化器的学习率'
)
parser
.
add_argument
(
'--learning_rate'
,
default
=
1e-3
,
type
=
float
,
help
=
'Adam 优化器的学习率'
)
parser
.
add_argument
(
'--add_special_tokens'
,
default
=
True
,
type
=
bool
,
help
=
'bert encode 时前后是否添加特殊token'
)
parser
.
add_argument
(
'--add_special_tokens'
,
default
=
True
,
type
=
bool
,
help
=
'bert encode 时前后是否添加特殊token'
)
...
@@ -46,6 +48,7 @@ args = parser.parse_args()
...
@@ -46,6 +48,7 @@ args = parser.parse_args()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
args
.
task
==
'load_db_data'
:
if
args
.
task
==
'load_db_data'
:
# 从数据库中导出信息
manager
=
DBDataManager
()
manager
=
DBDataManager
()
manager
.
update_order_info
()
manager
.
update_order_info
()
manager
.
update_profile
()
manager
.
update_profile
()
...
@@ -53,3 +56,7 @@ if __name__ == '__main__':
...
@@ -53,3 +56,7 @@ if __name__ == '__main__':
if
args
.
task
==
'make_profile_index'
:
if
args
.
task
==
'make_profile_index'
:
manager
=
ProfileManager
()
manager
=
ProfileManager
()
manager
.
make_embeddings
()
manager
.
make_embeddings
()
manager
.
make_virtual_embedding
()
if
args
.
task
==
'make_similarity'
:
pass
\ No newline at end of file
src/core/order_data_manager.py
View file @
cfdc63bd
...
@@ -11,13 +11,51 @@ from ydl_ai_recommender.src.utils import get_data_path
...
@@ -11,13 +11,51 @@ from ydl_ai_recommender.src.utils import get_data_path
class
OrderDataManager
():
class
OrderDataManager
():
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
client
=
None
)
->
None
:
self
.
local_file_dir
=
get_data_path
()
self
.
local_file_dir
=
get_data_path
()
self
.
client
=
client
self
.
logger
=
logging
.
getLogger
(
__name__
)
self
.
logger
=
logging
.
getLogger
(
__name__
)
def
_fetch_data_from_db
(
self
,
conditions
=
None
):
if
self
.
client
is
None
:
self
.
logger
.
error
(
'未连接数据库'
)
raise
condition_sql
=
''
if
conditions
:
condition_sql
=
' WHERE '
+
' AND '
.
join
(
conditions
)
select_fields
=
[
'main_order_id'
,
'uid'
,
'supplier_id'
,
'price'
,
'standard_order_type'
]
select_fields
.
append
(
'DATE_FORMAT(update_time, "
%
Y-
%
m-
%
d") AS update_time'
)
sql
=
'SELECT {} FROM ods.ods_ydl_standard_order'
.
format
(
', '
.
join
(
select_fields
))
sql
+=
condition_sql
_
,
all_data
=
self
.
client
.
query
(
sql
)
return
all_data
def
update_order_data
(
self
):
""" 从数据库中拉取最新订单数据并保存 """
all_data
=
self
.
_fetch_data_from_db
()
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_order_info.xlsx'
),
index
=
None
)
def
update_test_order_data
(
self
,
conditions
):
""" 从数据库中拉取指定条件订单用于测试 """
all_data
=
self
.
_fetch_data_from_db
(
conditions
)
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_order_info.xlsx'
),
index
=
None
)
def
load_raw_data
(
self
):
def
load_raw_data
(
self
):
df
=
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_order_info.xlsx'
))
df
=
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_order_info.xlsx'
),
dtype
=
str
)
return
df
def
load_test_order_data
(
self
):
df
=
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_order_info.xlsx'
),
dtype
=
str
)
return
df
return
df
...
@@ -49,7 +87,7 @@ class OrderDataManager():
...
@@ -49,7 +87,7 @@ class OrderDataManager():
latest_time
=
max
([
info
[
1
]
for
info
in
infos
])
latest_time
=
max
([
info
[
1
]
for
info
in
infos
])
supplier_values
.
append
([
supplier_id
,
value
,
latest_time
])
supplier_values
.
append
([
supplier_id
,
value
,
latest_time
])
index
[
uid
]
=
sorted
(
supplier_values
,
key
=
lambda
x
:
(
x
[
1
],
x
[
2
]),
reverse
=
True
)
index
[
uid
]
=
sorted
(
supplier_values
,
key
=
lambda
x
:
(
x
[
2
],
x
[
1
]),
reverse
=
True
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_index.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_index.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
...
...
src/core/profile/__init__.py
View file @
cfdc63bd
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
json
import
json
from
typing
import
Dict
,
List
,
Any
from
typing
import
Dict
,
List
,
Any
,
Union
import
pandas
as
pd
import
pandas
as
pd
from
.country_code_profile
import
CountryCodeProfile
#
from .country_code_profile import CountryCodeProfile
from
.profile
import
ChannelIdTypeProfile
#
from .profile import ChannelIdTypeProfile
class
FfromLoginProfile
():
class
BaseProfile
():
def
__init__
(
self
)
->
None
:
self
.
dim
:
int
=
0
def
convert
(
self
,
value
):
raise
NotImplemented
def
inconvert
(
self
,
embedding
:
List
[
Union
[
int
,
float
]])
->
str
:
raise
NotImplemented
class
CountryCodeProfile
(
BaseProfile
):
def
__init__
(
self
)
->
None
:
super
()
.
__init__
()
self
.
dim
=
3
def
convert
(
self
,
value
):
try
:
value
=
int
(
value
)
except
Exception
as
e
:
return
[
0
,
0
,
1
]
if
value
==
86
:
return
[
1
,
0
,
0
]
else
:
return
[
0
,
1
,
0
]
def
inconvert
(
self
,
embedding
):
if
embedding
[
0
]
==
1
:
return
'china'
elif
embedding
[
1
]
==
2
:
return
'abroad'
else
:
return
'unknown'
class
ChannelIdTypeProfile
(
BaseProfile
):
def
__init__
(
self
)
->
None
:
super
()
.
__init__
()
self
.
dim
=
3
def
convert
(
self
,
value
):
try
:
value
=
int
(
value
)
except
Exception
as
e
:
return
[
0
,
0
,
1
]
if
value
==
1
:
return
[
1
,
0
,
0
]
elif
value
==
2
:
return
[
0
,
1
,
0
]
else
:
return
[
0
,
0
,
1
]
def
inconvert
(
self
,
embedding
):
if
embedding
[
0
]
==
1
:
return
'ios'
elif
embedding
[
1
]
==
2
:
return
'android'
else
:
return
'other'
class
FfromLoginProfile
(
BaseProfile
):
"""
"""
登录来源
登录来源
主要是不同android品牌特征,android和ios特征在其他画像中已有,这里不重复构建
主要是不同android品牌特征,android和ios特征在其他画像中已有,这里不重复构建
"""
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
pass
super
()
.
__init__
()
self
.
brand_list
=
[
'huawei'
,
'vivo'
,
'oppo'
,
'xiaomi'
]
self
.
dim
=
len
(
self
.
brand_list
)
+
1
def
convert
(
self
,
value
):
def
convert
(
self
,
value
):
ret
=
[
0
,
0
,
0
,
0
,
0
]
ret
=
[
0
,
0
,
0
,
0
,
0
]
...
@@ -24,25 +89,30 @@ class FfromLoginProfile():
...
@@ -24,25 +89,30 @@ class FfromLoginProfile():
except
Exception
as
e
:
except
Exception
as
e
:
return
ret
return
ret
if
'huawei'
in
value
:
for
i
,
v
in
enumerate
(
self
.
brand_list
):
ret
[
0
]
=
1
if
v
in
value
:
elif
'vivo'
in
value
:
ret
[
i
]
=
1
ret
[
1
]
=
1
break
elif
'oppo'
in
value
:
ret
[
2
]
=
1
elif
'xiaomi'
in
value
:
ret
[
3
]
=
1
else
:
else
:
ret
[
4
]
=
1
ret
[
self
.
dim
-
1
]
=
1
return
ret
return
ret
def
inconvert
(
self
,
embedding
):
for
emb
,
name
in
zip
(
embedding
,
self
.
brand_list
):
if
emb
==
1
:
return
name
return
'other'
class
UserPreferenceCateProfile
():
class
UserPreferenceCateProfile
(
BaseProfile
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
cate_list
=
[
'个人成长'
,
'亲子教育'
,
'人际关系'
,
'婚姻家庭'
,
'心理健康'
,
'恋爱情感'
,
'情绪压力'
,
'职场发展'
]
super
()
.
__init__
()
self
.
dim
=
8
self
.
cate_list
=
[
'个人成长'
,
'亲子教育'
,
'人际关系'
,
'婚姻家庭'
,
'心理健康'
,
'恋爱情感'
,
'情绪压力'
,
'职场发展'
]
self
.
cate_index
=
{
self
.
cate_index
=
{
cate
:
index
for
index
,
cate
in
enumerate
(
cate_list
)
cate
:
index
for
index
,
cate
in
enumerate
(
self
.
cate_list
)
}
}
def
convert
(
self
,
value
):
def
convert
(
self
,
value
):
...
@@ -61,13 +131,21 @@ class UserPreferenceCateProfile():
...
@@ -61,13 +131,21 @@ class UserPreferenceCateProfile():
return
ret
return
ret
def
inconvert
(
self
,
embedding
):
ret
=
{}
for
emb
,
name
in
zip
(
embedding
,
self
.
cate_list
):
ret
[
name
]
=
emb
return
json
.
dumps
(
ret
,
ensure_ascii
=
False
)
class
NumClassProfile
():
class
NumClassProfile
(
BaseProfile
):
def
__init__
(
self
,
split_points
,
mode
=
'le'
)
->
None
:
def
__init__
(
self
,
split_points
,
mode
=
'le'
)
->
None
:
"""
"""
mode : le 前开后闭; be: 前闭后开;默认前开后闭
mode : le 前开后闭; be: 前闭后开;默认前开后闭
"""
"""
super
()
.
__init__
()
self
.
dim
=
len
(
split_points
)
+
1
self
.
split_points
=
split_points
self
.
split_points
=
split_points
self
.
mode
=
mode
self
.
mode
=
mode
...
@@ -94,31 +172,87 @@ class NumClassProfile():
...
@@ -94,31 +172,87 @@ class NumClassProfile():
return
ret
return
ret
class
AidiCstBiasPriceProfile
():
def
inconvert
(
self
,
embedding
):
ret
=
''
# 确保embedding中有包含1的值
if
embedding
.
count
(
1
)
==
0
:
return
''
index
=
embedding
.
index
(
1
)
if
index
==
0
:
if
self
.
mode
==
'be'
:
ret
=
'< {}'
.
format
(
self
.
split_points
[
0
])
else
:
ret
=
'<= {}'
.
format
(
self
.
split_points
[
0
])
elif
index
==
(
self
.
dim
-
1
):
if
self
.
mode
==
'be'
:
ret
=
'>= {}'
.
format
(
self
.
split_points
[
-
1
])
else
:
ret
=
'> {}'
.
format
(
self
.
split_points
[
-
1
])
else
:
if
self
.
mode
==
'be'
:
ret
=
'{} < {}'
.
format
(
self
.
split_points
[
index
-
1
],
self
.
split_points
[
index
])
else
:
ret
=
'{} <= {}'
.
format
(
self
.
split_points
[
index
-
1
],
self
.
split_points
[
index
])
return
ret
class
AidiCstBiasPriceProfile
(
BaseProfile
):
""" 用户偏好价格 """
def
__init__
(
self
)
->
None
:
super
()
.
__init__
()
self
.
dim
=
6
self
.
price_groups
=
[
'[0, 50)'
,
'[50, 100)'
,
'[100, 200)'
,
'[200, 500)'
,
'[500, 1000)'
,
'[1000, )'
,
]
def
convert
(
self
,
value
):
def
convert
(
self
,
value
):
ret
=
[
0
]
*
6
ret
=
[
0
]
*
6
if
pd
.
isnull
(
value
):
if
pd
.
isnull
(
value
):
return
ret
return
ret
for
v
in
value
:
json_object
=
json
.
loads
(
value
)
for
v
in
json_object
:
try
:
try
:
ret
[
v
[
'level'
]
-
1
]
=
1
ret
[
v
[
'level'
]
-
1
]
=
1
except
Exception
:
except
Exception
:
pass
pass
return
ret
return
ret
def
inconvert
(
self
,
embedding
):
ret
=
''
# 确保embedding中有包含1的值
if
embedding
.
count
(
1
)
==
0
:
return
''
class
MultiChoiceProfile
():
index
=
embedding
.
index
(
1
)
ret
=
self
.
price_groups
[
index
]
return
ret
class
MultiChoiceProfile
(
BaseProfile
):
def
__init__
(
self
,
option_dict
:
Dict
[
Any
,
int
])
->
None
:
def
__init__
(
self
,
option_dict
:
Dict
[
Any
,
int
])
->
None
:
super
()
.
__init__
()
self
.
dim
=
6
self
.
option_dict
=
option_dict
self
.
option_dict
=
option_dict
self
.
re_option_dict
=
{
v
:
k
for
k
,
v
in
self
.
option_dict
.
items
()}
def
convert
(
self
,
value
:
List
):
def
convert
(
self
,
value
:
List
):
ret
=
[
0
]
*
len
(
self
.
option_dict
)
ret
=
[
0
]
*
len
(
self
.
option_dict
)
if
pd
.
isnull
(
value
):
if
pd
.
isnull
(
value
):
return
ret
return
ret
value
=
json
.
loads
(
value
)
for
v
in
value
:
for
v
in
value
:
try
:
try
:
i
=
self
.
option_dict
[
v
]
i
=
self
.
option_dict
[
v
]
...
@@ -128,71 +262,84 @@ class MultiChoiceProfile():
...
@@ -128,71 +262,84 @@ class MultiChoiceProfile():
return
ret
return
ret
class
CityProfile
():
def
inconvert
(
self
,
embedding
):
ret
=
[]
for
index
,
emb
in
enumerate
(
embedding
):
if
emb
==
1
:
ret
.
append
(
self
.
re_option_dict
.
get
(
index
,
''
))
return
'[{}]'
.
format
(
', '
.
join
(
map
(
str
,
ret
)))
class
CityProfile
(
BaseProfile
):
""" 基于邮编的城市编码 """
def
__init__
(
self
,
level
=
2
)
->
None
:
"""
level: 级别,2-省/直辖市 ; 3-区; 4-区;6-投递区
"""
super
()
.
__init__
()
self
.
level
=
level
self
.
dim
=
self
.
level
*
10
def
__init__
(
self
)
->
None
:
self
.
default_city_codes
=
[
'330500'
,
'640200'
,
'130200'
,
'620200'
,
'321000'
,
'530600'
,
'650500'
,
'410200'
,
'511400'
,
'450200'
,
'610800'
,
'220400'
,
'430400'
,
'320500'
,
'410400'
,
'341100'
,
'420300'
,
'410500'
,
'640100'
,
'440100'
,
'420500'
,
'650200'
,
'441900'
,
'211200'
,
'210400'
,
'140700'
,
'131000'
,
'440700'
,
'340100'
,
'350200'
,
'371100'
,
'370900'
,
'130500'
,
'451300'
,
'331100'
,
'320700'
,
'710100'
,
'500000'
,
'610300'
,
'370100'
,
'610500'
,
'450300'
,
'520100'
,
'140200'
,
'320400'
,
'210500'
,
'440300'
,
'610700'
,
'341800'
,
'210300'
,
'340200'
,
'120100'
,
'340500'
,
'210200'
,
'222400'
,
'370600'
,
'110100'
,
'441200'
,
'230500'
,
'510100'
,
'330700'
,
'330600'
,
'370300'
,
'230600'
,
'450100'
,
'340300'
,
'651800'
,
'340800'
,
'430200'
,
'421300'
,
'220600'
,
'150200'
,
'433100'
,
'440500'
,
'620100'
,
'710200'
,
'130100'
,
'131100'
,
'150600'
,
'430100'
,
'150100'
,
'130400'
,
'140600'
,
'140300'
,
'410300'
,
'620600'
,
'330300'
,
'321100'
,
'320900'
,
'630100'
,
'320100'
,
'410900'
,
'510400'
,
'620800'
,
'610600'
,
'220300'
,
'420600'
,
'510700'
,
'130300'
,
'411400'
,
'310000'
,
'341200'
,
'370500'
,
'710500'
,
'231100'
,
'152900'
,
'371500'
,
'220100'
,
'360700'
,
'150500'
,
'331000'
,
'360600'
,
'371000'
,
'341600'
,
'130600'
,
'230100'
,
'410800'
,
'370700'
,
'410700'
,
'430800'
,
'410100'
,
'210800'
,
'330400'
,
'460200'
,
'650100'
,
'310100'
,
'350500'
,
'360400'
,
'320300'
,
'500100'
,
'360900'
,
'610100'
,
'350100'
,
'350400'
,
'530100'
,
'320600'
,
'130900'
,
'371300'
,
'421200'
,
'210700'
,
'220200'
,
'130700'
,
'320800'
,
'420100'
,
'110000'
,
'150400'
,
'442000'
,
'469002'
,
'360100'
,
'150800'
,
'441300'
,
'460100'
,
'610200'
,
'210100'
,
'210900'
,
'371400'
,
'621000'
,
'141000'
,
'330100'
,
'220700'
,
'371700'
,
'370800'
,
'211400'
,
'330200'
,
'140400'
,
'120000'
,
'231200'
,
'140100'
,
'431100'
,
'320200'
,
'451000'
,
'370200'
,
'511900'
,
'361100'
,
'610400'
,
'440600'
,
'411100'
,
'231000'
,
'360300'
]
self
.
city_codes
=
self
.
default_city_codes
self
.
city_code_dict
=
{
code
:
index
for
index
,
code
in
enumerate
(
self
.
city_codes
)
}
def
convert
(
self
,
value
):
def
convert
(
self
,
value
):
ret
=
[
0
]
*
len
(
self
.
city_code_dict
)
ret
=
[
0
]
*
self
.
dim
if
pd
.
isnull
(
value
):
if
pd
.
isnull
(
value
):
return
ret
return
ret
value
=
str
(
value
)
value
=
str
(
value
)
try
:
try
:
i
=
self
.
city_code_dict
[
value
]
for
i
,
_n
in
enumerate
(
value
[:
self
.
level
]):
ret
[
i
]
=
1
n
=
int
(
_n
)
ret
[
i
*
10
+
n
]
=
1
except
Exception
as
e
:
except
Exception
as
e
:
pass
pass
return
ret
return
ret
def
inconvert
(
self
,
embedding
):
# 邮编固定都是6
ret
=
[
0
]
*
6
for
index
,
_emb
in
enumerate
(
embedding
):
emb
=
int
(
_emb
)
if
emb
==
1
:
ret
[
int
(
index
/
10
)]
=
index
%
10
return
''
.
join
(
map
(
str
,
ret
))
class
AidiCstBiasCityProfile
(
CityProfile
):
class
AidiCstBiasCityProfile
(
CityProfile
):
""" 支持多个城市编码 """
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
level
=
2
)
->
None
:
super
()
.
__init__
()
super
()
.
__init__
(
level
=
level
)
def
convert
(
self
,
value
):
ret
=
[
0
]
*
len
(
self
.
city_code_dict
)
if
pd
.
isnull
(
value
):
def
convert
(
self
,
value_object
):
ret
=
[
0
]
*
self
.
dim
if
pd
.
isnull
(
value_object
):
return
ret
return
ret
if
not
value
:
if
not
value
_object
:
return
ret
return
ret
if
isinstance
(
value
,
str
):
if
isinstance
(
value
_object
,
str
):
try
:
try
:
value
=
json
.
loads
(
value
)
value
_object
=
json
.
loads
(
value_object
)
except
Exception
as
e
:
except
Exception
as
e
:
pass
pass
for
v
in
value
:
if
isinstance
(
value_object
,
dict
):
for
value
in
value_object
.
get
(
'in'
,
[]):
try
:
try
:
ret
[
self
.
city_code_dict
[
v
]]
=
1
for
i
,
_n
in
enumerate
(
value
[:
self
.
level
]):
n
=
int
(
_n
)
ret
[
i
*
10
+
n
]
=
1
except
Exception
as
e
:
except
Exception
as
e
:
pass
pass
return
ret
return
ret
...
...
src/core/profile_manager.py
View file @
cfdc63bd
...
@@ -15,16 +15,49 @@ class ProfileManager():
...
@@ -15,16 +15,49 @@ class ProfileManager():
订单用户画像数据管理
订单用户画像数据管理
"""
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
client
=
None
)
->
None
:
self
.
local_file_dir
=
get_data_path
()
self
.
local_file_dir
=
get_data_path
()
self
.
profile_file_path
=
os
.
path
.
join
(
self
.
local_file_dir
,
'all_profile.json'
)
self
.
profile_file_path
=
os
.
path
.
join
(
self
.
local_file_dir
,
'all_profile.json'
)
self
.
client
=
client
self
.
logger
=
logging
.
getLogger
(
__name__
)
self
.
logger
=
logging
.
getLogger
(
__name__
)
def
_fetch_data_from_db
(
self
,
conditions
=
None
):
if
self
.
client
is
None
:
self
.
logger
.
error
(
'未连接数据库'
)
raise
condition_sql
=
''
if
conditions
:
condition_sql
=
' WHERE '
+
' AND '
.
join
(
conditions
)
sql
=
'SELECT * FROM ads.ads_register_user_profiles'
sql
+=
' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'
.
format
(
condition_sql
)
_
,
all_data
=
self
.
client
.
query
(
sql
)
return
all_data
def
update_profile
(
self
):
""" 从数据库中拉取最新画像特征并保存 """
all_data
=
self
.
_fetch_data_from_db
()
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_profile.xlsx'
),
index
=
None
)
def
update_test_profile
(
self
,
conditions
):
""" 从数据库中拉取指定条件画像信息用于测试 """
all_data
=
self
.
_fetch_data_from_db
(
conditions
)
df
=
pd
.
DataFrame
(
all_data
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_profile.xlsx'
),
index
=
None
)
def
_load_profile_data
(
self
):
def
_load_profile_data
(
self
):
return
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_profile.xlsx'
))
return
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'all_profile.xlsx'
),
dtype
=
str
)
# with open(self.profile_file_path, 'r', encoding='utf-8') as f:
# return json.load(f)
def
load_test_profile_data
(
self
):
return
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
'test_profile.xlsx'
),
dtype
=
str
)
def
profile_to_embedding
(
self
,
profile
):
def
profile_to_embedding
(
self
,
profile
):
...
@@ -36,6 +69,18 @@ class ProfileManager():
...
@@ -36,6 +69,18 @@ class ProfileManager():
embedding
.
extend
(
converter
.
convert
(
profile
[
name
]))
embedding
.
extend
(
converter
.
convert
(
profile
[
name
]))
return
embedding
return
embedding
def
embedding_to_profile
(
self
,
embedding
):
"""
向量转换为用户画像
"""
ret
=
{}
si
=
0
for
[
name
,
converter
]
in
profile_converters
:
ei
=
si
+
converter
.
dim
ret
[
name
]
=
converter
.
inconvert
(
embedding
[
si
:
ei
])
si
=
ei
return
ret
def
make_embeddings
(
self
):
def
make_embeddings
(
self
):
user_profiles
=
self
.
_load_profile_data
()
user_profiles
=
self
.
_load_profile_data
()
...
@@ -61,10 +106,33 @@ class ProfileManager():
...
@@ -61,10 +106,33 @@ class ProfileManager():
user_ids
=
[]
user_ids
=
[]
embeddings
=
[]
embeddings
=
[]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings_ids.txt'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
user_ids
=
[
line
.
strip
()
for
line
in
f
]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings.json'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
embeddings
=
json
.
load
(
f
)
v_embedding_set
=
{}
for
user_id
,
embedding
in
zip
(
user_ids
,
embeddings
):
key
=
'_'
.
join
(
map
(
str
,
embedding
))
if
key
not
in
v_embedding_set
:
v_embedding_set
[
key
]
=
{
'embedding'
:
embedding
,
'user_ids'
:
[],
}
v_embedding_set
[
key
][
'user_ids'
]
.
append
(
str
(
user_id
))
v_embedding_list
=
[]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'virtual_user_embeddings_ids.txt'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
for
info
in
v_embedding_set
.
values
():
f
.
write
(
','
.
join
(
info
[
'user_ids'
])
+
'
\n
'
)
v_embedding_list
.
append
(
info
[
'embedding'
])
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'virtual_user_embeddings.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
v_embedding_list
,
f
,
ensure_ascii
=
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
manager
=
ProfileManager
()
manager
=
ProfileManager
()
manager
.
make_embeddings
()
# manager.make_embeddings()
# manager.update_local_data()
manager
.
make_virtual_embedding
()
# print(manager.make_index())
\ No newline at end of file
\ No newline at end of file
src/core/recommender.py
0 → 100644
View file @
cfdc63bd
# -*- coding: utf-8 -*-
import
os
import
json
from
typing
import
List
,
Dict
import
faiss
import
numpy
as
np
from
ydl_ai_recommender.src.core.profile_manager
import
ProfileManager
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.utils
import
get_conf_path
,
get_data_path
class
Recommender
():
def
__init__
(
self
)
->
None
:
pass
def
recommend
(
self
,
user
)
->
List
:
raise
NotImplemented
class
UserCFRecommender
(
Recommender
):
def
__init__
(
self
,
top_n
=
5
,
k
=
5
,
is_lazy
=
True
)
->
None
:
super
()
.
__init__
()
# 召回 top_n 个相似用户
self
.
top_n
=
top_n
# 每个召回的用户取 k 个相关咨询师
self
.
k
=
k
if
is_lazy
is
False
:
self
.
client
=
MySQLClient
.
create_from_config_file
(
get_conf_path
())
self
.
manager
=
ProfileManager
()
self
.
local_file_dir
=
get_data_path
()
self
.
load_data
()
def
load_data
(
self
):
order_user_embedding
=
[]
order_user_ids
=
[]
order_user_counselor_index
=
{}
default_counselor
=
[]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings_ids.txt'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
order_user_ids
=
[
line
.
strip
()
for
line
in
f
]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings.json'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
order_user_embedding
=
json
.
load
(
f
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_index.json'
),
encoding
=
'utf-8'
)
as
f
:
order_user_counselor_index
=
json
.
load
(
f
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'top50_supplier.txt'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
default_counselor
=
[
line
.
strip
()
for
line
in
f
]
self
.
order_user_embedding
=
order_user_embedding
self
.
order_user_ids
=
order_user_ids
self
.
order_user_counselor_index
=
order_user_counselor_index
self
.
default_counselor
=
default_counselor
self
.
index
=
faiss
.
IndexFlatL2
(
len
(
self
.
order_user_embedding
[
0
]))
self
.
index
.
add
(
np
.
array
(
self
.
order_user_embedding
))
def
get_user_profile
(
self
,
user_id
):
sql
=
'SELECT * FROM ads.ads_register_user_profiles'
sql
+=
' WHERE uid={}'
.
format
(
user_id
)
_
,
all_data
=
self
.
client
.
query
(
sql
)
if
len
(
all_data
)
==
0
:
return
[]
return
all_data
[
0
]
def
user_token
(
self
,
user_profile
):
return
self
.
manager
.
profile_to_embedding
(
user_profile
)
def
_recommend
(
self
,
user_embedding
):
D
,
I
=
self
.
index
.
search
(
np
.
array
([
user_embedding
]),
self
.
k
)
counselors
=
[]
for
idx
,
score
in
zip
(
I
[
0
],
D
[
0
]):
# 相似用户uid
similar_user_id
=
self
.
order_user_ids
[
idx
]
similar_user_counselor
=
self
.
order_user_counselor_index
.
get
(
similar_user_id
,
[])
recommend_data
=
[{
'counselor'
:
str
(
user
[
0
]),
'score'
:
float
(
score
),
'from'
:
'similar_users {}'
.
format
(
similar_user_id
),
}
for
user
in
similar_user_counselor
[:
self
.
top_n
]]
counselors
.
extend
(
recommend_data
)
return
counselors
def
recommend_with_profile
(
self
,
user_profile
):
user_embedding
=
self
.
user_token
(
user_profile
)
counselors
=
self
.
_recommend
(
user_embedding
)
return
counselors
def
recommend
(
self
,
user_id
):
"""
根据用户画像,推荐咨询师
若获取不到用户画像,推荐默认咨询师(订单最多的)
"""
user_profile
=
self
.
get_user_profile
(
user_id
)
if
not
user_profile
:
return
[]
return
self
.
recommend_with_profile
(
user_profile
)
if
__name__
==
'__main__'
:
recommender
=
UserCFRecommender
()
print
(
recommender
.
recommend
(
'10957910'
))
\ No newline at end of file
src/core/test.py
0 → 100644
View file @
cfdc63bd
# -*- coding: utf-8 -*-
import
re
import
json
import
logging
import
argparse
from
datetime
import
datetime
,
timedelta
from
ydl_ai_recommender.src.core.order_data_manager
import
OrderDataManager
from
ydl_ai_recommender.src.core.profile_manager
import
ProfileManager
from
ydl_ai_recommender.src.core.recommender
import
UserCFRecommender
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.utils
import
get_conf_path
logging
.
basicConfig
(
format
=
'
%(asctime)
s -
%(levelname)
s -
%(name)
s -
%(message)
s'
,
datefmt
=
'
%
m/
%
d/
%
Y
%
H:
%
M:
%
S'
,
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
def
main
(
args
):
# 构建用户画像字典,不用每次都从数据库中获取
profile_manager
=
ProfileManager
()
df
=
profile_manager
.
load_test_profile_data
()
user_profile_dict
=
{}
for
_
,
row
in
df
.
iterrows
():
user_profile_dict
[
row
[
'uid'
]]
=
row
manager
=
OrderDataManager
()
# 加载训练订单数据,为后面判断用户是否为新用户
train_orders
=
manager
.
load_raw_data
()
old_users
=
set
(
train_orders
[
'uid'
])
logger
.
info
(
'订单用户数
%
s '
,
len
(
old_users
))
# 加载测试数据
test_orders
=
manager
.
load_test_order_data
()
logger
.
info
(
'加载测试数据成功,共加载
%
s 条'
,
len
(
test_orders
))
recommender
=
UserCFRecommender
(
top_n
=
args
.
top_n
,
k
=
args
.
k
)
result_detail
=
[]
for
index
,
order_info
in
test_orders
.
iterrows
():
if
args
.
max_test
>
0
:
if
index
>=
args
.
max_test
:
break
uid
=
order_info
[
'uid'
]
profile
=
user_profile_dict
.
get
(
uid
)
if
profile
is
None
:
continue
recommend_result
=
recommender
.
recommend_with_profile
(
profile
)
recall_resons
=
[]
for
rr
in
recommend_result
:
if
rr
[
'counselor'
]
==
order_info
[
'supplier_id'
]:
recall_resons
.
append
(
rr
[
'from'
])
result_detail
.
append
({
'uid'
:
uid
,
'supplier_id'
:
order_info
[
'supplier_id'
],
'is_old_user'
:
uid
in
old_users
,
'recall_counselors'
:
recommend_result
,
'is_recall'
:
len
(
recall_resons
)
>
0
,
'recall_reason'
:
'|'
.
join
(
recall_resons
),
})
# 结果报告
metrics
=
{
'all_test_cnt'
:
0
,
'all_recall_cnt'
:
0
,
'old_user_test_cnt'
:
0
,
'old_user_recall_cnt'
:
0
,
'new_user_test_cnt'
:
0
,
'new_user_recall_cnt'
:
0
,
'same_user_recall_cnt'
:
0
,
'similar_user_recall_cnt'
:
0
,
}
for
rd
in
result_detail
:
metrics
[
'all_test_cnt'
]
+=
1
if
rd
[
'is_recall'
]:
metrics
[
'all_recall_cnt'
]
+=
1
is_same_user
,
is_similar_user
=
False
,
False
for
counselor
in
rd
[
'recall_counselors'
]:
from_id
=
counselor
[
'from'
]
.
split
(
' '
)[
1
]
if
from_id
==
rd
[
'uid'
]:
is_same_user
=
True
if
from_id
!=
rd
[
'uid'
]:
is_similar_user
=
True
if
is_same_user
:
metrics
[
'same_user_recall_cnt'
]
+=
1
if
is_similar_user
:
metrics
[
'similar_user_recall_cnt'
]
+=
1
if
rd
[
'is_old_user'
]:
metrics
[
'old_user_test_cnt'
]
+=
1
if
rd
[
'is_recall'
]:
metrics
[
'old_user_recall_cnt'
]
+=
1
else
:
metrics
[
'new_user_test_cnt'
]
+=
1
if
rd
[
'is_recall'
]:
metrics
[
'new_user_recall_cnt'
]
+=
1
logger
.
info
(
'=='
*
20
+
' 测试结果 '
+
'=='
*
20
)
logger
.
info
(
'--'
*
45
)
logger
.
info
(
'{:<10}{:<10}{:<10}{:<10}'
.
format
(
''
,
'样本数'
,
'召回数'
,
'召回率'
))
logger
.
info
(
'{:<10}{:<10}{:<10}{:<10.2
%
}'
.
format
(
'整体 '
,
metrics
[
'all_test_cnt'
],
metrics
[
'all_recall_cnt'
],
metrics
[
'all_recall_cnt'
]
/
metrics
[
'all_test_cnt'
]))
logger
.
info
(
'{:<10}{:<10}{:<10}{:<10.2
%
}'
.
format
(
'老用户'
,
metrics
[
'old_user_test_cnt'
],
metrics
[
'old_user_recall_cnt'
],
metrics
[
'old_user_recall_cnt'
]
/
metrics
[
'old_user_test_cnt'
]))
logger
.
info
(
'{:<10}{:<10}{:<10}{:<10.2
%
}'
.
format
(
'新用户'
,
metrics
[
'new_user_test_cnt'
],
metrics
[
'new_user_recall_cnt'
],
metrics
[
'new_user_recall_cnt'
]
/
metrics
[
'new_user_test_cnt'
]))
logger
.
info
(
''
)
logger
.
info
(
'--'
*
45
)
logger
.
info
(
'用户自己召回数 {} 占总召回比例 {:.2
%
}'
.
format
(
metrics
[
'same_user_recall_cnt'
],
metrics
[
'same_user_recall_cnt'
]
/
metrics
[
'all_recall_cnt'
]))
logger
.
info
(
'相似用户召回数 {} 占总召回比例 {:.2
%
}'
.
format
(
metrics
[
'similar_user_recall_cnt'
],
metrics
[
'similar_user_recall_cnt'
]
/
metrics
[
'all_recall_cnt'
]))
with
open
(
'result_detail.json'
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
result_detail
,
f
,
ensure_ascii
=
False
,
indent
=
2
)
def
update_test_data
(
args
):
start_date
=
''
if
re
.
match
(
r'20\d\d-[01]\d-\d\d'
,
args
.
start_date
):
start_date
=
args
.
start_date
elif
re
.
match
(
r'-\d+'
,
args
.
start_date
):
now
=
datetime
.
now
()
start_date
=
(
now
-
timedelta
(
days
=
int
(
args
.
start_date
[
1
:])))
.
strftime
(
'
%
Y-
%
m-
%
d'
)
else
:
logger
.
error
(
'args.start_date 参数格式错误,
%
s'
,
args
.
start_date
)
raise
conditions
=
[
'create_time >= "{}"'
.
format
(
start_date
)]
client
=
MySQLClient
.
create_from_config_file
(
get_conf_path
())
# 订单数据
manager
=
OrderDataManager
(
client
)
manager
.
update_test_order_data
(
conditions
=
conditions
)
# 用户画像数据
manager
=
ProfileManager
(
client
)
manager
.
update_test_profile
(
conditions
=
conditions
)
logger
.
info
(
'测试数据更新完成'
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--k'
,
default
=
5
,
type
=
int
,
help
=
'召回相似用户的数量'
)
parser
.
add_argument
(
'--top_n'
,
default
=
5
,
type
=
int
,
help
=
'每个相似用户召回的咨询师数量'
)
parser
.
add_argument
(
'--max_test'
,
default
=
0
,
type
=
int
,
help
=
'最多测试数据量'
)
parser
.
add_argument
(
'--do_update_test_data'
,
default
=
False
,
action
=
'store_true'
,
help
=
'是否更新测试数据'
)
parser
.
add_argument
(
'--start_date'
,
default
=
'-1'
,
type
=
str
,
help
=
'测试订单创建的开始时间,可以是"
%
Y-
%
m-
%
d"格式,也可以是 -3 表示前3天'
)
args
=
parser
.
parse_args
()
if
args
.
do_update_test_data
:
logger
.
info
(
'更新测试数据'
)
update_test_data
(
args
)
main
(
args
)
\ No newline at end of file
src/core/user_similarity.py
View file @
cfdc63bd
...
@@ -8,6 +8,13 @@ from itertools import combinations
...
@@ -8,6 +8,13 @@ from itertools import combinations
from
ydl_ai_recommender.src.utils
import
get_data_path
from
ydl_ai_recommender.src.utils
import
get_data_path
logging
.
basicConfig
(
format
=
'
%(asctime)
s -
%(levelname)
s -
%(name)
s -
%(message)
s'
,
datefmt
=
'
%
m/
%
d/
%
Y
%
H:
%
M:
%
S'
,
level
=
logging
.
INFO
)
class
BasicUserSimilarity
():
class
BasicUserSimilarity
():
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
...
@@ -24,7 +31,7 @@ class BasicUserSimilarity():
...
@@ -24,7 +31,7 @@ class BasicUserSimilarity():
counselor_user_index
=
{}
counselor_user_index
=
{}
for
user
,
counselors
in
user_counselor_index
.
items
():
for
user
,
counselors
in
user_counselor_index
.
items
():
user_like_set
[
user
]
=
len
(
counselors
)
user_like_set
[
user
]
=
set
([
c
[
0
]
for
c
in
counselors
]
)
for
[
counselor
,
_
,
_
]
in
counselors
:
for
[
counselor
,
_
,
_
]
in
counselors
:
if
counselor
not
in
counselor_user_index
:
if
counselor
not
in
counselor_user_index
:
...
@@ -32,20 +39,38 @@ class BasicUserSimilarity():
...
@@ -32,20 +39,38 @@ class BasicUserSimilarity():
counselor_user_index
[
counselor
]
.
append
(
user
)
counselor_user_index
[
counselor
]
.
append
(
user
)
# 两个用户与同一个咨询师有订单,就认为两个用户相似
# 两个用户与同一个咨询师有订单,就认为两个用户相似
self
.
logger
.
info
(
'开始构建用户相似性关系'
)
self
.
logger
.
info
(
'开始构建用户
之间
相似性关系'
)
relations
=
{}
relations
=
{}
user_index
=
{}
for
users
in
counselor_user_index
.
values
():
for
users
in
counselor_user_index
.
values
():
for
[
_u1
,
_u2
]
in
combinations
(
users
,
2
):
for
[
_u1
,
_u2
]
in
combinations
(
users
,
2
):
u1
,
u2
=
min
(
_u1
,
_u2
),
max
(
_u1
,
_u2
)
u1
,
u2
=
min
(
_u1
,
_u2
),
max
(
_u1
,
_u2
)
key
=
'{}_{}'
.
format
(
u1
,
u2
)
key
=
'{}_{}'
.
format
(
u1
,
u2
)
if
key
in
relations
:
if
key
in
relations
:
continue
continue
relations
[
key
]
=
1.0
/
(
user_like_set
[
u1
]
*
user_like_set
[
u2
])
sim
=
len
(
user_like_set
[
u1
]
&
user_like_set
[
u2
])
/
(
len
(
user_like_set
[
u1
])
*
len
(
user_like_set
[
u2
]))
relations
[
key
]
=
sim
if
u1
not
in
user_index
:
user_index
[
u1
]
=
{}
if
u2
not
in
user_index
:
user_index
[
u2
]
=
{}
if
u2
not
in
user_index
[
u1
]:
user_index
[
u1
][
u2
]
=
sim
if
u1
not
in
user_index
[
u2
]:
user_index
[
u2
][
u1
]
=
sim
user_counselor_index
=
{}
self
.
logger
.
info
(
'用户相似性关系构建完成,共有
%
s 对关系'
,
len
(
relations
))
self
.
logger
.
info
(
'用户相似性关系构建完成,共有
%
s 对关系'
,
len
(
relations
))
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_similarity.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_similarity.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
relations
,
f
,
ensure_ascii
=
False
,
indent
=
2
)
json
.
dump
(
relations
,
f
,
ensure_ascii
=
False
,
indent
=
2
)
def
recall
(
self
,
user
,
N
=
10
,
top_k
=
10
):
pass
bs
=
BasicUserSimilarity
()
bs
=
BasicUserSimilarity
()
bs
.
compute_similarity
()
bs
.
compute_similarity
()
\ No newline at end of file
src/data/mysql_client.py
View file @
cfdc63bd
...
@@ -21,11 +21,26 @@ class MySQLClient():
...
@@ -21,11 +21,26 @@ class MySQLClient():
self
.
cursor
=
self
.
connection
.
cursor
()
self
.
cursor
=
self
.
connection
.
cursor
()
self
.
_log_info
(
'数据库连接成功'
)
self
.
_log_info
(
'数据库连接成功'
)
def
_connect
(
self
):
self
.
connection
=
pymysql
.
connect
(
host
=
self
.
host
,
port
=
self
.
port
,
user
=
self
.
user
,
password
=
self
.
password
,
charset
=
'utf8mb4'
,
cursorclass
=
pymysql
.
cursors
.
DictCursor
)
self
.
cursor
=
self
.
connection
.
cursor
()
self
.
_log_info
(
'数据库连接成功'
)
def
_log_info
(
self
,
text
,
*
args
,
**
params
):
def
_log_info
(
self
,
text
,
*
args
,
**
params
):
if
self
.
logger
:
if
self
.
logger
:
self
.
logger
.
info
(
text
,
*
args
,
**
params
)
self
.
logger
.
info
(
text
,
*
args
,
**
params
)
def
query
(
self
,
sql
):
def
query
(
self
,
sql
):
if
self
.
cursor
is
None
:
self
.
_connect
()
sql
+=
' limit 1000'
sql
+=
' limit 1000'
self
.
_log_info
(
'begin execute sql:
%
s'
,
sql
)
self
.
_log_info
(
'begin execute sql:
%
s'
,
sql
)
row_count
=
self
.
cursor
.
execute
(
sql
)
row_count
=
self
.
cursor
.
execute
(
sql
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment