Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
Y
ydl_ai_recommender
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
闫发泽
ydl_ai_recommender
Commits
9f715802
Commit
9f715802
authored
Dec 21, 2022
by
柴鹏飞
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
新增咨询师->咨询师索引关系
parent
65e1c61d
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
267 additions
and
72 deletions
+267
-72
test.py
bin/test.py
+109
-2
update.py
bin/update.py
+4
-4
indexer.py
src/core/indexer.py
+113
-61
recommender.py
src/core/recommender.py
+40
-4
recommend_service.py
src/service/recommend_service.py
+1
-1
No files found.
bin/test.py
View file @
9f715802
...
...
@@ -76,7 +76,7 @@ def evaluation(result_detail):
for
top_n
,
counselor
in
enumerate
(
rd
[
'recall_counselors'
]):
if
counselor
[
'counselor'
]
==
rd
[
'supplier_id'
]:
from_info
=
counselor
[
'from'
]
.
split
(
' '
)
if
from_info
[
0
]
==
'
top_100
'
:
if
from_info
[
0
]
==
'
default
'
:
metrics
[
'default_recall_cnt'
]
+=
1
else
:
from_id
=
from_info
[
1
]
...
...
@@ -122,6 +122,8 @@ def evaluation(result_detail):
for
i
,
n
in
enumerate
(
top_n_list
):
logger
.
info
(
'top {:<2} 的召回数 {:<4} 召回率 {:.2
%
}'
.
format
(
n
,
metrics
[
'top_n_recall_cnt'
][
i
],
_sd
(
metrics
[
'top_n_recall_cnt'
][
i
],
metrics
[
'all_test_cnt'
])))
return
metrics
def
do_test
(
args
):
user_profile_dict
=
load_local_user_profile
()
...
...
@@ -147,7 +149,7 @@ def do_test(args):
continue
is_merge
=
args
.
mode
==
0
size
=
10
if
args
.
mode
==
0
else
0
size
=
10
0
if
args
.
mode
==
0
else
0
recommend_result
=
recommender
.
recommend_with_profile
(
profile
,
size
=
size
,
is_merge
=
is_merge
)
recall_resons
=
[]
for
rr
in
recommend_result
:
...
...
@@ -216,6 +218,104 @@ def update_test_data(args):
logger
.
info
(
'测试数据更新完成'
)
def
batch_test
(
args
):
user_profile_dict
=
load_local_user_profile
()
try
:
old_users
,
test_orders
=
load_test_data
()
except
Exception
as
e
:
logger
.
error
(
'测试数据加载出错,请确认测试数据已经下载到本地,或启动命令中增加参数 "--do_update_test_data"'
,
exc_info
=
True
)
return
def
_test
(
case
):
logger
.
info
(
'##'
*
20
+
' 测试case '
+
'##'
*
20
)
logger
.
info
(
json
.
dumps
(
case
,
ensure_ascii
=
False
))
logger
.
info
(
'##'
*
20
+
' 测试case '
+
'##'
*
20
)
recommender
=
UserCFRecommender
(
**
case
)
result_detail
=
[]
for
index
,
order_info
in
test_orders
.
iterrows
():
if
args
.
max_test
>
0
:
if
index
>=
args
.
max_test
:
break
uid
=
order_info
[
'uid'
]
profile
=
user_profile_dict
.
get
(
uid
)
if
profile
is
None
:
continue
is_merge
=
args
.
mode
==
0
size
=
100
if
args
.
mode
==
0
else
0
recommend_result
=
recommender
.
recommend_with_profile
(
profile
,
size
=
size
,
is_merge
=
is_merge
)
recall_resons
=
[]
for
rr
in
recommend_result
:
if
rr
[
'counselor'
]
==
order_info
[
'supplier_id'
]:
recall_resons
.
append
(
rr
[
'from'
])
result_detail
.
append
({
'uid'
:
uid
,
'supplier_id'
:
order_info
[
'supplier_id'
],
'is_old_user'
:
uid
in
old_users
,
'recall_counselors'
:
recommend_result
,
'is_recall'
:
len
(
recall_resons
)
>
0
,
'recall_reason'
:
'|'
.
join
(
recall_resons
),
'update_time'
:
order_info
[
'update_time'
],
})
return
result_detail
def
_evaluation
(
result_detail
):
metrics
=
evaluation
(
result_detail
)
metrics
[
'all_recall_ratio'
]
=
_sd
(
metrics
[
'all_recall_cnt'
],
metrics
[
'all_test_cnt'
])
metrics
[
'old_user_recall_ratio'
]
=
_sd
(
metrics
[
'old_user_recall_cnt'
],
metrics
[
'old_user_test_cnt'
])
metrics
[
'new_user_recall_ratio'
]
=
_sd
(
metrics
[
'new_user_recall_cnt'
],
metrics
[
'new_user_test_cnt'
])
metrics
[
'top1'
]
=
_sd
(
metrics
[
'top_n_recall_cnt'
][
0
],
metrics
[
'all_test_cnt'
])
metrics
[
'top10'
]
=
_sd
(
metrics
[
'top_n_recall_cnt'
][
3
],
metrics
[
'all_test_cnt'
])
metrics
[
'top50'
]
=
_sd
(
metrics
[
'top_n_recall_cnt'
][
5
],
metrics
[
'all_test_cnt'
])
return
metrics
test_cases
=
[
{
'top_n'
:
2
,
'k'
:
50
,
'u2c'
:
'combination'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
5
,
'k'
:
20
,
'u2c'
:
'combination'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
10
,
'k'
:
10
,
'u2c'
:
'combination'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
20
,
'k'
:
5
,
'u2c'
:
'combination'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
2
,
'k'
:
50
,
'u2c'
:
'chat'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
5
,
'k'
:
20
,
'u2c'
:
'chat'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
10
,
'k'
:
10
,
'u2c'
:
'chat'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
20
,
'k'
:
5
,
'u2c'
:
'chat'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
2
,
'k'
:
50
,
'u2c'
:
'order'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
5
,
'k'
:
20
,
'u2c'
:
'order'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
10
,
'k'
:
10
,
'u2c'
:
'order'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
20
,
'k'
:
5
,
'u2c'
:
'order'
,
'c2c'
:
None
,
'is_use_db'
:
False
},
{
'top_n'
:
2
,
'k'
:
50
,
'u2c'
:
'combination'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
{
'top_n'
:
5
,
'k'
:
20
,
'u2c'
:
'combination'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
{
'top_n'
:
10
,
'k'
:
10
,
'u2c'
:
'combination'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
{
'top_n'
:
20
,
'k'
:
5
,
'u2c'
:
'combination'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
{
'top_n'
:
2
,
'k'
:
50
,
'u2c'
:
'chat'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
{
'top_n'
:
5
,
'k'
:
20
,
'u2c'
:
'chat'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
{
'top_n'
:
10
,
'k'
:
10
,
'u2c'
:
'chat'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
{
'top_n'
:
20
,
'k'
:
5
,
'u2c'
:
'chat'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
{
'top_n'
:
2
,
'k'
:
50
,
'u2c'
:
'order'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
{
'top_n'
:
5
,
'k'
:
20
,
'u2c'
:
'order'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
{
'top_n'
:
10
,
'k'
:
10
,
'u2c'
:
'order'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
{
'top_n'
:
20
,
'k'
:
5
,
'u2c'
:
'order'
,
'c2c'
:
True
,
'is_use_db'
:
False
},
]
all_test_result
=
[]
for
case
in
test_cases
:
result_detail
=
_test
(
case
)
metrics
=
_evaluation
(
result_detail
)
all_test_result
.
append
((
case
,
metrics
))
with
open
(
'batch_test_result.tsv'
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
for
(
case
,
metrics
)
in
all_test_result
:
items
=
[
case
[
'k'
],
case
[
'top_n'
],
case
[
'u2c'
],
case
[
'c2c'
]]
items
.
extend
([
metrics
[
'all_recall_ratio'
],
metrics
[
'old_user_recall_ratio'
],
metrics
[
'new_user_recall_ratio'
]])
items
.
extend
([
metrics
[
'top1'
],
metrics
[
'top10'
],
metrics
[
'top50'
]])
f
.
write
(
'
\t
'
.
join
(
map
(
lambda
x
:
str
(
x
),
items
))
+
'
\n
'
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -227,6 +327,8 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--save_test_result'
,
default
=
False
,
action
=
'store_true'
,
help
=
'保存测试详情结果'
)
parser
.
add_argument
(
'--show_result_by_day'
,
default
=
False
,
action
=
'store_true'
,
help
=
'测试结果是否按天展示'
)
parser
.
add_argument
(
'--do_batch_test'
,
default
=
False
,
action
=
'store_true'
,
help
=
'批量测试'
)
parser
.
add_argument
(
'--do_update_test_data'
,
default
=
False
,
action
=
'store_true'
,
help
=
'是否更新测试数据'
)
parser
.
add_argument
(
'--start_date'
,
default
=
'-1'
,
type
=
str
,
help
=
'测试订单创建的开始时间,可以是"
%
Y-
%
m-
%
d"格式,也可以是 -3 表示前3天'
)
...
...
@@ -238,5 +340,9 @@ if __name__ == '__main__':
logger
.
info
(
'测试数据创建时间
%
s'
,
args
.
start_date
)
update_test_data
(
args
)
if
args
.
do_batch_test
:
logger
.
info
(
'执行批量测试'
)
batch_test
(
args
)
else
:
logger
.
info
(
'开始执行测试任务,测试模式为
%
s'
,
args
.
mode
)
do_test
(
args
)
\ No newline at end of file
bin/update.py
View file @
9f715802
...
...
@@ -91,10 +91,10 @@ if __name__ == '__main__':
if
args
.
task
==
'make_index'
:
indexers
=
[
[
'[用户->咨询师]兜底关系索引'
,
UserCounselorDefaultIndexer
()],
[
'基于订单数据的[用户->咨询师]关系索引'
,
UserCounselorOrderIndexer
()],
[
'基于询单数据的[用户->咨询师]关系索引'
,
UserCounselorChatIndexer
()],
[
'基于多种数据组合的[用户->咨询师]关系索引'
,
UserCounselorCombinationIndexer
()],
[
'[用户->咨询师]兜底关系索引'
,
UserCounselorDefaultIndexer
(
logger
=
logger
)],
[
'基于订单数据的[用户->咨询师]关系索引'
,
UserCounselorOrderIndexer
(
logger
=
logger
)],
[
'基于询单数据的[用户->咨询师]关系索引'
,
UserCounselorChatIndexer
(
logger
=
logger
)],
[
'基于多种数据组合的[用户->咨询师]关系索引'
,
UserCounselorCombinationIndexer
(
0.8
,
0.2
,
logger
=
logger
)],
]
logger
.
info
(
''
)
...
...
src/core/indexer.py
View file @
9f715802
...
...
@@ -3,6 +3,7 @@
import
os
import
json
from
collections
import
Counter
from
itertools
import
combinations
from
datetime
import
datetime
,
timedelta
from
typing
import
Dict
,
List
,
Tuple
...
...
@@ -22,16 +23,34 @@ class Indexer():
self
.
logger
=
logger
self
.
local_file_dir
=
get_data_path
()
self
.
index_file
=
''
self
.
index_data
=
{}
def
make_index
(
self
)
->
Dict
[
str
,
List
]:
raise
NotImplementedError
def
load_index_data
(
self
):
if
not
os
.
path
.
exists
(
self
.
index_file
):
self
.
logger
.
error
(
'
%
s 不存在,确认索引是否已经构建完成'
,
self
.
index_file
)
raise
'文件不存在'
index
=
[]
with
open
(
self
.
index_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
index
=
json
.
load
(
f
)
self
.
index_data
=
index
def
index
(
self
,
q
:
str
,
count
:
i
nt
=
0
)
->
List
[
Tuple
[
str
,
float
]]:
def
index
(
self
,
q
=
''
,
cou
nt
=
0
)
->
List
[
Tuple
[
str
,
float
]]:
"""
返回值类型:[[相似id, score], [相似id, score], ...]
"""
raise
NotImplementedError
if
len
(
self
.
index_data
)
==
0
:
self
.
logger
.
error
(
'未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法'
)
raise
def
make_index
(
self
)
->
Dict
[
str
,
List
]:
raise
NotImplementedError
if
count
==
0
:
return
self
.
index_data
.
get
(
q
,
[])
else
:
return
self
.
index_data
.
get
(
q
,
[])[:
count
]
class
UserCounselorDefaultIndexer
(
Indexer
):
...
...
@@ -40,7 +59,7 @@ class UserCounselorDefaultIndexer(Indexer):
"""
def
__init__
(
self
,
logger
=
None
)
->
None
:
super
()
.
__init__
(
logger
)
self
.
data_manager
=
OrderDataManager
(
logger
)
self
.
data_manager
=
OrderDataManager
(
self
.
logger
)
self
.
index_file
=
os
.
path
.
join
(
self
.
local_file_dir
,
'index_list.txt'
)
self
.
count
=
100
self
.
index_data
=
[]
...
...
@@ -88,7 +107,7 @@ class UserCounselorOrderIndexer(Indexer):
def
__init__
(
self
,
logger
=
None
)
->
None
:
super
()
.
__init__
(
logger
)
self
.
data_manager
=
OrderDataManager
(
logger
)
self
.
data_manager
=
OrderDataManager
(
self
.
logger
)
self
.
index_file
=
os
.
path
.
join
(
self
.
local_file_dir
,
'user_counselor_order_index.json'
)
self
.
index_data
=
{}
self
.
now
=
datetime
.
now
()
...
...
@@ -100,13 +119,13 @@ class UserCounselorOrderIndexer(Indexer):
date
=
datetime
.
strptime
(
dt
,
'
%
Y-
%
m-
%
d'
)
if
(
self
.
now
-
date
)
<=
timedelta
(
days
=
7
):
w
[
0
]
=
max
(
1.
,
w
[
0
],
price
/
400
)
w
[
0
]
=
max
(
w
[
0
],
min
(
1.
,
price
/
400
)
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
30
):
w
[
1
]
=
max
(
1.
,
w
[
1
],
price
/
400
)
w
[
1
]
=
max
(
w
[
1
],
min
(
1.
,
price
/
400
)
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
180
):
w
[
2
]
=
max
(
1.
,
w
[
2
],
price
/
400
)
w
[
2
]
=
max
(
w
[
2
],
min
(
1.
,
price
/
400
)
)
else
:
w
[
3
]
=
max
(
1.
,
w
[
3
],
price
/
400
)
w
[
3
]
=
max
(
w
[
3
],
min
(
1.
,
price
/
400
)
)
value
=
w
[
0
]
*
0.5
+
w
[
1
]
*
0.25
+
w
[
2
]
*
0.15
+
w
[
3
]
*
0.1
return
value
...
...
@@ -142,22 +161,6 @@ class UserCounselorOrderIndexer(Indexer):
self
.
logger
.
info
(
'基于订单数据的[用户->咨询师]关系索引数据已保存,共有用户
%
s'
,
len
(
index
))
def
load_index_data
(
self
):
index
=
[]
with
open
(
self
.
index_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
index
=
json
.
load
(
f
)
self
.
index_data
=
index
def
index
(
self
,
q
=
''
,
count
=
0
)
->
List
[
Tuple
[
str
,
float
]]:
if
len
(
self
.
index_data
)
==
0
:
self
.
logger
.
error
(
'未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法'
)
raise
if
count
==
0
:
return
self
.
index_data
.
get
(
q
,
[])
else
:
return
self
.
index_data
.
get
(
q
,
[])[:
count
]
class
UserCounselorChatIndexer
(
Indexer
):
"""
...
...
@@ -166,7 +169,7 @@ class UserCounselorChatIndexer(Indexer):
def
__init__
(
self
,
logger
=
None
)
->
None
:
super
()
.
__init__
(
logger
)
self
.
data_manager
=
ChatDataManager
(
logger
)
self
.
data_manager
=
ChatDataManager
(
self
.
logger
)
self
.
index_file
=
os
.
path
.
join
(
self
.
local_file_dir
,
'user_counselor_chat_index.json'
)
self
.
index_data
=
{}
self
.
now
=
datetime
.
now
()
...
...
@@ -178,13 +181,13 @@ class UserCounselorChatIndexer(Indexer):
date
=
datetime
.
strptime
(
dt
,
'
%
Y-
%
m-
%
d'
)
if
(
self
.
now
-
date
)
<=
timedelta
(
days
=
7
):
w
[
0
]
=
max
(
1.
,
w
[
0
],
(
u2d
+
d2u
)
/
20
)
w
[
0
]
=
max
(
w
[
0
],
min
(
1.
,
(
u2d
+
d2u
)
/
20
)
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
30
):
w
[
1
]
=
max
(
1.
,
w
[
1
],
(
u2d
+
d2u
)
/
20
)
w
[
1
]
=
max
(
w
[
1
],
min
(
1.
,
(
u2d
+
d2u
)
/
20
)
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
180
):
w
[
2
]
=
max
(
1.
,
w
[
2
],
(
u2d
+
d2u
)
/
20
)
w
[
2
]
=
max
(
w
[
2
],
min
(
1.
,
(
u2d
+
d2u
)
/
20
)
)
else
:
w
[
3
]
=
max
(
1.
,
w
[
3
],
(
u2d
+
d2u
)
/
20
)
w
[
3
]
=
max
(
w
[
3
],
min
(
1.
,
(
u2d
+
d2u
)
/
20
)
)
value
=
w
[
0
]
*
0.5
+
w
[
1
]
*
0.25
+
w
[
2
]
*
0.15
+
w
[
3
]
*
0.1
return
value
...
...
@@ -221,29 +224,12 @@ class UserCounselorChatIndexer(Indexer):
self
.
logger
.
info
(
'基于询单数据的[用户->咨询师]关系索引数据已保存,共有用户
%
s'
,
len
(
index
))
def
load_index_data
(
self
):
index
=
[]
with
open
(
self
.
index_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
index
=
json
.
load
(
f
)
self
.
index_data
=
index
def
index
(
self
,
q
=
''
,
count
=
0
)
->
List
[
Tuple
[
str
,
float
]]:
if
len
(
self
.
index_data
)
==
0
:
self
.
logger
.
error
(
'未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法'
)
raise
if
count
==
0
:
return
self
.
index_data
.
get
(
q
,
[])
else
:
return
self
.
index_data
.
get
(
q
,
[])[:
count
]
class
UserCounselorCombinationIndexer
(
Indexer
):
"""
基于多种数据组合的[用户->咨询师]关系索引
"""
def
__init__
(
self
,
order_w
=
0.
6
,
chat_w
=
0.4
,
logger
=
None
)
->
None
:
def
__init__
(
self
,
order_w
=
0.
8
,
chat_w
=
0.2
,
logger
=
None
)
->
None
:
super
()
.
__init__
(
logger
)
self
.
order_w
=
order_w
self
.
chat_w
=
chat_w
...
...
@@ -278,21 +264,83 @@ class UserCounselorCombinationIndexer(Indexer):
return
index
def
load_index_data
(
self
):
index
=
[]
with
open
(
self
.
index_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
index
=
json
.
load
(
f
)
self
.
index_data
=
index
def
index
(
self
,
q
=
''
,
count
=
0
)
->
List
[
Tuple
[
str
,
float
]]:
if
len
(
self
.
index_data
)
==
0
:
self
.
logger
.
error
(
'未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法'
)
raise
class
CounselorCounselorCFIndexer
(
Indexer
):
"""
基于协同过滤的[咨询师->咨询师]关系索引
"""
def
__init__
(
self
,
order_w
=
0.8
,
chat_w
=
0.2
,
logger
=
None
)
->
None
:
super
()
.
__init__
(
logger
)
self
.
order_w
=
order_w
self
.
chat_w
=
chat_w
self
.
index_file
=
os
.
path
.
join
(
self
.
local_file_dir
,
'counselor_counselor_cf_index.json'
)
self
.
index_data
=
{}
if
count
==
0
:
return
self
.
index_data
.
get
(
q
,
[])
def
load_pair_by_chat
(
self
):
data_manager
=
ChatDataManager
(
self
.
logger
)
df
=
data_manager
.
load_raw_data
()
counselor_user_dict
=
{}
for
_
,
row
in
df
.
iterrows
():
uid
,
supplier_id
=
row
[
'uid'
],
row
[
'doctor_id'
]
if
supplier_id
not
in
counselor_user_dict
:
counselor_user_dict
[
supplier_id
]
=
set
()
counselor_user_dict
[
supplier_id
]
.
add
(
uid
)
return
counselor_user_dict
def
load_pair_by_order
(
self
):
data_manager
=
OrderDataManager
(
self
.
logger
)
df
=
data_manager
.
load_raw_data
()
counselor_user_dict
=
{}
for
_
,
row
in
df
.
iterrows
():
uid
,
supplier_id
=
row
[
'uid'
],
row
[
'supplier_id'
]
if
supplier_id
not
in
counselor_user_dict
:
counselor_user_dict
[
supplier_id
]
=
set
()
counselor_user_dict
[
supplier_id
]
.
add
(
uid
)
return
counselor_user_dict
def
make_index
(
self
)
->
Dict
[
str
,
List
]:
self
.
logger
.
info
(
''
)
self
.
logger
.
info
(
'开始构建基于协同过滤的[咨询师->咨询师]关系索引'
)
counselor_user_dict
=
self
.
load_pair_by_order
()
self
.
logger
.
info
(
'基于订单的[用户->咨询师]数据加载完成'
)
_counselor_user_dict2
=
self
.
load_pair_by_chat
()
self
.
logger
.
info
(
'基于询单的[用户->咨询师]数据加载完成'
)
for
key
,
val
in
_counselor_user_dict2
.
items
():
if
key
not
in
counselor_user_dict
:
counselor_user_dict
[
key
]
=
val
else
:
return
self
.
index_data
.
get
(
q
,
[])[:
count
]
counselor_user_dict
[
key
]
.
update
(
val
)
self
.
logger
.
info
(
'数据合并完成,共有咨询师
%
s'
,
len
(
counselor_user_dict
))
index
=
{}
for
[
_u1
,
_u2
]
in
combinations
(
counselor_user_dict
.
keys
(),
2
):
u1
,
u2
=
min
(
_u1
,
_u2
),
max
(
_u1
,
_u2
)
sim
=
len
(
counselor_user_dict
[
u1
]
&
counselor_user_dict
[
u2
])
/
(
len
(
counselor_user_dict
[
u1
])
*
len
(
counselor_user_dict
[
u2
]))
if
u1
not
in
index
:
index
[
u1
]
=
[]
if
u2
not
in
index
:
index
[
u2
]
=
[]
index
[
u1
]
.
append
((
u2
,
sim
))
index
[
u2
]
.
append
((
u1
,
sim
))
self
.
logger
.
info
(
'开始咨询师相似性排序'
)
for
key
,
val
in
index
.
items
():
# 根据相似性得分排序后,取前100个
index
[
key
]
=
sorted
(
val
,
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)[:
100
]
self
.
logger
.
info
(
'基于协同过滤的[咨询师->咨询师]关系索引构建完成,共构建
%
s 条数据'
,
len
(
index
))
with
open
(
self
.
index_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
return
index
if
__name__
==
'__main__'
:
...
...
@@ -307,3 +355,6 @@ if __name__ == '__main__':
indexer
=
UserCounselorCombinationIndexer
()
indexer
.
make_index
()
indexer
=
CounselorCounselorCFIndexer
()
indexer
.
make_index
()
\ No newline at end of file
src/core/recommender.py
View file @
9f715802
...
...
@@ -7,8 +7,13 @@ from typing import List, Dict
import
faiss
import
numpy
as
np
from
ydl_ai_recommender.src.core.indexer
import
UserCounselorDefaultIndexer
from
ydl_ai_recommender.src.core.indexer
import
UserCounselorCombinationIndexer
from
ydl_ai_recommender.src.core.indexer
import
(
UserCounselorChatIndexer
,
UserCounselorOrderIndexer
,
UserCounselorDefaultIndexer
,
UserCounselorCombinationIndexer
,
CounselorCounselorCFIndexer
,
)
from
ydl_ai_recommender.src.core.profile
import
encode_profile
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.utils
import
get_conf_path
,
get_data_path
...
...
@@ -26,7 +31,15 @@ class Recommender():
class
UserCFRecommender
(
Recommender
):
def
__init__
(
self
,
top_n
=
5
,
k
=
20
,
is_use_db
=
True
)
->
None
:
def
__init__
(
self
,
top_n
=
5
,
k
=
20
,
is_use_db
=
True
,
u2c
=
'combination'
,
c2c
=
None
)
->
None
:
"""
params:
top_n: 每个召回的用户获取的相关咨询师个数
k: 召回的相似用户数
is_use_db: 是否使用数据库
u2c: [用户->咨询师] 索引方法
c2c: [咨询师->咨询师] 索引方法,None 表示不使用咨询师拓展
"""
super
()
.
__init__
()
# 召回 top_n 个相似用户
self
.
top_n
=
top_n
...
...
@@ -41,9 +54,21 @@ class UserCFRecommender(Recommender):
self
.
default_indexer
=
UserCounselorDefaultIndexer
(
self
.
logger
)
self
.
default_indexer
.
load_index_data
()
if
u2c
==
'chat'
:
self
.
indexer
=
UserCounselorChatIndexer
(
self
.
logger
)
elif
u2c
==
'order'
:
self
.
indexer
=
UserCounselorOrderIndexer
(
self
.
logger
)
else
:
self
.
indexer
=
UserCounselorCombinationIndexer
(
self
.
logger
)
self
.
indexer
.
load_index_data
()
if
c2c
:
self
.
c2c_indexer
=
CounselorCounselorCFIndexer
(
self
.
logger
)
self
.
c2c_indexer
.
load_index_data
()
else
:
self
.
c2c_indexer
=
None
self
.
local_file_dir
=
get_data_path
()
self
.
load_data
()
...
...
@@ -100,7 +125,18 @@ class UserCFRecommender(Recommender):
'score'
:
score
/
max
(
0.01
,
float
(
simi_score
)),
'from'
:
'similar_users {}'
.
format
(
similar_user_id
),
}
for
(
c_id
,
score
)
in
similar_user_counselor
]
counselors
.
extend
(
recommend_data
)
supplement_data
=
[]
if
self
.
c2c_indexer
:
for
ro
in
recommend_data
:
supplement_data
.
extend
([{
'counselor'
:
sc_id
,
'score'
:
ro
[
'score'
]
*
score
,
'from'
:
'{} supplement {}'
.
format
(
ro
[
'from'
],
ro
[
'counselor'
]),
}
for
(
sc_id
,
score
)
in
self
.
c2c_indexer
.
index
(
ro
[
'counselor'
],
count
=
int
(
self
.
top_n
))])
# } for (sc_id, score) in self.c2c_indexer.index(ro['counselor'], count=int(self.top_n / len(recommend_data)))])
counselors
.
extend
(
recommend_data
+
supplement_data
)
counselors
.
sort
(
key
=
lambda
x
:
x
[
'score'
],
reverse
=
True
)
return
counselors
...
...
src/service/recommend_service.py
View file @
9f715802
...
...
@@ -14,7 +14,7 @@ from ydl_ai_recommender.src.core.recommender import UserCFRecommender
logger
=
create_logger
(
__name__
,
'service.log'
,
is_rotating
=
True
)
recommender
=
UserCFRecommender
(
top_n
=
5
,
k
=
5
)
recommender
=
UserCFRecommender
(
top_n
=
5
,
k
=
20
,
c2c
=
True
)
class
RecommendHandler
(
tornado
.
web
.
RequestHandler
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment