182 lines
6.9 KiB
Python
182 lines
6.9 KiB
Python
import socket
|
||
import threading
|
||
import json
|
||
import os
|
||
import hashlib
|
||
import secrets
|
||
import time
|
||
from collections import defaultdict
|
||
|
||
|
||
def send_json(conn, data):
|
||
# 发送JSON数据
|
||
conn.sendall(json.dumps(data).encode())
|
||
|
||
|
||
def recv_json(conn):
|
||
# 接收并解析JSON数据
|
||
data = conn.recv(4096)
|
||
if not data:
|
||
return None
|
||
return json.loads(data.decode())
|
||
|
||
|
||
class Coordinator:
|
||
def __init__(self, host='0.0.0.0', port=5000):
|
||
# 初始化协调器服务端参数
|
||
self.host = host # 监听地址
|
||
self.port = port # 监听端口
|
||
# 生成盐值用于密码加密
|
||
self.salt = secrets.token_hex(8)
|
||
# 存储管理员密码哈希值(盐+密码)
|
||
self.stored_hash = hashlib.sha256((self.salt + "admin_password").encode()).hexdigest()
|
||
# 存储用户令牌信息
|
||
self.tokens = {}
|
||
# 存储服务注册信息,格式:{token: {services: [], addr: (), conn: socket}}
|
||
self.services = defaultdict(dict)
|
||
# 活动连接池(当前未使用)
|
||
self.active_connections = {}
|
||
# 线程锁保证数据安全
|
||
self.lock = threading.Lock()
|
||
|
||
def handle_client(self, conn, addr):
|
||
# 处理客户端连接
|
||
print(f"New connection from {addr}")
|
||
token = None
|
||
salt = secrets.token_hex(8)
|
||
stored_hash = hashlib.sha256((salt + "admin_password").encode()).hexdigest()
|
||
|
||
try:
|
||
while True:
|
||
# 接收客户端JSON数据
|
||
data = recv_json(conn)
|
||
if not data:
|
||
break
|
||
|
||
action = data.get('action')
|
||
|
||
# 登录流程:发送盐值
|
||
if action == 'login':
|
||
if data.get('account') == 'admin':
|
||
response = {'status': 'salt', 'salt': salt}
|
||
send_json(conn, response)
|
||
else:
|
||
send_json(conn, {'status': 'error', 'message': 'Invalid account'})
|
||
|
||
# 认证流程:验证密码哈希
|
||
elif action == 'auth':
|
||
if data.get('hash') == stored_hash:
|
||
# 生成访问令牌(有效期1小时)
|
||
token = secrets.token_hex(8)
|
||
with self.lock:
|
||
self.tokens[token] = {
|
||
'ip': addr[0],
|
||
'expiry': time.time() + 3600 # 令牌过期时间
|
||
}
|
||
response = {'status': 'success', 'token': token, 'message': 'Login successful'}
|
||
send_json(conn, response)
|
||
else:
|
||
send_json(conn, {'status': 'error', 'message': 'Authentication failed'})
|
||
|
||
# 服务注册流程
|
||
elif action == 'register_service':
|
||
connector_token = data.get('token')
|
||
if self.validate_token(connector_token, addr[0]):
|
||
services = data.get('services', [])
|
||
with self.lock:
|
||
self.services[connector_token] = {
|
||
'services': services, # 支持的服务列表
|
||
'addr': addr, # 客户端地址信息
|
||
'conn': conn # 客户端连接套接字
|
||
}
|
||
send_json(conn, {'status': 'success', 'message': 'Services registered'})
|
||
else:
|
||
send_json(conn, {'status': 'error', 'message': 'Invalid token'})
|
||
|
||
# 服务请求流程
|
||
elif action == 'request_service':
|
||
connector_token = data.get('token')
|
||
if not self.validate_token(connector_token, addr[0]):
|
||
send_json(conn, {'status': 'error', 'message': 'Invalid token'})
|
||
continue
|
||
|
||
service_name = data.get('service_name')
|
||
provider_token = self.find_service_provider(service_name)
|
||
|
||
if provider_token:
|
||
provider_info = self.services[provider_token]
|
||
provider_addr = provider_info['addr']
|
||
connector_addr = addr
|
||
count = 0
|
||
|
||
# 通知服务提供方进行NAT打洞
|
||
punch_msg = {
|
||
'action': 'punch_request',
|
||
'connector_addr': connector_addr,
|
||
'service_name': service_name # 请求的服务名称
|
||
}
|
||
send_json(provider_info['conn'], punch_msg)
|
||
|
||
# 响应请求方
|
||
send_json(conn, {
|
||
'status': 'success',
|
||
'provider_addr': provider_addr
|
||
})
|
||
|
||
# 使用后立即销毁令牌
|
||
with self.lock:
|
||
if connector_token in self.tokens:
|
||
del self.tokens[connector_token]
|
||
else:
|
||
send_json(conn, {'status': 'error', 'message': 'Service not available'})
|
||
except (ConnectionResetError, json.JSONDecodeError):
|
||
pass
|
||
finally:
|
||
conn.close()
|
||
print(f"Connection closed: {addr}")
|
||
# 清理资源
|
||
if token:
|
||
with self.lock:
|
||
if token in self.tokens:
|
||
del self.tokens[token]
|
||
if token in self.services:
|
||
del self.services[token]
|
||
|
||
def validate_token(self, token, ip):
|
||
# 验证令牌有效性:存在、IP匹配、未过期
|
||
with self.lock:
|
||
token_info = self.tokens.get(token)
|
||
if token_info and token_info['ip'] == ip and token_info['expiry'] > time.time():
|
||
return True
|
||
return False
|
||
|
||
def find_service_provider(self, service_name):
|
||
# 查找可用服务提供者
|
||
for token, info in self.services.items():
|
||
if service_name in info['services']:
|
||
return token
|
||
return None
|
||
|
||
def start(self):
|
||
# 启动协调器服务
|
||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||
server.bind((self.host, self.port))
|
||
server.listen(5)
|
||
print(f"Coordinator listening on {self.host}:{self.port}")
|
||
|
||
while True:
|
||
conn, addr = server.accept()
|
||
# 为每个连接创建独立线程
|
||
client_thread = threading.Thread(
|
||
target=self.handle_client,
|
||
args=(conn, addr),
|
||
daemon=True
|
||
)
|
||
client_thread.start()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
coordinator = Coordinator()
|
||
coordinator.start()
|