260 lines
10 KiB
Python
260 lines
10 KiB
Python
import socket
|
||
import json
|
||
import time
|
||
import logging
|
||
from collections import defaultdict
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
import threading
|
||
|
||
# 设置日志配置
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class CoordinatorServer:
|
||
def __init__(self, host='0.0.0.0', port=5000):
|
||
"""
|
||
初始化协调服务器。
|
||
:param host: 服务器绑定的主机地址,默认为 '0.0.0.0'。
|
||
:param port: 服务器监听的端口号,默认为 5000。
|
||
"""
|
||
self.host = host
|
||
self.port = port
|
||
self.clients = defaultdict(dict) # 客户端信息字典
|
||
self.providers = defaultdict(dict) # 服务提供端信息字典
|
||
self.services = defaultdict(tuple) # 服务名称与提供者信息的映射
|
||
self.udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||
self.udp_sock.bind((host, port))
|
||
self.udp_sock.settimeout(1) # 设置 UDP 套接字接收超时时间为 1 秒
|
||
self.running = True
|
||
self.executor = ThreadPoolExecutor(max_workers=10) # 创建线程池
|
||
logger.debug(f"协调服务器运行在 {host}:{port}")
|
||
|
||
# 启动定时清理任务
|
||
self.executor.submit(self.cleanup_expired_services)
|
||
|
||
def handle_register(self, data, addr):
|
||
"""
|
||
处理服务注册请求。
|
||
:param data: 包含服务注册信息的字典。
|
||
:param addr: 发送请求的客户端地址。
|
||
"""
|
||
try:
|
||
provider_id = data['provider_id']
|
||
services = data['services']
|
||
|
||
# 记录客户端信息
|
||
self.providers[provider_id] = {
|
||
'addr': addr,
|
||
'services': services,
|
||
'last_seen': time.time()
|
||
}
|
||
|
||
# 遍历 services 字典,记录每个服务的名称和端口
|
||
for service_name, service_port in services.items():
|
||
self.services[service_name] = (addr, service_port, provider_id)
|
||
|
||
logger.info(f"注册来自{addr}:{services}")
|
||
|
||
# 直接回复注册成功消息
|
||
response = {'status': 'success', 'message': '服务注册成功'}
|
||
self.udp_sock.sendto(json.dumps(response).encode(), addr)
|
||
except Exception as e:
|
||
response = {'status': 'error', 'message': str(e)}
|
||
self.udp_sock.sendto(json.dumps(response).encode(), addr)
|
||
|
||
def handle_request(self, data, addr):
|
||
"""
|
||
处理服务请求。
|
||
:param data: 包含服务请求信息的字典。
|
||
:param addr: 发送请求的客户端地址。
|
||
"""
|
||
try:
|
||
service_name = data['service_name']
|
||
client_id = data['client_id']
|
||
|
||
if service_name not in self.services:
|
||
response = {'status': 'error', 'message': '服务未找到'}
|
||
self.udp_sock.sendto(json.dumps(response).encode(), addr)
|
||
|
||
# 记录请求客户端信息
|
||
self.clients[client_id] = {
|
||
'addr': addr,
|
||
'service_name': service_name,
|
||
'last_seen': time.time()
|
||
}
|
||
|
||
# 获取服务提供者的信息
|
||
provider_addr, internal_port, provider_id = self.services[service_name]
|
||
|
||
logger.debug(f"服务请求: {service_name} 来自 {addr}, 提供者 {provider_addr}")
|
||
|
||
response = {
|
||
'status': 'success',
|
||
'provider_addr': provider_addr,
|
||
'internal_port': internal_port,
|
||
'provider_id': provider_id
|
||
}
|
||
self.udp_sock.sendto(json.dumps(response).encode(), addr)
|
||
except Exception as e:
|
||
response = {'status': 'error', 'message': str(e)}
|
||
self.udp_sock.sendto(json.dumps(response).encode(), addr)
|
||
|
||
def handle_punch_request(self, data, addr):
|
||
"""
|
||
处理打洞请求。
|
||
:param data: 包含打洞请求信息的字典。
|
||
:param addr: 发送请求的客户端地址。
|
||
"""
|
||
try:
|
||
client_id = data['client_id']
|
||
provider_id = data['provider_id']
|
||
|
||
# 获取目标客户端信息
|
||
if provider_id not in self.providers:
|
||
response = {'status': 'error', 'message': '目标提供端未找到'}
|
||
self.udp_sock.sendto(json.dumps(response).encode(), addr)
|
||
|
||
provider_addr = self.providers[provider_id]['addr']
|
||
|
||
logger.info(f"打洞请求: {addr} -> {provider_id}")
|
||
|
||
# 通知双方对方的地址
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'punch_request',
|
||
'client_id': client_id,
|
||
'client_addr': addr
|
||
}).encode(), provider_addr)
|
||
|
||
self.udp_sock.sendto(json.dumps({
|
||
'status': 'success',
|
||
'provider_addr': provider_addr
|
||
}).encode(), addr)
|
||
except Exception as e:
|
||
logger.error(f"处理打洞请求时出错: {e}")
|
||
response = {'status': 'error', 'message': str(e)}
|
||
self.udp_sock.sendto(json.dumps(response).encode(), addr)
|
||
|
||
def handle_stop_provider(self, data, addr):
|
||
"""
|
||
处理停止服务请求。
|
||
:param data: 包含停止服务信息的字典。
|
||
:param addr: 发送请求的客户端地址。
|
||
"""
|
||
try:
|
||
service_name = data['service_name']
|
||
self.services.pop(service_name, None)
|
||
response = {'status': 'success', 'message': '服务停止成功'}
|
||
self.udp_sock.sendto(json.dumps(response).encode(), addr)
|
||
except Exception as e:
|
||
response = {'status': 'error', 'message': str(e)}
|
||
self.udp_sock.sendto(json.dumps(response).encode(), addr)
|
||
|
||
def handle_heartbeat(self, data, addr):
|
||
"""
|
||
处理心跳包。
|
||
:param data: 包含心跳信息的字典。
|
||
:param addr: 发送心跳的客户端地址。
|
||
"""
|
||
try:
|
||
provider_id = data['provider_id']
|
||
self.providers[provider_id]['last_seen'] = time.time()
|
||
response = {'status': 'success', 'message': '心跳更新成功'}
|
||
self.udp_sock.sendto(json.dumps(response).encode(), addr)
|
||
except Exception as e:
|
||
response = {'status': 'error', 'message': str(e)}
|
||
self.udp_sock.sendto(json.dumps(response).encode(), addr)
|
||
|
||
def cleanup_expired_services(self):
|
||
"""
|
||
定时清理过期服务。
|
||
每20秒检查一次,移除超过30秒未更新心跳的服务。
|
||
"""
|
||
while self.running:
|
||
time.sleep(60) # 每60秒检查一次
|
||
current_time = time.time()
|
||
expired_providers = []
|
||
for provider_id, provider in self.providers.items():
|
||
if current_time - provider['last_seen'] > 60: # 心跳包最后更新时间大于30秒
|
||
expired_providers.append(provider_id)
|
||
logger.info(f"服务过期: {provider['addr']}")
|
||
for provider_id in expired_providers:
|
||
provider = self.providers[provider_id]
|
||
for service_name in provider['services'].keys(): # 使用新的 services 字段
|
||
self.services.pop(service_name, None)
|
||
self.providers.pop(provider_id, None)
|
||
|
||
def run(self):
|
||
"""
|
||
运行协调服务器。
|
||
"""
|
||
logger.info(f"协调服务器已启动,端口{self.udp_sock.getsockname()[1]},等待连接...")
|
||
action_handlers = {
|
||
'register': self.handle_register, # 服务注册处理行为
|
||
'request': self.handle_request, # 服务请求处理行为
|
||
'punch_request': self.handle_punch_request, # 打洞处理行为
|
||
'stop_provider': self.handle_stop_provider, # 停止服务行为
|
||
'heartbeat': self.handle_heartbeat # 心跳处理行为
|
||
}
|
||
|
||
while self.running:
|
||
try:
|
||
data, addr = self.udp_sock.recvfrom(4096)
|
||
try:
|
||
message = json.loads(data.decode())
|
||
action = message.get('action')
|
||
|
||
handler = action_handlers.get(action)
|
||
if handler: # 如果存在对应的处理行为,则执行它
|
||
self.executor.submit(handler, message, addr).result()
|
||
else: # 如果没有对应的处理行为,则返回错误响应
|
||
self.udp_sock.sendto(json.dumps({
|
||
'status': 'error',
|
||
'message': '无效操作'}).encode(), addr)
|
||
except json.JSONDecodeError:
|
||
self.udp_sock.sendto(json.dumps({
|
||
'status': 'error',
|
||
'message': '无效的JSON数据'
|
||
}).encode(), addr)
|
||
except socket.timeout: # 捕获 UDP 接收超时异常
|
||
pass # 不做任何处理,允许主线程继续执行
|
||
except Exception as e:
|
||
logger.debug(f"服务器错误: {str(e)}")
|
||
|
||
def start(self):
|
||
"""
|
||
启动协调服务器。
|
||
"""
|
||
# 创建线程运行 run 方法
|
||
server_thread = threading.Thread(target=self.run)
|
||
server_thread.start()
|
||
|
||
try:
|
||
# 主线程捕获键盘打断信号
|
||
while self.running:
|
||
time.sleep(1) # 防止主线程空转
|
||
except KeyboardInterrupt:
|
||
logger.info("检测到键盘打断,准备退出...")
|
||
self.running = False
|
||
|
||
# 通知所有提供端停止服务
|
||
for provider_id, provider_info in self.providers.items():
|
||
try:
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'stop_provider',
|
||
}).encode(), provider_info['addr'])
|
||
logger.info(f"已通知提供端 {provider_id} 停止服务")
|
||
except Exception as e:
|
||
logger.error(f"通知提供端 {provider_id} 停止服务时出错: {e}")
|
||
server_thread.join() # 等待服务器线程退出
|
||
|
||
# 关闭线程池和套接字
|
||
self.executor.shutdown()
|
||
self.udp_sock.close()
|
||
logger.info("协调服务器已安全退出")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
server = CoordinatorServer()
|
||
server.start()
|