Commit 448f26e8 by chai pengfei

Initial commit

parents
# OpenAIBot
将 OpenAI 的能力包装成各类机器人
# 安装
```
git clone https://github.com/pfchai/OpenAIBot.git
cd OpenAIBot
# 进入虚拟环境
pip install -r requirements.txt
```
# 启动服务
```
# 配置
cp config.yaml.example config.yaml
# 补充相关配置信息
vi config.yaml
# 启动服务
flask run
# 指定端口,外网可用
flask run --host 0.0.0.0 --port 8000
```
# ToDos
- [x] 基于 GPT3 对话机器人
- [ ] 基于 ChatGPT 对话机器人
- [x] 支持飞书机器人
- [x] 支持企业微信机器人
- [x] 服务支持配置多个机器人
# -*- coding: utf-8 -*-
import yaml
import logging
from logging.config import dictConfig
from flask import Flask
from flask import request, jsonify
from .platform.feishu import EchoServer as FeishuEchoServer
from .platform.feishu import ChatGPTServer as FeishuChatGPTServer
from .platform.feishu import YDLGPTServer as FeishuYDLGPTServer
from .platform.wework import EchoServer as WeworkEchoServer
from .platform.wework import ChatGPTServer as WeworkChatGPTServer
from .platform.wework import YDLGPTServer as WeworkYDLGPTServer
dictConfig({
'version': 1,
'formatters': {
'default': {
'format': '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
}
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'level': 'DEBUG',
'formatter': 'default'
}
},
'root': {
'level': 'DEBUG',
'handlers': ['console'],
}
})
def create_server(app, feishu_bots, wework_bots):
if feishu_bots:
@app.route('/feishu/<name>', methods=['GET', 'POST'])
def feishu_server(name):
if name not in feishu_bots:
return 'url error'
if request.method == 'POST':
app.logger.info(request)
res = feishu_bots[name].handle(request)
app.logger.info(res)
if isinstance(res, str):
return res
else:
return jsonify(res)
return '<p>Hello, World!</p>'
if wework_bots:
@app.route('/wework/<name>', methods=['GET', 'POST'])
def wework_server(name):
if name not in wework_bots:
return 'url error'
if request.method == 'POST':
app.logger.info(request)
res = wework_bots[name].handle(request)
app.logger.info(res)
if isinstance(res, str):
return res
else:
return jsonify(res)
else:
return wework_bots[name].client.vertify_url(request)
def create_bots():
feishu_bots, wework_bots = {}, {}
with open('config.yaml') as f:
configs = yaml.load_all(f, Loader=yaml.FullLoader)
for config in configs:
if config['platform'] == 'feishu':
if config['bot'] not in ('echo', 'chatgpt', 'ydl_gpt'):
raise
if config['bot'] == 'echo':
feishu_bots[config['name']] = FeishuEchoServer(config)
if config['bot'] == 'chatgpt':
feishu_bots[config['name']] = FeishuChatGPTServer(config)
if config['bot'] == 'ydl_gpt':
feishu_bots[config['name']] = FeishuYDLGPTServer(config)
if config['platform'] == 'wework':
if config['bot'] not in ('echo', 'chatgpt', 'ydl_gpt'):
raise
if config['bot'] == 'echo':
wework_bots[config['name']] = WeworkEchoServer(config)
if config['bot'] == 'chatgpt':
wework_bots[config['name']] = WeworkChatGPTServer(config)
if config['bot'] == 'ydl_gpt':
wework_bots[config['name']] = WeworkYDLGPTServer(config)
return feishu_bots, wework_bots
app = Flask(__name__)
app.config['timeout'] = 120
app.logger.setLevel(logging.DEBUG)
feishu_bots, wework_bots = create_bots()
create_server(app, feishu_bots, wework_bots)
# -*- coding: utf-8 -*-
from .gpt3.gpt3_chat_bot import GPT3ChatBot
from .ydl_gpt import YDLGPTBot
\ No newline at end of file
# -*- coding: utf-8 -*-
import json
class ConversationManager():
def __init__(self):
self.conversations = {}
def add_conversation(self, key: str, history: list = []) -> None:
self.conversations[key] = history
def get_conversation(self, key: str) -> list:
if key not in self.conversations:
self.add_conversation(key)
return self.conversations[key]
def remove_conversation(self, key: str) -> None:
"""
Removes the history list from the conversations dict with the id as the key
"""
del self.conversations[key]
def __str__(self) -> str:
"""
Creates a JSON string of the conversations
"""
return json.dumps(self.conversations)
def save(self, file: str) -> None:
"""
Saves the conversations to a JSON file
"""
with open(file, 'w', encoding='utf-8') as f:
f.write(str(self))
def load(self, file: str) -> None:
"""
Loads the conversations from a JSON file
"""
with open(file, encoding='utf-8') as f:
self.conversations = json.loads(f.read())
# -*- coding: utf-8 -*-
import os
import re
import json
import logging
import openai
import tiktoken
from .prompt import Prompt
from .conversation_manager import ConversationManager
ENCODER = tiktoken.get_encoding('gpt2')
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class GPT3ChatBot():
def __init__(self, api_key: str = None, engine: str = None, proxy: str = None):
openai.api_key = api_key or os.getenv('OPENAI_API_KEY')
openai.proxy = proxy or os.getenv("OPENAI_API_PROXY")
self.engine = engine or os.getenv('GPT_ENGINE') or 'text-davinci-003'
self.max_token = 4000
self.prompt = Prompt(max_token=self.max_token)
self.conversation_manager = ConversationManager()
def _get_completion(self, prompt: str, temperature: float = 0.5, stream: bool = False):
return openai.Completion.create(
engine=self.engine,
prompt=prompt,
temperature=temperature,
stop=['\n\n\n'],
stream=stream
)
def parse_completion(self, completion: dict) -> str:
if completion.get('choices') is None:
raise Exception('ChatGPT API returned no choices')
if len(completion['choices']) == 0:
raise Exception('ChatGPT API returned no choices')
if completion['choices'][0].get('text') is None:
raise Exception('ChatGPT API returned no text')
response = re.sub(r'<\|im_end\|>', '', completion['choices'][0]['text'])
completion['choices'][0]['text'] = response
return response
def ask(self, query, temperature: float = 0.5, conversation_id: str = None, user: str = 'User'):
if conversation_id is None:
history = []
else:
history = self.conversation_manager.get_conversation(conversation_id)
prompt, history = self.prompt.encode(query, history, user)
completion = None
try:
logger.debug('请求 OpenAI 接口 %s', prompt)
completion = openai.Completion.create(
engine=self.engine,
prompt=prompt,
temperature=temperature,
max_tokens=self.max_token - len(prompt),
stop=['\n\n\n'],
stream=False
)
logger.debug('OpenAI 接口返回 %s', json.dumps(completion, ensure_ascii=False))
except Exception as e:
logger.error('OpenAI 接口返回异常,%s', e, exc_info=True)
return ''
response = self.parse_completion(completion)
if conversation_id is not None:
history.append([query, response])
self.conversation_manager.add_conversation(conversation_id, history)
return response
class AsyncGPT3ChatBot(GPT3ChatBot):
async def ask(self, query, temperature: float = 0.5, conversation_id: str = None, user: str = 'User'):
if conversation_id is None:
history = []
else:
history = self.conversation_manager.get_conversation(conversation_id)
prompt, history = self.prompt.encode(query, history, user)
completion = None
try:
logger.debug('请求 OpenAI 接口 %s', prompt)
completion = openai.Completion.create(
engine=self.engine,
prompt=prompt,
temperature=temperature,
max_tokens=self.max_token - len(prompt),
stop=['\n\n\n'],
stream=False
)
logger.debug('OpenAI 接口返回 %s', json.dumps(completion, ensure_ascii=False))
except Exception as e:
logger.error('OpenAI 接口返回异常,%s', e, exc_info=True)
return ''
response = self.parse_completion(completion)
if conversation_id is not None:
history.append([query, response])
self.conversation_manager.add_conversation(conversation_id, history)
return response
# -*- coding: utf-8 -*-
from datetime import date
import tiktoken
ENCODER = tiktoken.get_encoding('gpt2')
HEADER_PROMPT_TEMPLATE = '''{prologue} {date}
{user}: Hello
ChatGPT: Hello! How can I help you today? <|im_end|>
'''
ROUND_PROMPT_TEMPLATE = '''{user}: {query}
ChatGPT: {response}<|im_end|>
'''
QUERY_TEMPLATE = '''{user}: {query}
ChatGPT:'''
class Prompt():
def __init__(self, max_token=4000, buffer: int = None, prologue=None):
buffer = buffer or 800
self.max_token = max_token - buffer
self.buffer = buffer
self.default_prologue = prologue or (
'You are ChatGPT, a large language model trained by OpenAI. '
'Respond conversationally. Do not answer as the user.'
)
def encode_header(self, user: str) -> str:
return HEADER_PROMPT_TEMPLATE.format(
prologue=self.default_prologue,
date=str(date.today()),
user=user,
)
def encode_history_round(self, query: str, response: str = None, user: str = 'User') -> str:
return ROUND_PROMPT_TEMPLATE.format(
user=user,
query=query,
response=response
)
def encode_history(self, history: list, user: str = 'User') -> str:
return ''.join([
self.encode_history_round(query, response, user)
for [query, response] in history
])
def encode_query(self, query: str, user: str = 'User') -> str:
return QUERY_TEMPLATE.format(
user=user,
query=query
)
def encode(self, query: str, history: list = [], user: str = 'User'):
prompt = self.encode_header(user)
prompt += self.encode_history(history, user)
prompt += self.encode_query(query, user)
prompt = ENCODER.encode(prompt)
if (len(prompt) > self.max_token) and (len(history) > 0):
return self.encode(query, history[1:], user)
return prompt, history
if __name__ == '__main__':
prompt = Prompt()
query = '帮我写一段快速排序的代码'
history = []
pt, _ = prompt.encode(query, history)
print(pt)
history = [['你好', '你好,请问有什么可以帮你的吗']]
pt, _ = prompt.encode(query, history)
print(len(pt))
history = [['你好', '你好,请问有什么可以帮你的吗'] for _ in range(200)]
pt, _ = prompt.encode(query, history)
print(len(pt))
# -*- coding: utf-8 -*-
import requests
import logging
class YDLGPTBot():
def __init__(self, url, app_id, scene) -> None:
self.url = url
self.app_id = app_id
self.scene = scene
self.headers = {'Content-Type': 'application/json'}
def ask(self, text, sender_id='', chat_id=''):
body = {
'dto': {
'appId': self.app_id,
'userId': sender_id,
'conversationId': chat_id,
'input': text,
'scene': self.scene,
}
}
logging.debug(body)
response = requests.post(self.url, headers=self.headers, json=body)
res_data = response.json()
logging.debug(res_data)
reply_text = res_data['data']['aiSide']
return reply_text
---
name: feishu1
platform: feishu
bot: echo
FEISHU_APP_ID:
FEISHU_APP_SECRET:
FEISHU_ENCRYPT_KEY:
---
name: feishu2
platform: feishu
bot: chatgpt
FEISHU_APP_ID:
FEISHU_APP_SECRET:
FEISHU_ENCRYPT_KEY:
# chatgpt 需要
OPENAI_API_KEY:
---
name: wework1
platform: wework
bot: echo
WEWORK_TOKEN:
WEWORK_ENCODING_AES_KEY:
WEWORK_CORP_ID:
WEWORK_SECRET:
WEWORK_AGENTID:
# -*- coding: utf-8 -*-
from .client import Client
from .server import BaseServer
from .server import EchoServer
from .server import ChatGPTServer
from .server import YDLGPTServer
# -*- coding: utf-8 -*-
import json
import time
import hashlib
import logging
import requests
from requests_toolbelt import MultipartEncoder
from .tool import AESCipher
logger = logging.getLogger(__name__)
class Client():
def __init__(self, config):
self.app_id = config['FEISHU_APP_ID']
self.app_secret = config['FEISHU_APP_SECRET']
self.encrypt_key = config['FEISHU_ENCRYPT_KEY']
self.is_valid_message = False
self.tenant_access_token = ''
self.expiry_time = time.time()
self.id_map = {}
self.message_ids = set()
self.update_tenant_access_token()
self.get_bot_info()
self.cipher = AESCipher(self.encrypt_key)
def _valid_message(self, request):
headers = request.headers
timestamp = headers['X-Lark-Request-Timestamp']
nonce = headers['X-Lark-Request-Nonce']
bytes_b1 = (timestamp + nonce + self.encrypt_key).encode('utf-8')
bytes_b = bytes_b1 + request.data
h = hashlib.sha256(bytes_b)
signature = h.hexdigest()
return request.headers['X-Lark-Signature'] == signature
def get_bot_info(self):
"""
获取机器人信息
"""
url = 'https://open.feishu.cn/open-apis/bot/v3/info'
headers = {
'Authorization': 'Bearer {}'.format(self.tenant_access_token),
'Content-Type': 'application/json; charset=utf-8'
}
response = requests.request('GET', url, headers=headers)
data = response.json()
logger.debug('获取bot信息成功')
logger.debug('--' * 20)
logger.debug(data)
self.bot_open_id = data.get('bot', {}).get('open_id')
return data
def update_tenant_access_token(self):
logger.debug('更新 tenant_access_token')
url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal'
headers = {
'Content-Type': 'application/json; charset=utf-8'
}
body = {
'app_id': self.app_id,
'app_secret': self.app_secret,
}
response = requests.request('POST', url, headers=headers, data=json.dumps(body))
data = response.json()
self.expiry_time = data['expire'] + time.time() - 10
self.tenant_access_token = data['tenant_access_token']
logger.debug('tenant_access_token 更新成功')
return
def upload_img(self, image, image_file_type='file'):
""" 上传图片 """
url = 'https://open.feishu.cn/open-apis/im/v1/images'
if image_file_type == 'file':
form = {'image_type': 'message', 'image': (open(image, 'rb'))}
elif image_file_type == 'net':
_response = requests.get(image)
form = {'image_type': 'message', 'image': _response.content}
else:
logger.error('不支持的图片文件类型')
return
multi_form = MultipartEncoder(form)
headers = {
'Authorization': 'Bearer {}'.format(self.tenant_access_token),
'Content-Type': multi_form.content_type
}
response = requests.request("POST", url, headers=headers, data=multi_form)
logger.debug(response.content)
data = response.json()
return data.get('data', {}).get('image_key')
# 判断机器人是否被 @
def is_be_at(self, msg):
for mention in msg.get('mentions', []):
if mention['id']['open_id'] == self.bot_open_id:
return mention['key']
return False
def _reply(self, message_id, data):
if time.time() > self.expiry_time:
self.update_tenant_access_token()
url = 'https://open.feishu.cn/open-apis/im/v1/messages/{}/reply'.format(message_id)
headers = {
'Authorization': 'Bearer {}'.format(self.tenant_access_token),
'Content-Type': 'application/json; charset=utf-8'
}
logger.debug('--' * 20)
logger.debug('\n'+json.dumps(data, ensure_ascii=False))
response = requests.request('POST', url, headers=headers, data=json.dumps(data))
logger.debug(response.text)
return response
def parse_message(self, request):
content = request.json
if 'encrypt' in content:
decrypt_str = self.cipher.decrypt_string(content['encrypt'])
content = json.loads(decrypt_str)
return content
def reply_text(self, message_id, chat_id, text):
if time.time() > self.expiry_time:
self.update_tenant_access_token()
data = {
'receive_id': chat_id,
'content': json.dumps({'text': text}),
'msg_type': 'text',
}
return self._reply(message_id, data)
def reply_img(self, message_id, chat_id, image):
if time.time() > self.expiry_time:
self.update_tenant_access_token()
if image.startswith('http'):
image_file_type = 'net'
else:
image_file_type = 'file'
img_key = self.upload_img(image, image_file_type=image_file_type)
logger.info('image upload, key = %s', img_key)
data = {
'receive_id': chat_id,
'content': json.dumps({'image_key': img_key}),
'msg_type': 'image',
}
return self._reply(message_id, data)
def handle_p2p(self, message):
"""
单聊消息处理
"""
msg = message['event']['message']
msg_content = json.loads(msg['content'])
msg_text = msg_content['text']
self.reply_text(msg['message_id'], msg['chat_id'], '收到消息:' + msg_text)
return 'success'
def handle_group(self, message):
"""
处理群消息
"""
msg = message['event']['message']
msg_content = json.loads(msg['content'])
at_key = self.is_be_at(msg)
if at_key is False:
return 'ignore group message'
msg_text = msg_content['text'].replace(at_key + ' ', '')
self.reply_text(msg['message_id'], msg['chat_id'], '收到消息:' + msg_text)
return 'success'
def handle(self, request):
if self.is_valid_message:
if not self.valid(request):
logger.info('received message is invalid')
received_message = self.parse_message(request)
logger.debug(received_message)
if not received_message:
logger.error('received message is None')
# 飞书认证逻辑
if 'challenge' in received_message:
return {'challenge': received_message['challenge']}
r_event = received_message['event']
msg = r_event['message']
msg_id = msg['message_id']
# 重复消息忽略
if msg_id in self.message_ids:
return 'ignore'
else:
self.message_ids.add(msg_id)
# 判断是单聊还是群消息
if msg['chat_type'] == 'p2p':
return self.handle_p2p(received_message)
elif msg['chat_type'] == 'group':
return self.handle_group(received_message)
else:
return 'not support chat_type'
# -*- coding: utf-8 -*-
import os
import re
import json
import logging
import openai
from .client import Client
from ...bot import GPT3ChatBot, YDLGPTBot
logger = logging.getLogger()
class BaseServer():
def handle(self):
"""
处理消息
"""
raise NotImplementedError
class EchoServer(BaseServer):
def __init__(self, config):
self.is_valid_message = False
self.client = Client(config)
self.message_ids = set()
def handle_p2p(self, message):
"""
单聊消息处理
"""
msg = message['event']['message']
msg_content = json.loads(msg['content'])
msg_text = msg_content['text']
self.client.reply_text(msg['message_id'], msg['chat_id'], '收到消息:' + msg_text)
return 'success'
def handle_group(self, message):
"""
处理群消息
"""
msg = message['event']['message']
msg_content = json.loads(msg['content'])
at_key = self.client.is_be_at(msg)
if at_key is False:
return 'ignore group message'
msg_text = msg_content['text'].replace(at_key + ' ', '')
self.client.reply_text(msg['message_id'], msg['chat_id'], '收到消息:' + msg_text)
return 'success'
def handle(self, request):
if self.is_valid_message:
if not self.client.valid(request):
logger.info('received message is invalid')
received_message = self.client.parse_message(request)
logger.debug(received_message)
if not received_message:
logger.error('received message is None')
# 飞书认证逻辑
if 'challenge' in received_message:
return {'challenge': received_message['challenge']}
r_event = received_message['event']
msg = r_event['message']
msg_id = msg['message_id']
# 重复消息忽略
if msg_id in self.message_ids:
return 'ignore'
else:
self.message_ids.add(msg_id)
# 判断是单聊还是群消息
if msg['chat_type'] == 'p2p':
return self.handle_p2p(received_message)
elif msg['chat_type'] == 'group':
return self.handle_group(received_message)
else:
return 'not support chat_type'
class ChatGPTServer(BaseServer):
def __init__(self, config):
super().__init__()
self.chatbot = GPT3ChatBot(api_key=config['OPENAI_API_KEY'], engine=config.get('ENGINE'), proxy=config.get('PROXY'))
self.support_image_generation = True
self.is_valid_message = False
self.client = Client(config)
openai.api_key = config['OPENAI_API_KEY']
self.message_ids = set()
def ask(self, text, sender_id=None):
if text == 'Brazil':
self.chatbot.reset()
return 'bot is reset'
response = self.chatbot.ask(text, conversation_id=sender_id)
logger.debug('chatgpt response', response)
return response
def process(self, message, msg_text):
event = message['event']
msg = event['message']
sender_id = event['sender']['sender_id']['open_id']
if msg_text.startswith('作图 '):
if self.support_image_generation is False:
self.client.reply_text(msg['message_id'], msg['chat_id'], '暂不提供AI作图能力')
return 'success'
prompt = re.sub(r'^作图 ', '', msg_text)
try:
response = openai.Image.create(prompt=prompt, n=1, size='512x512')
except Exception as e:
logger.error(e)
print(e)
return 'generate image error'
img_url = response['data'][0]['url']
self.client.reply_img(msg['message_id'], msg['chat_id'], img_url)
return 'success'
chatgpt_res_text = '获取chatgpt回复消息失败'
try:
chatgpt_res_text = self.ask(msg_text, sender_id=sender_id)
except Exception as e:
logger.error(e)
print(e)
self.client.reply_text(msg['message_id'], msg['chat_id'], '服务出了点问题,请重试')
return 'error'
self.client.reply_text(msg['message_id'], msg['chat_id'], chatgpt_res_text)
return 'success'
def handle_p2p(self, message):
"""
单聊消息处理
"""
event = message['event']
msg = event['message']
msg_content = json.loads(msg['content'])
msg_text = msg_content['text']
return self.process(message, msg_text)
def handle_group(self, message):
"""
处理群消息
"""
event = message['event']
msg = event['message']
msg_content = json.loads(msg['content'])
at_key = self.client.is_be_at(msg)
if at_key is False:
return 'ignore group message'
msg_text = msg_content['text'].replace(at_key + ' ', '')
return self.process(message, msg_text)
def handle(self, request):
if self.is_valid_message:
if not self.client.valid(request):
logger.info('received message is invalid')
received_message = self.client.parse_message(request)
logger.debug(received_message)
if not received_message:
logger.error('received message is None')
# 飞书认证逻辑
if 'challenge' in received_message:
return {'challenge': received_message['challenge']}
r_event = received_message['event']
msg = r_event['message']
msg_id = msg['message_id']
# 重复消息忽略
if msg_id in self.message_ids:
return 'ignore'
else:
self.message_ids.add(msg_id)
# 判断是单聊还是群消息
if msg['chat_type'] == 'p2p':
return self.handle_p2p(received_message)
elif msg['chat_type'] == 'group':
return self.handle_group(received_message)
else:
return 'not support chat_type'
class YDLGPTServer(BaseServer):
def __init__(self, config):
super().__init__()
self.chatbot = YDLGPTBot(url=config['CUSTOM_CHATGPT_URL'], app_id=config['YDL_APP_ID'], scene=config['YDL_SCENE'])
self.client = Client(config)
self.is_valid_message = False
self.message_ids = set()
def process(self, message, msg_text):
event = message['event']
msg = event['message']
sender_id = event['sender']['sender_id']['open_id']
chatgpt_res_text = '获取chatgpt回复消息失败'
try:
chatgpt_res_text = self.chatbot.ask(msg_text, sender_id=sender_id, chat_id=msg['chat_id'])
except Exception as e:
logger.exception(e)
self.client.reply_text(msg['message_id'], msg['chat_id'], '服务出了点问题,请重试')
return 'error'
self.client.reply_text(msg['message_id'], msg['chat_id'], chatgpt_res_text)
return 'success'
def handle_p2p(self, message):
"""
单聊消息处理
"""
event = message['event']
msg = event['message']
msg_content = json.loads(msg['content'])
if 'text' in msg_content:
msg_text = msg_content['text']
elif 'content' in msg_content:
msgs = []
for items in msg_content['content']:
for item in items:
msgs.append(item['text'])
msg_text = '\n'.join(msgs)
else:
logger.error(msg_content)
raise
return self.process(message, msg_text)
def handle_group(self, message):
"""
处理群消息
"""
event = message['event']
msg = event['message']
msg_content = json.loads(msg['content'])
at_key = self.client.is_be_at(msg)
if at_key is False:
return 'ignore group message'
msg_text = msg_content['text'].replace(at_key + ' ', '')
return self.process(message, msg_text)
def handle(self, request):
if self.is_valid_message:
if not self.client.valid(request):
logger.info('received message is invalid')
received_message = self.client.parse_message(request)
logger.debug(received_message)
if not received_message:
logger.error('received message is None')
# 飞书认证逻辑
if 'challenge' in received_message:
return {'challenge': received_message['challenge']}
r_event = received_message['event']
msg = r_event['message']
msg_id = msg['message_id']
# 重复消息忽略
if msg_id in self.message_ids:
return 'ignore'
else:
self.message_ids.add(msg_id)
# 判断是单聊还是群消息
if msg['chat_type'] == 'p2p':
return self.handle_p2p(received_message)
elif msg['chat_type'] == 'group':
return self.handle_group(received_message)
else:
return 'not support chat_type'
# -*- coding: utf-8 -*-
import base64
import hashlib
from Crypto.Cipher import AES
class AESCipher(object):
def __init__(self, key):
self.bs = AES.block_size
self.key = hashlib.sha256(AESCipher.str_to_bytes(key)).digest()
@staticmethod
def str_to_bytes(data):
u_type = type(b"".decode('utf8'))
if isinstance(data, u_type):
return data.encode('utf8')
return data
@staticmethod
def _unpad(s):
return s[:-ord(s[len(s) - 1:])]
def decrypt(self, enc):
iv = enc[:AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size:]))
def decrypt_string(self, enc):
enc = base64.b64decode(enc)
return self.decrypt(enc).decode('utf8')
# 企业微信机器人
主要的几个文档
- [开发指南](https://developer.work.weixin.qq.com/document/path/90664)
- [发送应用消息](https://developer.work.weixin.qq.com/document/path/90236)
- [接收消息](https://developer.work.weixin.qq.com/document/path/90238)
几个需要注意的点
- 接受消息的url,必须是与企业相关的域名;
- 如果要发送消息,必须保证发送消息的服务器的ip在 “企业可信IP” 列表中;
# -*- coding: utf-8 -*-
from .client import Client
from .server import BaseServer
from .server import EchoServer
from .server import ChatGPTServer
from .server import YDLGPTServer
# -*- coding: utf-8 -*-
import os
import time
import json
import logging
import requests
from .WXBizMsgCrypt import WXBizMsgCrypt
logger = logging.getLogger(__name__)
class Client():
def __init__(self, config):
self.token = config['WEWORK_TOKEN']
self.encoding_aes_key = config['WEWORK_ENCODING_AES_KEY']
self.corp_id = config['WEWORK_CORP_ID']
self.secret = config['WEWORK_SECRET']
self.agentid = int(config['WEWORK_AGENTID'])
self.access_token = ''
self.expiry_time = time.time()
self.wxcpt = WXBizMsgCrypt(self.token, self.encoding_aes_key, self.corp_id)
self.update_access_token()
def vertify_url(self, request):
logger.info('request for vertify url')
logger.info(request.args)
msg_signature = request.args.get('msg_signature')
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
echo_str = request.args.get('echostr')
ret, echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echo_str)
if ret != 0:
logger.error('ERR: Vertify URL ret ', ret)
return echo_str
def update_access_token(self):
logger.debug('更新 access_token')
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corp_id}&corpsecret={self.secret}'
response = requests.request('GET', url)
data = response.json()
self.expiry_time = data['expires_in'] + time.time() - 10
self.access_token = data['access_token']
logger.debug('access_token 更新成功')
return
def decrypt_msg(self, request):
msg_signature = request.args.get('msg_signature')
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
ret, msg = self.wxcpt.DecryptMsg(request.data, msg_signature, timestamp, nonce)
if ret != 0:
logger.error('decrypt msg error')
return None
return msg
def encrypt_msg(self, resp_data, req_nonce, timestamp):
# timestamp = time.time()
ret, encrypted_msg = self.wxcpt.EncryptMsg(resp_data, req_nonce, timestamp)
if ret != 0:
return None
return encrypted_msg
def send_msg(self, content: str, to_user):
if time.time() > self.expiry_time:
self.update_access_token()
data = {
'touser': to_user,
'msgtype': 'text',
'agentid': self.agentid,
'text': {
'content': content
},
'safe': 0,
'enable_id_trans': 0,
'enable_duplicate_check': 0,
'duplicate_check_interval': 1800
}
logger.debug('--' * 20)
logger.debug('\n'+json.dumps(data, ensure_ascii=False))
url = f'https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={self.access_token}&debug=1'
response = requests.request('POST', url, data=json.dumps(data))
res_data = response.json()
if res_data['errcode'] != 0:
logger.error('send message error: %s', res_data['errmsg'])
logger.debug(response.text)
return response
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#########################################################################
# Author: jonyqin
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
# File Name: ierror.py
# Description:定义错误码含义
#########################################################################
WXBizMsgCrypt_OK = 0
WXBizMsgCrypt_ValidateSignature_Error = -40001
WXBizMsgCrypt_ParseXml_Error = -40002
WXBizMsgCrypt_ComputeSignature_Error = -40003
WXBizMsgCrypt_IllegalAesKey = -40004
WXBizMsgCrypt_ValidateCorpid_Error = -40005
WXBizMsgCrypt_EncryptAES_Error = -40006
WXBizMsgCrypt_DecryptAES_Error = -40007
WXBizMsgCrypt_IllegalBuffer = -40008
WXBizMsgCrypt_EncodeBase64_Error = -40009
WXBizMsgCrypt_DecodeBase64_Error = -40010
WXBizMsgCrypt_GenReturnXml_Error = -40011
# -*- coding: utf-8 -*-
import os
import logging
import xml.etree.cElementTree as ET
from concurrent.futures import ThreadPoolExecutor
import requests
from .client import Client
from ...bot import GPT3ChatBot, YDLGPTBot
logger = logging.getLogger(__name__)
thread_pool = ThreadPoolExecutor(max_workers=16)
class BaseServer():
def handle(self):
raise NotImplementedError
class EchoServer(BaseServer):
def __init__(self, config):
self.client = Client(config)
def handle_p2p(self, tree, request):
"""
单聊消息处理
"""
content = tree.find('Content').text
to_user = tree.find('ToUserName').text
from_user = tree.find('FromUserName').text
create_time = tree.find('CreateTime').text
# msg_type = tree.find('MsgType').text
resp_data = f'to_user: {to_user}, from_user: {from_user}, create_time: {create_time}, content: {content}'
self.client.send_msg(resp_data, from_user)
return ''
def handle(self, request):
msg = self.client.decrypt_msg(request)
if msg is None:
return 'None msg'
tree = ET.fromstring(msg)
if not tree:
return 'None'
return self.handle_p2p(tree, request)
class ChatGPTServer(BaseServer):
def __init__(self, config):
self.client = Client(config)
self.chatgpt_url = config['CUSTOM_CHATGPT_URL']
def ask(self, content, to_user, from_user):
headers = {'Content-Type': 'application/json'}
body = {'botDTO': {
'appId': 'yunjiaoqiwei',
'userId': to_user,
'conversationId': from_user,
'input': content,
'scene': 'psy-chat-bot-common'
}}
response = requests.post(self.chatgpt_url, headers=headers, json=body)
res_data = response.json()
logger.debug('chatgpt response', res_data)
reply_text = res_data['data']
self.client.send_msg(reply_text, from_user)
def handle_p2p(self, tree, request):
"""
单聊消息处理
"""
content = tree.find('Content').text
to_user = tree.find('ToUserName').text
from_user = tree.find('FromUserName').text
thread_pool.submit(self.ask, content, to_user, from_user)
return ''
def handle(self, request):
msg = self.client.decrypt_msg(request)
if msg is None:
return 'None msg'
tree = ET.fromstring(msg)
if not tree:
return 'None'
return self.handle_p2p(tree, request)
class YDLGPTServer(BaseServer):
def __init__(self, config):
self.client = Client(config)
self.chatbot = YDLGPTBot(url=config['CUSTOM_CHATGPT_URL'], app_id=config['YDL_APP_ID'], scene=config['YDL_SCENE'])
self.chatgpt_url = config['CUSTOM_CHATGPT_URL']
def ask(self, content, to_user, from_user):
chatgpt_res_text = self.chatbot.ask(content, sender_id=to_user, chat_id=from_user)
self.client.send_msg(chatgpt_res_text, from_user)
def handle_p2p(self, tree, request):
"""
单聊消息处理
"""
content = tree.find('Content').text
to_user = tree.find('ToUserName').text
from_user = tree.find('FromUserName').text
thread_pool.submit(self.ask, content, to_user, from_user)
return ''
def handle(self, request):
msg = self.client.decrypt_msg(request)
if msg is None:
return 'None msg'
tree = ET.fromstring(msg)
if not tree:
return 'None'
return self.handle_p2p(tree, request)
\ No newline at end of file
#!/bin/bash
# source activate aibot
# export GPT_ENGINE='text-davinci-003'
# nohup flask run --host '0.0.0.0' --port 4202 >> run.log 2>&1 &
nohup flask run --port 4202 >> run.log 2>&1 &
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment