Add passwd_file as last argument

This commit is contained in:
2024-06-09 22:57:19 +02:00
parent 74d1d7771c
commit 34ba10be03
2 changed files with 62 additions and 30 deletions

View File

@@ -242,6 +242,14 @@ def main():
logging.error('Need passwd-file as fourth argument') logging.error('Need passwd-file as fourth argument')
sys.exit() sys.exit()
#Make sure the file specified is to be found..
if not passwd_file.startswith('/'):
passwd_file = os.path.abspath(passwd_file)
if not os.path.isfile(passwd_file):
logging.error('Cannot find passwd-file %s', passwd_file)
sys.exit()
rrdupdater = UpdateRRD(rrdfile) rrdupdater = UpdateRRD(rrdfile)
client = routerstats_client(client_host, client_port, passwd_file) client = routerstats_client(client_host, client_port, passwd_file)
while True: while True:

View File

@@ -221,36 +221,44 @@ def load_start_pos(logfile):
return tmp_start_pos return tmp_start_pos
return None return None
def check_login(answer):
with open('passwd.client', 'r') as passwd_file:
passwd = passwd_file.readline()
passwd = passwd.rstrip() #Remove that newline
try:
answer = answer.decode('utf-8')
except UnicodeDecodeError as error:
logging.error('Could not decode %s as unicode: %s', answer, str(error))
if answer == passwd:
return True
return False
class RequestHandler(socketserver.BaseRequestHandler): class RequestHandler(socketserver.BaseRequestHandler):
'''derived BaseRequestHandler''' '''derived BaseRequestHandler'''
def login(self): def set_passwd_file(self, filename):
self.request.send(b'Hello') self.passwd_file = filename
def check_login(self, answer):
with open(self.passwd_file, 'r') as passwd_file:
passwd = passwd_file.readline()
passwd = passwd.rstrip() #Remove that newline
try: try:
answer = self.request.recv(1024) answer = answer.decode('utf-8')
except TimeoutError: except UnicodeDecodeError as error:
#Client did not even bother to reply... logging.error('Could not decode %s as unicode: %s', answer, str(error))
logging.warning('Timed out during auth') if answer == passwd:
self.request.send(b'timeout') return True
return return False
if not check_login(answer):
logging.warning('Wrong passphrase') def login(self):
self.request.send(b'auth error') try:
return self.request.send(b'Hello')
self.request.send(b'Welcome') try:
logging.info('Client ' + str(self.client_address[0]) + ' logged in') answer = self.request.recv(1024)
return True except TimeoutError:
#Client did not even bother to reply...
logging.warning('Timed out during auth')
self.request.send(b'timeout')
return
if not self.check_login(answer):
logging.warning('Wrong passphrase')
self.request.send(b'auth error')
return
self.request.send(b'Welcome')
logging.info('Client ' + str(self.client_address[0]) + ' logged in')
return True
except BrokenPipeError:
#Client gone and came back, bad idea.
logging.warning('Broken pipe, closing socket')
return False
def handle(self): def handle(self):
logging.info('Connected to ' + str(self.client_address[0])) logging.info('Connected to ' + str(self.client_address[0]))
@@ -305,7 +313,7 @@ class RequestHandler(socketserver.BaseRequestHandler):
#Long time, no see, time to pingpong the client:) #Long time, no see, time to pingpong the client:)
if self.ping_client() != True: if self.ping_client() != True:
break break
logging.debug('Request abandoned') logging.info('Request abandoned')
def send(self, tosend): def send(self, tosend):
'''Wrap sendall''' '''Wrap sendall'''
@@ -352,7 +360,7 @@ class RequestHandler(socketserver.BaseRequestHandler):
logging.error('Peer gone?: ' + str(error)) logging.error('Peer gone?: ' + str(error))
return False return False
def socket_server(file_parser_result_queue, overflowqueue, socket_server_signal_queue): def socket_server(file_parser_result_queue, overflowqueue, socket_server_signal_queue, passwd_file):
'''Socket server sending whatever data is in the queue to any client connecting''' '''Socket server sending whatever data is in the queue to any client connecting'''
#Multiple connections here is probably a horrible idea:) #Multiple connections here is probably a horrible idea:)
setproctitle('routerstats-collector socket_server') setproctitle('routerstats-collector socket_server')
@@ -365,6 +373,7 @@ def socket_server(file_parser_result_queue, overflowqueue, socket_server_signal_
server.timeout = 1 server.timeout = 1
with server: with server:
server.RequestHandlerClass.set_queue(server.RequestHandlerClass, file_parser_result_queue, overflowqueue, socket_server_signal_queue) server.RequestHandlerClass.set_queue(server.RequestHandlerClass, file_parser_result_queue, overflowqueue, socket_server_signal_queue)
server.RequestHandlerClass.set_passwd_file(server.RequestHandlerClass, passwd_file)
logging.info('Socket up at ' + host + ':' + str(port)) logging.info('Socket up at ' + host + ':' + str(port))
while True: while True:
try: try:
@@ -410,6 +419,21 @@ def main():
if os.path.isfile(file_to_follow) is False: if os.path.isfile(file_to_follow) is False:
logging.error('Could not find file ' + file_to_follow) logging.error('Could not find file ' + file_to_follow)
sys.exit() sys.exit()
try:
passwd_file = sys.argv[2]
except IndexError:
logging.error('Need passwd-file as second argument')
sys.exit()
if not passwd_file.startswith('/'):
passwd_file = os.path.abspath(passwd_file)
logging.debug('Setting passwd-file to %s', passwd_file)
if not os.path.isfile(passwd_file):
logging.error('Could not find file %s', passwd_file)
sys.exit()
file_parser_result_queue = Queue() file_parser_result_queue = Queue()
file_parser_signal_queue = Queue() file_parser_signal_queue = Queue()
overflowqueue = Queue() overflowqueue = Queue()
@@ -433,7 +457,7 @@ def main():
#This means any "malicious" connections will wipe the history #This means any "malicious" connections will wipe the history
#We're fine with this #We're fine with this
socket_server_process = Process(target=socket_server, daemon=True, args=(file_parser_result_queue, overflowqueue, socket_server_signal_queue)) socket_server_process = Process(target=socket_server, daemon=True, args=(file_parser_result_queue, overflowqueue, socket_server_signal_queue, passwd_file))
socket_server_process.start() socket_server_process.start()
logging.debug('Socket server started as pid ' + str(socket_server_process.pid)) logging.debug('Socket server started as pid ' + str(socket_server_process.pid))
started_processes.append((socket_server_process, socket_server_signal_queue)) started_processes.append((socket_server_process, socket_server_signal_queue))