Commit 448f26e8 by chai pengfei

Initial commit

parents
# OpenAIBot
将 OpenAI 的能力包装成各类机器人
# 安装
```
git clone https://github.com/pfchai/OpenAIBot.git
cd OpenAIBot
# 进入虚拟环境
pip install -r requirements.txt
```
# 启动服务
```
# 配置
cp config.yaml.example config.yaml
# 补充相关配置信息
vi config.yaml
# 启动服务
flask run
# 指定端口,外网可用
flask run --host 0.0.0.0 --port 8000
```
# ToDos
- [x] 基于 GPT3 对话机器人
- [ ] 基于 ChatGPT 对话机器人
- [x] 支持飞书机器人
- [x] 支持企业微信机器人
- [x] 服务支持配置多个机器人
# -*- coding: utf-8 -*-
import yaml
import logging
from logging.config import dictConfig
from flask import Flask
from flask import request, jsonify
from .platform.feishu import EchoServer as FeishuEchoServer
from .platform.feishu import ChatGPTServer as FeishuChatGPTServer
from .platform.feishu import YDLGPTServer as FeishuYDLGPTServer
from .platform.wework import EchoServer as WeworkEchoServer
from .platform.wework import ChatGPTServer as WeworkChatGPTServer
from .platform.wework import YDLGPTServer as WeworkYDLGPTServer
dictConfig({
'version': 1,
'formatters': {
'default': {
'format': '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
}
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'level': 'DEBUG',
'formatter': 'default'
}
},
'root': {
'level': 'DEBUG',
'handlers': ['console'],
}
})
def create_server(app, feishu_bots, wework_bots):
if feishu_bots:
@app.route('/feishu/<name>', methods=['GET', 'POST'])
def feishu_server(name):
if name not in feishu_bots:
return 'url error'
if request.method == 'POST':
app.logger.info(request)
res = feishu_bots[name].handle(request)
app.logger.info(res)
if isinstance(res, str):
return res
else:
return jsonify(res)
return '<p>Hello, World!</p>'
if wework_bots:
@app.route('/wework/<name>', methods=['GET', 'POST'])
def wework_server(name):
if name not in wework_bots:
return 'url error'
if request.method == 'POST':
app.logger.info(request)
res = wework_bots[name].handle(request)
app.logger.info(res)
if isinstance(res, str):
return res
else:
return jsonify(res)
else:
return wework_bots[name].client.vertify_url(request)
def create_bots():
feishu_bots, wework_bots = {}, {}
with open('config.yaml') as f:
configs = yaml.load_all(f, Loader=yaml.FullLoader)
for config in configs:
if config['platform'] == 'feishu':
if config['bot'] not in ('echo', 'chatgpt', 'ydl_gpt'):
raise
if config['bot'] == 'echo':
feishu_bots[config['name']] = FeishuEchoServer(config)
if config['bot'] == 'chatgpt':
feishu_bots[config['name']] = FeishuChatGPTServer(config)
if config['bot'] == 'ydl_gpt':
feishu_bots[config['name']] = FeishuYDLGPTServer(config)
if config['platform'] == 'wework':
if config['bot'] not in ('echo', 'chatgpt', 'ydl_gpt'):
raise
if config['bot'] == 'echo':
wework_bots[config['name']] = WeworkEchoServer(config)
if config['bot'] == 'chatgpt':
wework_bots[config['name']] = WeworkChatGPTServer(config)
if config['bot'] == 'ydl_gpt':
wework_bots[config['name']] = WeworkYDLGPTServer(config)
return feishu_bots, wework_bots
app = Flask(__name__)
app.config['timeout'] = 120
app.logger.setLevel(logging.DEBUG)
feishu_bots, wework_bots = create_bots()
create_server(app, feishu_bots, wework_bots)
# -*- coding: utf-8 -*-
from .gpt3.gpt3_chat_bot import GPT3ChatBot
from .ydl_gpt import YDLGPTBot
\ No newline at end of file
# -*- coding: utf-8 -*-
import json
class ConversationManager():
def __init__(self):
self.conversations = {}
def add_conversation(self, key: str, history: list = []) -> None:
self.conversations[key] = history
def get_conversation(self, key: str) -> list:
if key not in self.conversations:
self.add_conversation(key)
return self.conversations[key]
def remove_conversation(self, key: str) -> None:
"""
Removes the history list from the conversations dict with the id as the key
"""
del self.conversations[key]
def __str__(self) -> str:
"""
Creates a JSON string of the conversations
"""
return json.dumps(self.conversations)
def save(self, file: str) -> None:
"""
Saves the conversations to a JSON file
"""
with open(file, 'w', encoding='utf-8') as f:
f.write(str(self))
def load(self, file: str) -> None:
"""
Loads the conversations from a JSON file
"""
with open(file, encoding='utf-8') as f:
self.conversations = json.loads(f.read())
# -*- coding: utf-8 -*-
import os
import re
import json
import logging
import openai
import tiktoken
from .prompt import Prompt
from .conversation_manager import ConversationManager
ENCODER = tiktoken.get_encoding('gpt2')
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class GPT3ChatBot():
def __init__(self, api_key: str = None, engine: str = None, proxy: str = None):
openai.api_key = api_key or os.getenv('OPENAI_API_KEY')
openai.proxy = proxy or os.getenv("OPENAI_API_PROXY")
self.engine = engine or os.getenv('GPT_ENGINE') or 'text-davinci-003'
self.max_token = 4000
self.prompt = Prompt(max_token=self.max_token)
self.conversation_manager = ConversationManager()
def _get_completion(self, prompt: str, temperature: float = 0.5, stream: bool = False):
return openai.Completion.create(
engine=self.engine,
prompt=prompt,
temperature=temperature,
stop=['\n\n\n'],
stream=stream
)
def parse_completion(self, completion: dict) -> str:
if completion.get('choices') is None:
raise Exception('ChatGPT API returned no choices')
if len(completion['choices']) == 0:
raise Exception('ChatGPT API returned no choices')
if completion['choices'][0].get('text') is None:
raise Exception('ChatGPT API returned no text')
response = re.sub(r'<\|im_end\|>', '', completion['choices'][0]['text'])
completion['choices'][0]['text'] = response
return response
def ask(self, query, temperature: float = 0.5, conversation_id: str = None, user: str = 'User'):
if conversation_id is None:
history = []
else:
history = self.conversation_manager.get_conversation(conversation_id)
prompt, history = self.prompt.encode(query, history, user)
completion = None
try:
logger.debug('请求 OpenAI 接口 %s', prompt)
completion = openai.Completion.create(
engine=self.engine,
prompt=prompt,
temperature=temperature,
max_tokens=self.max_token - len(prompt),
stop=['\n\n\n'],
stream=False
)
logger.debug('OpenAI 接口返回 %s', json.dumps(completion, ensure_ascii=False))
except Exception as e:
logger.error('OpenAI 接口返回异常,%s', e, exc_info=True)
return ''
response = self.parse_completion(completion)
if conversation_id is not None:
history.append([query, response])
self.conversation_manager.add_conversation(conversation_id, history)
return response
class AsyncGPT3ChatBot(GPT3ChatBot):
async def ask(self, query, temperature: float = 0.5, conversation_id: str = None, user: str = 'User'):
if conversation_id is None:
history = []
else:
history = self.conversation_manager.get_conversation(conversation_id)
prompt, history = self.prompt.encode(query, history, user)
completion = None
try:
logger.debug('请求 OpenAI 接口 %s', prompt)
completion = openai.Completion.create(
engine=self.engine,
prompt=prompt,
temperature=temperature,
max_tokens=self.max_token - len(prompt),
stop=['\n\n\n'],
stream=False
)
logger.debug('OpenAI 接口返回 %s', json.dumps(completion, ensure_ascii=False))
except Exception as e:
logger.error('OpenAI 接口返回异常,%s', e, exc_info=True)
return ''
response = self.parse_completion(completion)
if conversation_id is not None:
history.append([query, response])
self.conversation_manager.add_conversation(conversation_id, history)
return response
# -*- coding: utf-8 -*-
from datetime import date
import tiktoken
ENCODER = tiktoken.get_encoding('gpt2')
HEADER_PROMPT_TEMPLATE = '''{prologue} {date}
{user}: Hello
ChatGPT: Hello! How can I help you today? <|im_end|>
'''
ROUND_PROMPT_TEMPLATE = '''{user}: {query}
ChatGPT: {response}<|im_end|>
'''
QUERY_TEMPLATE = '''{user}: {query}
ChatGPT:'''
class Prompt():
def __init__(self, max_token=4000, buffer: int = None, prologue=None):
buffer = buffer or 800
self.max_token = max_token - buffer
self.buffer = buffer
self.default_prologue = prologue or (
'You are ChatGPT, a large language model trained by OpenAI. '
'Respond conversationally. Do not answer as the user.'
)
def encode_header(self, user: str) -> str:
return HEADER_PROMPT_TEMPLATE.format(
prologue=self.default_prologue,
date=str(date.today()),
user=user,
)
def encode_history_round(self, query: str, response: str = None, user: str = 'User') -> str:
return ROUND_PROMPT_TEMPLATE.format(
user=user,
query=query,
response=response
)
def encode_history(self, history: list, user: str = 'User') -> str:
return ''.join([
self.encode_history_round(query, response, user)
for [query, response] in history
])
def encode_query(self, query: str, user: str = 'User') -> str:
return QUERY_TEMPLATE.format(
user=user,
query=query
)
def encode(self, query: str, history: list = [], user: str = 'User'):
prompt = self.encode_header(user)
prompt += self.encode_history(history, user)
prompt += self.encode_query(query, user)
prompt = ENCODER.encode(prompt)
if (len(prompt) > self.max_token) and (len(history) > 0):
return self.encode(query, history[1:], user)
return prompt, history
if __name__ == '__main__':
prompt = Prompt()
query = '帮我写一段快速排序的代码'
history = []
pt, _ = prompt.encode(query, history)
print(pt)
history = [['你好', '你好,请问有什么可以帮你的吗']]
pt, _ = prompt.encode(query, history)
print(len(pt))
history = [['你好', '你好,请问有什么可以帮你的吗'] for _ in range(200)]
pt, _ = prompt.encode(query, history)
print(len(pt))
# -*- coding: utf-8 -*-
import requests
import logging
class YDLGPTBot():
def __init__(self, url, app_id, scene) -> None:
self.url = url
self.app_id = app_id
self.scene = scene
self.headers = {'Content-Type': 'application/json'}
def ask(self, text, sender_id='', chat_id=''):
body = {
'dto': {
'appId': self.app_id,
'userId': sender_id,
'conversationId': chat_id,
'input': text,
'scene': self.scene,
}
}
logging.debug(body)
response = requests.post(self.url, headers=self.headers, json=body)
res_data = response.json()
logging.debug(res_data)
reply_text = res_data['data']['aiSide']
return reply_text
---
name: feishu1
platform: feishu
bot: echo
FEISHU_APP_ID:
FEISHU_APP_SECRET:
FEISHU_ENCRYPT_KEY:
---
name: feishu2
platform: feishu
bot: chatgpt
FEISHU_APP_ID:
FEISHU_APP_SECRET:
FEISHU_ENCRYPT_KEY:
# chatgpt 需要
OPENAI_API_KEY:
---
name: wework1
platform: wework
bot: echo
WEWORK_TOKEN:
WEWORK_ENCODING_AES_KEY:
WEWORK_CORP_ID:
WEWORK_SECRET:
WEWORK_AGENTID:
# -*- coding: utf-8 -*-
from .client import Client
from .server import BaseServer
from .server import EchoServer
from .server import ChatGPTServer
from .server import YDLGPTServer
# -*- coding: utf-8 -*-
import json
import time
import hashlib
import logging
import requests
from requests_toolbelt import MultipartEncoder
from .tool import AESCipher
logger = logging.getLogger(__name__)
class Client():
def __init__(self, config):
self.app_id = config['FEISHU_APP_ID']
self.app_secret = config['FEISHU_APP_SECRET']
self.encrypt_key = config['FEISHU_ENCRYPT_KEY']
self.is_valid_message = False
self.tenant_access_token = ''
self.expiry_time = time.time()
self.id_map = {}
self.message_ids = set()
self.update_tenant_access_token()
self.get_bot_info()
self.cipher = AESCipher(self.encrypt_key)
def _valid_message(self, request):
headers = request.headers
timestamp = headers['X-Lark-Request-Timestamp']
nonce = headers['X-Lark-Request-Nonce']
bytes_b1 = (timestamp + nonce + self.encrypt_key).encode('utf-8')
bytes_b = bytes_b1 + request.data
h = hashlib.sha256(bytes_b)
signature = h.hexdigest()
return request.headers['X-Lark-Signature'] == signature
def get_bot_info(self):
"""
获取机器人信息
"""
url = 'https://open.feishu.cn/open-apis/bot/v3/info'
headers = {
'Authorization': 'Bearer {}'.format(self.tenant_access_token),
'Content-Type': 'application/json; charset=utf-8'
}
response = requests.request('GET', url, headers=headers)
data = response.json()
logger.debug('获取bot信息成功')
logger.debug('--' * 20)
logger.debug(data)
self.bot_open_id = data.get('bot', {}).get('open_id')
return data
def update_tenant_access_token(self):
logger.debug('更新 tenant_access_token')
url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal'
headers = {
'Content-Type': 'application/json; charset=utf-8'
}
body = {
'app_id': self.app_id,
'app_secret': self.app_secret,
}
response = requests.request('POST', url, headers=headers, data=json.dumps(body))
data = response.json()
self.expiry_time = data['expire'] + time.time() - 10
self.tenant_access_token = data['tenant_access_token']
logger.debug('tenant_access_token 更新成功')
return
def upload_img(self, image, image_file_type='file'):
""" 上传图片 """
url = 'https://open.feishu.cn/open-apis/im/v1/images'
if image_file_type == 'file':
form = {'image_type': 'message', 'image': (open(image, 'rb'))}
elif image_file_type == 'net':
_response = requests.get(image)
form = {'image_type': 'message', 'image': _response.content}
else:
logger.error('不支持的图片文件类型')
return
multi_form = MultipartEncoder(form)
headers = {
'Authorization': 'Bearer {}'.format(self.tenant_access_token),
'Content-Type': multi_form.content_type
}
response = requests.request("POST", url, headers=headers, data=multi_form)
logger.debug(response.content)
data = response.json()
return data.get('data', {}).get('image_key')
# 判断机器人是否被 @
def is_be_at(self, msg):
for mention in msg.get('mentions', []):
if mention['id']['open_id'] == self.bot_open_id:
return mention['key']
return False
def _reply(self, message_id, data):
if time.time() > self.expiry_time:
self.update_tenant_access_token()
url = 'https://open.feishu.cn/open-apis/im/v1/messages/{}/reply'.format(message_id)
headers = {
'Authorization': 'Bearer {}'.format(self.tenant_access_token),
'Content-Type': 'application/json; charset=utf-8'
}
logger.debug('--' * 20)
logger.debug('\n'+json.dumps(data, ensure_ascii=False))
response = requests.request('POST', url, headers=headers, data=json.dumps(data))
logger.debug(response.text)
return response
def parse_message(self, request):
content = request.json
if 'encrypt' in content:
decrypt_str = self.cipher.decrypt_string(content['encrypt'])
content = json.loads(decrypt_str)
return content
def reply_text(self, message_id, chat_id, text):
if time.time() > self.expiry_time:
self.update_tenant_access_token()
data = {
'receive_id': chat_id,
'content': json.dumps({'text': text}),
'msg_type': 'text',
}
return self._reply(message_id, data)
def reply_img(self, message_id, chat_id, image):
if time.time() > self.expiry_time:
self.update_tenant_access_token()
if image.startswith('http'):
image_file_type = 'net'
else:
image_file_type = 'file'
img_key = self.upload_img(image, image_file_type=image_file_type)
logger.info('image upload, key = %s', img_key)
data = {
'receive_id': chat_id,
'content': json.dumps({'image_key': img_key}),
'msg_type': 'image',
}
return self._reply(message_id, data)
def handle_p2p(self, message):
"""
单聊消息处理
"""
msg = message['event']['message']
msg_content = json.loads(msg['content'])
msg_text = msg_content['text']
self.reply_text(msg['message_id'], msg['chat_id'], '收到消息:' + msg_text)
return 'success'
def handle_group(self, message):
"""
处理群消息
"""
msg = message['event']['message']
msg_content = json.loads(msg['content'])
at_key = self.is_be_at(msg)
if at_key is False:
return 'ignore group message'
msg_text = msg_content['text'].replace(at_key + ' ', '')
self.reply_text(msg['message_id'], msg['chat_id'], '收到消息:' + msg_text)
return 'success'
def handle(self, request):
if self.is_valid_message:
if not self.valid(request):
logger.info('received message is invalid')
received_message = self.parse_message(request)
logger.debug(received_message)
if not received_message:
logger.error('received message is None')
# 飞书认证逻辑
if 'challenge' in received_message:
return {'challenge': received_message['challenge']}
r_event = received_message['event']
msg = r_event['message']
msg_id = msg['message_id']
# 重复消息忽略
if msg_id in self.message_ids:
return 'ignore'
else:
self.message_ids.add(msg_id)
# 判断是单聊还是群消息
if msg['chat_type'] == 'p2p':
return self.handle_p2p(received_message)
elif msg['chat_type'] == 'group':
return self.handle_group(received_message)
else:
return 'not support chat_type'
# -*- coding: utf-8 -*-
import os
import re
import json
import logging
import openai
from .client import Client
from ...bot import GPT3ChatBot, YDLGPTBot
logger = logging.getLogger()
class BaseServer():
def handle(self):
"""
处理消息
"""
raise NotImplementedError
class EchoServer(BaseServer):
def __init__(self, config):
self.is_valid_message = False
self.client = Client(config)
self.message_ids = set()
def handle_p2p(self, message):
"""
单聊消息处理
"""
msg = message['event']['message']
msg_content = json.loads(msg['content'])
msg_text = msg_content['text']
self.client.reply_text(msg['message_id'], msg['chat_id'], '收到消息:' + msg_text)
return 'success'
def handle_group(self, message):
"""
处理群消息
"""
msg = message['event']['message']
msg_content = json.loads(msg['content'])
at_key = self.client.is_be_at(msg)
if at_key is False:
return 'ignore group message'
msg_text = msg_content['text'].replace(at_key + ' ', '')
self.client.reply_text(msg['message_id'], msg['chat_id'], '收到消息:' + msg_text)
return 'success'
def handle(self, request):
if self.is_valid_message:
if not self.client.valid(request):
logger.info('received message is invalid')
received_message = self.client.parse_message(request)
logger.debug(received_message)
if not received_message:
logger.error('received message is None')
# 飞书认证逻辑
if 'challenge' in received_message:
return {'challenge': received_message['challenge']}
r_event = received_message['event']
msg = r_event['message']
msg_id = msg['message_id']
# 重复消息忽略
if msg_id in self.message_ids:
return 'ignore'
else:
self.message_ids.add(msg_id)
# 判断是单聊还是群消息
if msg['chat_type'] == 'p2p':
return self.handle_p2p(received_message)
elif msg['chat_type'] == 'group':
return self.handle_group(received_message)
else:
return 'not support chat_type'
class ChatGPTServer(BaseServer):
def __init__(self, config):
super().__init__()
self.chatbot = GPT3ChatBot(api_key=config['OPENAI_API_KEY'], engine=config.get('ENGINE'), proxy=config.get('PROXY'))
self.support_image_generation = True
self.is_valid_message = False
self.client = Client(config)
openai.api_key = config['OPENAI_API_KEY']
self.message_ids = set()
def ask(self, text, sender_id=None):
if text == 'Brazil':
self.chatbot.reset()
return 'bot is reset'
response = self.chatbot.ask(text, conversation_id=sender_id)
logger.debug('chatgpt response', response)
return response
def process(self, message, msg_text):
event = message['event']
msg = event['message']
sender_id = event['sender']['sender_id']['open_id']
if msg_text.startswith('作图 '):
if self.support_image_generation is False:
self.client.reply_text(msg['message_id'], msg['chat_id'], '暂不提供AI作图能力')
return 'success'
prompt = re.sub(r'^作图 ', '', msg_text)
try:
response = openai.Image.create(prompt=prompt, n=1, size='512x512')
except Exception as e:
logger.error(e)
print(e)
return 'generate image error'
img_url = response['data'][0]['url']
self.client.reply_img(msg['message_id'], msg['chat_id'], img_url)
return 'success'
chatgpt_res_text = '获取chatgpt回复消息失败'
try:
chatgpt_res_text = self.ask(msg_text, sender_id=sender_id)
except Exception as e:
logger.error(e)
print(e)
self.client.reply_text(msg['message_id'], msg['chat_id'], '服务出了点问题,请重试')
return 'error'
self.client.reply_text(msg['message_id'], msg['chat_id'], chatgpt_res_text)
return 'success'
def handle_p2p(self, message):
"""
单聊消息处理
"""
event = message['event']
msg = event['message']
msg_content = json.loads(msg['content'])
msg_text = msg_content['text']
return self.process(message, msg_text)
def handle_group(self, message):
"""
处理群消息
"""
event = message['event']
msg = event['message']
msg_content = json.loads(msg['content'])
at_key = self.client.is_be_at(msg)
if at_key is False:
return 'ignore group message'
msg_text = msg_content['text'].replace(at_key + ' ', '')
return self.process(message, msg_text)
def handle(self, request):
if self.is_valid_message:
if not self.client.valid(request):
logger.info('received message is invalid')
received_message = self.client.parse_message(request)
logger.debug(received_message)
if not received_message:
logger.error('received message is None')
# 飞书认证逻辑
if 'challenge' in received_message:
return {'challenge': received_message['challenge']}
r_event = received_message['event']
msg = r_event['message']
msg_id = msg['message_id']
# 重复消息忽略
if msg_id in self.message_ids:
return 'ignore'
else:
self.message_ids.add(msg_id)
# 判断是单聊还是群消息
if msg['chat_type'] == 'p2p':
return self.handle_p2p(received_message)
elif msg['chat_type'] == 'group':
return self.handle_group(received_message)
else:
return 'not support chat_type'
class YDLGPTServer(BaseServer):
def __init__(self, config):
super().__init__()
self.chatbot = YDLGPTBot(url=config['CUSTOM_CHATGPT_URL'], app_id=config['YDL_APP_ID'], scene=config['YDL_SCENE'])
self.client = Client(config)
self.is_valid_message = False
self.message_ids = set()
def process(self, message, msg_text):
event = message['event']
msg = event['message']
sender_id = event['sender']['sender_id']['open_id']
chatgpt_res_text = '获取chatgpt回复消息失败'
try:
chatgpt_res_text = self.chatbot.ask(msg_text, sender_id=sender_id, chat_id=msg['chat_id'])
except Exception as e:
logger.exception(e)
self.client.reply_text(msg['message_id'], msg['chat_id'], '服务出了点问题,请重试')
return 'error'
self.client.reply_text(msg['message_id'], msg['chat_id'], chatgpt_res_text)
return 'success'
def handle_p2p(self, message):
"""
单聊消息处理
"""
event = message['event']
msg = event['message']
msg_content = json.loads(msg['content'])
if 'text' in msg_content:
msg_text = msg_content['text']
elif 'content' in msg_content:
msgs = []
for items in msg_content['content']:
for item in items:
msgs.append(item['text'])
msg_text = '\n'.join(msgs)
else:
logger.error(msg_content)
raise
return self.process(message, msg_text)
def handle_group(self, message):
"""
处理群消息
"""
event = message['event']
msg = event['message']
msg_content = json.loads(msg['content'])
at_key = self.client.is_be_at(msg)
if at_key is False:
return 'ignore group message'
msg_text = msg_content['text'].replace(at_key + ' ', '')
return self.process(message, msg_text)
def handle(self, request):
if self.is_valid_message:
if not self.client.valid(request):
logger.info('received message is invalid')
received_message = self.client.parse_message(request)
logger.debug(received_message)
if not received_message:
logger.error('received message is None')
# 飞书认证逻辑
if 'challenge' in received_message:
return {'challenge': received_message['challenge']}
r_event = received_message['event']
msg = r_event['message']
msg_id = msg['message_id']
# 重复消息忽略
if msg_id in self.message_ids:
return 'ignore'
else:
self.message_ids.add(msg_id)
# 判断是单聊还是群消息
if msg['chat_type'] == 'p2p':
return self.handle_p2p(received_message)
elif msg['chat_type'] == 'group':
return self.handle_group(received_message)
else:
return 'not support chat_type'
# -*- coding: utf-8 -*-
import base64
import hashlib
from Crypto.Cipher import AES
class AESCipher(object):
def __init__(self, key):
self.bs = AES.block_size
self.key = hashlib.sha256(AESCipher.str_to_bytes(key)).digest()
@staticmethod
def str_to_bytes(data):
u_type = type(b"".decode('utf8'))
if isinstance(data, u_type):
return data.encode('utf8')
return data
@staticmethod
def _unpad(s):
return s[:-ord(s[len(s) - 1:])]
def decrypt(self, enc):
iv = enc[:AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size:]))
def decrypt_string(self, enc):
enc = base64.b64decode(enc)
return self.decrypt(enc).decode('utf8')
# 企业微信机器人
主要的几个文档
- [开发指南](https://developer.work.weixin.qq.com/document/path/90664)
- [发送应用消息](https://developer.work.weixin.qq.com/document/path/90236)
- [接收消息](https://developer.work.weixin.qq.com/document/path/90238)
几个需要注意的点
- 接受消息的url,必须是与企业相关的域名;
- 如果要发送消息,必须保证发送消息的服务器的ip在 “企业可信IP” 列表中;
#!/usr/bin/env python
# -*- encoding:utf-8 -*-
""" 对企业微信发送给企业后台的消息加解密示例代码.
@copyright: Copyright (c) 1998-2014 Tencent Inc.
"""
# ------------------------------------------------------------------------
import logging
import base64
import random
import hashlib
import time
import struct
from Crypto.Cipher import AES
import xml.etree.cElementTree as ET
import socket
from .import ierror
"""
关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案
请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。
下载后,按照README中的“Installation”小节的提示进行pycrypto安装。
"""
class FormatException(Exception):
pass
def throw_exception(message, exception_class=FormatException):
"""my define raise exception function"""
raise exception_class(message)
class SHA1:
"""计算企业微信的消息签名接口"""
def getSHA1(self, token, timestamp, nonce, encrypt):
"""用SHA1算法生成安全签名
@param token: 票据
@param timestamp: 时间戳
@param encrypt: 密文
@param nonce: 随机字符串
@return: 安全签名
"""
try:
sortlist = [token, timestamp, nonce, encrypt]
sortlist.sort()
sha = hashlib.sha1()
sha.update("".join(sortlist).encode())
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
except Exception as e:
logger = logging.getLogger()
logger.error(e)
return ierror.WXBizMsgCrypt_ComputeSignature_Error, None
class XMLParse:
"""提供提取消息格式中的密文及生成回复消息格式的接口"""
# xml消息模板
AES_TEXT_RESPONSE_TEMPLATE = """<xml>
<Encrypt><![CDATA[%(msg_encrypt)s]]></Encrypt>
<MsgSignature><![CDATA[%(msg_signaturet)s]]></MsgSignature>
<TimeStamp>%(timestamp)s</TimeStamp>
<Nonce><![CDATA[%(nonce)s]]></Nonce>
</xml>"""
def extract(self, xmltext):
"""提取出xml数据包中的加密消息
@param xmltext: 待提取的xml字符串
@return: 提取出的加密消息字符串
"""
try:
xml_tree = ET.fromstring(xmltext)
encrypt = xml_tree.find("Encrypt")
return ierror.WXBizMsgCrypt_OK, encrypt.text
except Exception as e:
logger = logging.getLogger()
logger.error(e)
return ierror.WXBizMsgCrypt_ParseXml_Error, None
def generate(self, encrypt, signature, timestamp, nonce):
"""生成xml消息
@param encrypt: 加密后的消息密文
@param signature: 安全签名
@param timestamp: 时间戳
@param nonce: 随机字符串
@return: 生成的xml字符串
"""
resp_dict = {
'msg_encrypt': encrypt,
'msg_signaturet': signature,
'timestamp': timestamp,
'nonce': nonce,
}
resp_xml = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict
return resp_xml
class PKCS7Encoder():
"""提供基于PKCS7算法的加解密接口"""
block_size = 32
def encode(self, text):
""" 对需要加密的明文进行填充补位
@param text: 需要进行填充补位操作的明文
@return: 补齐明文字符串
"""
text_length = len(text)
# 计算需要填充的位数
amount_to_pad = self.block_size - (text_length % self.block_size)
if amount_to_pad == 0:
amount_to_pad = self.block_size
# 获得补位所用的字符
pad = chr(amount_to_pad)
return text + (pad * amount_to_pad).encode()
def decode(self, decrypted):
"""删除解密后明文的补位字符
@param decrypted: 解密后的明文
@return: 删除补位字符后的明文
"""
pad = ord(decrypted[-1])
if pad < 1 or pad > 32:
pad = 0
return decrypted[:-pad]
class Prpcrypt(object):
"""提供接收和推送给企业微信消息的加解密接口"""
def __init__(self, key):
# self.key = base64.b64decode(key+"=")
self.key = key
# 设置加解密模式为AES的CBC模式
self.mode = AES.MODE_CBC
def encrypt(self, text, receiveid):
"""对明文进行加密
@param text: 需要加密的明文
@return: 加密得到的字符串
"""
# 16位随机字符串添加到明文开头
text = text.encode()
text = self.get_random_str() + struct.pack("I", socket.htonl(len(text))) + text + receiveid.encode()
# 使用自定义的填充方式对明文进行补位填充
pkcs7 = PKCS7Encoder()
text = pkcs7.encode(text)
# 加密
cryptor = AES.new(self.key, self.mode, self.key[:16])
try:
ciphertext = cryptor.encrypt(text)
# 使用BASE64对加密后的字符串进行编码
return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext)
except Exception as e:
logger = logging.getLogger()
logger.error(e)
return ierror.WXBizMsgCrypt_EncryptAES_Error, None
def decrypt(self, text, receiveid):
"""对解密后的明文进行补位删除
@param text: 密文
@return: 删除填充补位后的明文
"""
try:
cryptor = AES.new(self.key, self.mode, self.key[:16])
# 使用BASE64对密文进行解码,然后AES-CBC解密
plain_text = cryptor.decrypt(base64.b64decode(text))
except Exception as e:
logger = logging.getLogger()
logger.error(e)
return ierror.WXBizMsgCrypt_DecryptAES_Error, None
try:
pad = plain_text[-1]
# 去掉补位字符串
# pkcs7 = PKCS7Encoder()
# plain_text = pkcs7.encode(plain_text)
# 去除16位随机字符串
content = plain_text[16:-pad]
xml_len = socket.ntohl(struct.unpack("I", content[: 4])[0])
xml_content = content[4: xml_len + 4]
from_receiveid = content[xml_len + 4:]
except Exception as e:
logger = logging.getLogger()
logger.error(e)
return ierror.WXBizMsgCrypt_IllegalBuffer, None
if from_receiveid.decode('utf8') != receiveid:
return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None
return 0, xml_content
def get_random_str(self):
""" 随机生成16位字符串
@return: 16位字符串
"""
return str(random.randint(1000000000000000, 9999999999999999)).encode()
class WXBizMsgCrypt(object):
# 构造函数
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
try:
self.key = base64.b64decode(sEncodingAESKey + "=")
assert len(self.key) == 32
except:
throw_exception("[error]: EncodingAESKey unvalid !", FormatException)
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
self.m_sToken = sToken
self.m_sReceiveId = sReceiveId
# 验证URL
# @param sMsgSignature: 签名串,对应URL参数的msg_signature
# @param sTimeStamp: 时间戳,对应URL参数的timestamp
# @param sNonce: 随机串,对应URL参数的nonce
# @param sEchoStr: 随机串,对应URL参数的echostr
# @param sReplyEchoStr: 解密之后的echostr,当return返回0时有效
# @return:成功0,失败返回对应的错误码
def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr):
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId)
return ret, sReplyEchoStr
def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None):
# 将企业回复用户的消息加密打包
# @param sReplyMsg: 企业号待回复用户的消息,xml格式的字符串
# @param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间
# @param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce
# sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的xml格式的字符串,
# return:成功0,sEncryptMsg,失败返回对应的错误码None
pc = Prpcrypt(self.key)
ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId)
encrypt = encrypt.decode('utf8')
if ret != 0:
return ret, None
if timestamp is None:
timestamp = str(int(time.time()))
# 生成安全签名
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt)
if ret != 0:
return ret, None
xmlParse = XMLParse()
return ret, xmlParse.generate(encrypt, signature, timestamp, sNonce)
def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce):
# 检验消息的真实性,并且获取解密后的明文
# @param sMsgSignature: 签名串,对应URL参数的msg_signature
# @param sTimeStamp: 时间戳,对应URL参数的timestamp
# @param sNonce: 随机串,对应URL参数的nonce
# @param sPostData: 密文,对应POST请求的数据
# xml_content: 解密后的原文,当return返回0时有效
# @return: 成功0,失败返回对应的错误码
# 验证安全签名
xmlParse = XMLParse()
ret, encrypt = xmlParse.extract(sPostData)
if ret != 0:
return ret, None
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret, xml_content = pc.decrypt(encrypt, self.m_sReceiveId)
return ret, xml_content
# -*- coding: utf-8 -*-
from .client import Client
from .server import BaseServer
from .server import EchoServer
from .server import ChatGPTServer
from .server import YDLGPTServer
# -*- coding: utf-8 -*-
import os
import time
import json
import logging
import requests
from .WXBizMsgCrypt import WXBizMsgCrypt
logger = logging.getLogger(__name__)
class Client():
def __init__(self, config):
self.token = config['WEWORK_TOKEN']
self.encoding_aes_key = config['WEWORK_ENCODING_AES_KEY']
self.corp_id = config['WEWORK_CORP_ID']
self.secret = config['WEWORK_SECRET']
self.agentid = int(config['WEWORK_AGENTID'])
self.access_token = ''
self.expiry_time = time.time()
self.wxcpt = WXBizMsgCrypt(self.token, self.encoding_aes_key, self.corp_id)
self.update_access_token()
def vertify_url(self, request):
logger.info('request for vertify url')
logger.info(request.args)
msg_signature = request.args.get('msg_signature')
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
echo_str = request.args.get('echostr')
ret, echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echo_str)
if ret != 0:
logger.error('ERR: Vertify URL ret ', ret)
return echo_str
def update_access_token(self):
logger.debug('更新 access_token')
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corp_id}&corpsecret={self.secret}'
response = requests.request('GET', url)
data = response.json()
self.expiry_time = data['expires_in'] + time.time() - 10
self.access_token = data['access_token']
logger.debug('access_token 更新成功')
return
def decrypt_msg(self, request):
msg_signature = request.args.get('msg_signature')
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
ret, msg = self.wxcpt.DecryptMsg(request.data, msg_signature, timestamp, nonce)
if ret != 0:
logger.error('decrypt msg error')
return None
return msg
def encrypt_msg(self, resp_data, req_nonce, timestamp):
# timestamp = time.time()
ret, encrypted_msg = self.wxcpt.EncryptMsg(resp_data, req_nonce, timestamp)
if ret != 0:
return None
return encrypted_msg
def send_msg(self, content: str, to_user):
if time.time() > self.expiry_time:
self.update_access_token()
data = {
'touser': to_user,
'msgtype': 'text',
'agentid': self.agentid,
'text': {
'content': content
},
'safe': 0,
'enable_id_trans': 0,
'enable_duplicate_check': 0,
'duplicate_check_interval': 1800
}
logger.debug('--' * 20)
logger.debug('\n'+json.dumps(data, ensure_ascii=False))
url = f'https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={self.access_token}&debug=1'
response = requests.request('POST', url, data=json.dumps(data))
res_data = response.json()
if res_data['errcode'] != 0:
logger.error('send message error: %s', res_data['errmsg'])
logger.debug(response.text)
return response
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#########################################################################
# Author: jonyqin
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
# File Name: ierror.py
# Description:定义错误码含义
#########################################################################
WXBizMsgCrypt_OK = 0
WXBizMsgCrypt_ValidateSignature_Error = -40001
WXBizMsgCrypt_ParseXml_Error = -40002
WXBizMsgCrypt_ComputeSignature_Error = -40003
WXBizMsgCrypt_IllegalAesKey = -40004
WXBizMsgCrypt_ValidateCorpid_Error = -40005
WXBizMsgCrypt_EncryptAES_Error = -40006
WXBizMsgCrypt_DecryptAES_Error = -40007
WXBizMsgCrypt_IllegalBuffer = -40008
WXBizMsgCrypt_EncodeBase64_Error = -40009
WXBizMsgCrypt_DecodeBase64_Error = -40010
WXBizMsgCrypt_GenReturnXml_Error = -40011
# -*- coding: utf-8 -*-
import os
import logging
import xml.etree.cElementTree as ET
from concurrent.futures import ThreadPoolExecutor
import requests
from .client import Client
from ...bot import GPT3ChatBot, YDLGPTBot
logger = logging.getLogger(__name__)
thread_pool = ThreadPoolExecutor(max_workers=16)
class BaseServer():
def handle(self):
raise NotImplementedError
class EchoServer(BaseServer):
def __init__(self, config):
self.client = Client(config)
def handle_p2p(self, tree, request):
"""
单聊消息处理
"""
content = tree.find('Content').text
to_user = tree.find('ToUserName').text
from_user = tree.find('FromUserName').text
create_time = tree.find('CreateTime').text
# msg_type = tree.find('MsgType').text
resp_data = f'to_user: {to_user}, from_user: {from_user}, create_time: {create_time}, content: {content}'
self.client.send_msg(resp_data, from_user)
return ''
def handle(self, request):
msg = self.client.decrypt_msg(request)
if msg is None:
return 'None msg'
tree = ET.fromstring(msg)
if not tree:
return 'None'
return self.handle_p2p(tree, request)
class ChatGPTServer(BaseServer):
def __init__(self, config):
self.client = Client(config)
self.chatgpt_url = config['CUSTOM_CHATGPT_URL']
def ask(self, content, to_user, from_user):
headers = {'Content-Type': 'application/json'}
body = {'botDTO': {
'appId': 'yunjiaoqiwei',
'userId': to_user,
'conversationId': from_user,
'input': content,
'scene': 'psy-chat-bot-common'
}}
response = requests.post(self.chatgpt_url, headers=headers, json=body)
res_data = response.json()
logger.debug('chatgpt response', res_data)
reply_text = res_data['data']
self.client.send_msg(reply_text, from_user)
def handle_p2p(self, tree, request):
"""
单聊消息处理
"""
content = tree.find('Content').text
to_user = tree.find('ToUserName').text
from_user = tree.find('FromUserName').text
thread_pool.submit(self.ask, content, to_user, from_user)
return ''
def handle(self, request):
msg = self.client.decrypt_msg(request)
if msg is None:
return 'None msg'
tree = ET.fromstring(msg)
if not tree:
return 'None'
return self.handle_p2p(tree, request)
class YDLGPTServer(BaseServer):
def __init__(self, config):
self.client = Client(config)
self.chatbot = YDLGPTBot(url=config['CUSTOM_CHATGPT_URL'], app_id=config['YDL_APP_ID'], scene=config['YDL_SCENE'])
self.chatgpt_url = config['CUSTOM_CHATGPT_URL']
def ask(self, content, to_user, from_user):
chatgpt_res_text = self.chatbot.ask(content, sender_id=to_user, chat_id=from_user)
self.client.send_msg(chatgpt_res_text, from_user)
def handle_p2p(self, tree, request):
"""
单聊消息处理
"""
content = tree.find('Content').text
to_user = tree.find('ToUserName').text
from_user = tree.find('FromUserName').text
thread_pool.submit(self.ask, content, to_user, from_user)
return ''
def handle(self, request):
msg = self.client.decrypt_msg(request)
if msg is None:
return 'None msg'
tree = ET.fromstring(msg)
if not tree:
return 'None'
return self.handle_p2p(tree, request)
\ No newline at end of file
#!/bin/bash
# source activate aibot
# export GPT_ENGINE='text-davinci-003'
# nohup flask run --host '0.0.0.0' --port 4202 >> run.log 2>&1 &
nohup flask run --port 4202 >> run.log 2>&1 &
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment