AwinSimpleP2P/coordinator.py
2025-05-31 04:32:13 +08:00

260 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import socket
import json
import time
import logging
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import threading
# 设置日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class CoordinatorServer:
def __init__(self, host='0.0.0.0', port=5000):
"""
初始化协调服务器。
:param host: 服务器绑定的主机地址,默认为 '0.0.0.0'
:param port: 服务器监听的端口号,默认为 5000。
"""
self.host = host
self.port = port
self.clients = defaultdict(dict) # 客户端信息字典
self.providers = defaultdict(dict) # 服务提供端信息字典
self.services = defaultdict(tuple) # 服务名称与提供者信息的映射
self.udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.udp_sock.bind((host, port))
self.udp_sock.settimeout(1) # 设置 UDP 套接字接收超时时间为 1 秒
self.running = True
self.executor = ThreadPoolExecutor(max_workers=10) # 创建线程池
logger.debug(f"协调服务器运行在 {host}:{port}")
# 启动定时清理任务
self.executor.submit(self.cleanup_expired_services)
def handle_register(self, data, addr):
"""
处理服务注册请求。
:param data: 包含服务注册信息的字典。
:param addr: 发送请求的客户端地址。
"""
try:
provider_id = data['provider_id']
services = data['services']
# 记录客户端信息
self.providers[provider_id] = {
'addr': addr,
'services': services,
'last_seen': time.time()
}
# 遍历 services 字典,记录每个服务的名称和端口
for service_name, service_port in services.items():
self.services[service_name] = (addr, service_port, provider_id)
logger.info(f"注册来自{addr}{services}")
# 直接回复注册成功消息
response = {'status': 'success', 'message': '服务注册成功'}
self.udp_sock.sendto(json.dumps(response).encode(), addr)
except Exception as e:
response = {'status': 'error', 'message': str(e)}
self.udp_sock.sendto(json.dumps(response).encode(), addr)
def handle_request(self, data, addr):
"""
处理服务请求。
:param data: 包含服务请求信息的字典。
:param addr: 发送请求的客户端地址。
"""
try:
service_name = data['service_name']
client_id = data['client_id']
if service_name not in self.services:
response = {'status': 'error', 'message': '服务未找到'}
self.udp_sock.sendto(json.dumps(response).encode(), addr)
# 记录请求客户端信息
self.clients[client_id] = {
'addr': addr,
'service_name': service_name,
'last_seen': time.time()
}
# 获取服务提供者的信息
provider_addr, internal_port, provider_id = self.services[service_name]
logger.debug(f"服务请求: {service_name} 来自 {addr}, 提供者 {provider_addr}")
response = {
'status': 'success',
'provider_addr': provider_addr,
'internal_port': internal_port,
'provider_id': provider_id
}
self.udp_sock.sendto(json.dumps(response).encode(), addr)
except Exception as e:
response = {'status': 'error', 'message': str(e)}
self.udp_sock.sendto(json.dumps(response).encode(), addr)
def handle_punch_request(self, data, addr):
"""
处理打洞请求。
:param data: 包含打洞请求信息的字典。
:param addr: 发送请求的客户端地址。
"""
try:
client_id = data['client_id']
provider_id = data['provider_id']
# 获取目标客户端信息
if provider_id not in self.providers:
response = {'status': 'error', 'message': '目标提供端未找到'}
self.udp_sock.sendto(json.dumps(response).encode(), addr)
provider_addr = self.providers[provider_id]['addr']
logger.info(f"打洞请求: {addr} -> {provider_id}")
# 通知双方对方的地址
self.udp_sock.sendto(json.dumps({
'action': 'punch_request',
'client_id': client_id,
'client_addr': addr
}).encode(), provider_addr)
self.udp_sock.sendto(json.dumps({
'status': 'success',
'provider_addr': provider_addr
}).encode(), addr)
except Exception as e:
logger.error(f"处理打洞请求时出错: {e}")
response = {'status': 'error', 'message': str(e)}
self.udp_sock.sendto(json.dumps(response).encode(), addr)
def handle_stop_provider(self, data, addr):
"""
处理停止服务请求。
:param data: 包含停止服务信息的字典。
:param addr: 发送请求的客户端地址。
"""
try:
service_name = data['service_name']
self.services.pop(service_name, None)
response = {'status': 'success', 'message': '服务停止成功'}
self.udp_sock.sendto(json.dumps(response).encode(), addr)
except Exception as e:
response = {'status': 'error', 'message': str(e)}
self.udp_sock.sendto(json.dumps(response).encode(), addr)
def handle_heartbeat(self, data, addr):
"""
处理心跳包。
:param data: 包含心跳信息的字典。
:param addr: 发送心跳的客户端地址。
"""
try:
provider_id = data['provider_id']
self.providers[provider_id]['last_seen'] = time.time()
response = {'status': 'success', 'message': '心跳更新成功'}
self.udp_sock.sendto(json.dumps(response).encode(), addr)
except Exception as e:
response = {'status': 'error', 'message': str(e)}
self.udp_sock.sendto(json.dumps(response).encode(), addr)
def cleanup_expired_services(self):
"""
定时清理过期服务。
每20秒检查一次移除超过30秒未更新心跳的服务。
"""
while self.running:
time.sleep(60) # 每60秒检查一次
current_time = time.time()
expired_providers = []
for provider_id, provider in self.providers.items():
if current_time - provider['last_seen'] > 60: # 心跳包最后更新时间大于30秒
expired_providers.append(provider_id)
logger.info(f"服务过期: {provider['addr']}")
for provider_id in expired_providers:
provider = self.providers[provider_id]
for service_name in provider['services'].keys(): # 使用新的 services 字段
self.services.pop(service_name, None)
self.providers.pop(provider_id, None)
def run(self):
"""
运行协调服务器。
"""
logger.info(f"协调服务器已启动,端口{self.udp_sock.getsockname()[1]},等待连接...")
action_handlers = {
'register': self.handle_register, # 服务注册处理行为
'request': self.handle_request, # 服务请求处理行为
'punch_request': self.handle_punch_request, # 打洞处理行为
'stop_provider': self.handle_stop_provider, # 停止服务行为
'heartbeat': self.handle_heartbeat # 心跳处理行为
}
while self.running:
try:
data, addr = self.udp_sock.recvfrom(4096)
try:
message = json.loads(data.decode())
action = message.get('action')
handler = action_handlers.get(action)
if handler: # 如果存在对应的处理行为,则执行它
self.executor.submit(handler, message, addr).result()
else: # 如果没有对应的处理行为,则返回错误响应
self.udp_sock.sendto(json.dumps({
'status': 'error',
'message': '无效操作'}).encode(), addr)
except json.JSONDecodeError:
self.udp_sock.sendto(json.dumps({
'status': 'error',
'message': '无效的JSON数据'
}).encode(), addr)
except socket.timeout: # 捕获 UDP 接收超时异常
pass # 不做任何处理,允许主线程继续执行
except Exception as e:
logger.debug(f"服务器错误: {str(e)}")
def start(self):
"""
启动协调服务器。
"""
# 创建线程运行 run 方法
server_thread = threading.Thread(target=self.run)
server_thread.start()
try:
# 主线程捕获键盘打断信号
while self.running:
time.sleep(1) # 防止主线程空转
except KeyboardInterrupt:
logger.info("检测到键盘打断,准备退出...")
self.running = False
# 通知所有提供端停止服务
for provider_id, provider_info in self.providers.items():
try:
self.udp_sock.sendto(json.dumps({
'action': 'stop_provider',
}).encode(), provider_info['addr'])
logger.info(f"已通知提供端 {provider_id} 停止服务")
except Exception as e:
logger.error(f"通知提供端 {provider_id} 停止服务时出错: {e}")
server_thread.join() # 等待服务器线程退出
# 关闭线程池和套接字
self.executor.shutdown()
self.udp_sock.close()
logger.info("协调服务器已安全退出")
if __name__ == '__main__':
server = CoordinatorServer()
server.start()