420 lines
13 KiB
Python
420 lines
13 KiB
Python
import schedule
|
|
import time
|
|
import paho.mqtt.client as mqtt
|
|
from urllib.parse import urlparse
|
|
import requests
|
|
import json
|
|
import socket
|
|
import CloudFlare
|
|
import logging
|
|
import smtplib
|
|
import ssl
|
|
import io
|
|
|
|
logger = logging.getLogger('dns_updater')
|
|
logger.setLevel(logging.DEBUG)
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%m/%d/%Y:%H:%M:%S')
|
|
|
|
|
|
ch = logging.StreamHandler()
|
|
ch.setFormatter(formatter)
|
|
logger.addHandler(ch)
|
|
|
|
log_capture_string = io.StringIO()
|
|
ch2 = logging.StreamHandler(log_capture_string)
|
|
ch2.setFormatter(formatter)
|
|
logger.addHandler(ch2)
|
|
|
|
OPENWRT_PATH = "/cgi-bin/luci/"
|
|
INTERVAL = 60
|
|
userdata = dict()
|
|
mqtt_data = dict()
|
|
mail_data = dict()
|
|
openwrt = dict()
|
|
hosts = dict()
|
|
dns = dict()
|
|
firewall = dict()
|
|
|
|
def readFile():
|
|
with open("settings.json") as json_file:
|
|
global userdata
|
|
global hosts
|
|
global dns
|
|
global firewall
|
|
global mqtt_data
|
|
global mail_data
|
|
global openwrt
|
|
global OPENWRT_PATH
|
|
global INTERVAL
|
|
data = json.load(json_file)
|
|
|
|
userdata = {"apikey": data["apikey"], "username": data["user"]}
|
|
INTERVAL = data["interval"]
|
|
|
|
mqtt_data = data["mqtt"]
|
|
mail_data = data["mail"]
|
|
openwrt = data["openwrt"]
|
|
OPENWRT_PATH = "http://" + openwrt["host"] + OPENWRT_PATH
|
|
|
|
if "hosts" in data:
|
|
z = hosts.copy()
|
|
z.update(data["hosts"])
|
|
hosts = z
|
|
|
|
if "dns" in data:
|
|
z = dns.copy()
|
|
z.update(data["dns"])
|
|
dns = z
|
|
|
|
if "firewall" in data:
|
|
z = firewall.copy()
|
|
z.update(data["firewall"])
|
|
firewall = z
|
|
|
|
readFile()
|
|
|
|
def saveFile():
|
|
logger.info("Saving settings.json file...")
|
|
with open('settings.json', 'wt') as out:
|
|
global userdata
|
|
global hosts
|
|
global dns
|
|
global INTERVAL
|
|
obj = {
|
|
"apikey": userdata["apikey"],
|
|
"user": userdata["username"],
|
|
"interval": INTERVAL,
|
|
"dns": dns,
|
|
"hosts": hosts,
|
|
"firewall": firewall,
|
|
"mqtt": mqtt_data,
|
|
"openwrt": openwrt,
|
|
"mail": mail_data,
|
|
}
|
|
res = json.dump(obj, out, sort_keys=True, indent=4, separators=(',', ': '))
|
|
logger.info("File saved")
|
|
|
|
schedule.every(2).minutes.do(saveFile)
|
|
|
|
|
|
def getPublicIP():
|
|
r = requests.get("https://ipinfo.io/")
|
|
data = json.loads(r.text)
|
|
|
|
return data.get("ip")
|
|
|
|
def do_dns_update(cf, zone_name, zone_id, dns_name, ip_address, ip_address_type):
|
|
"""Cloudflare API code - example"""
|
|
logger.info("Update %s to %s" % (dns_name+"."+zone_name, ip_address))
|
|
|
|
try:
|
|
prefix = "" if dns_name == "@" else dns_name+"."
|
|
params = {'name':prefix+zone_name, 'match':'all', 'type':ip_address_type}
|
|
dns_records = cf.zones.dns_records.get(zone_id, params=params)
|
|
except CloudFlare.exceptions.CloudFlareAPIError as e:
|
|
logger.error('/zones/dns_records %s - %d %s - api call failed' % (dns_name, e, e))
|
|
|
|
updated = False
|
|
changed = False
|
|
|
|
# update the record - unless it's already correct
|
|
for dns_record in dns_records:
|
|
old_ip_address = dns_record['content']
|
|
old_ip_address_type = dns_record['type']
|
|
|
|
if ip_address_type not in ['A', 'AAAA']:
|
|
# we only deal with A / AAAA records
|
|
continue
|
|
|
|
if ip_address_type != old_ip_address_type:
|
|
# only update the correct address type (A or AAAA)
|
|
# we don't see this becuase of the search params above
|
|
logger.debug('IGNORED: %s %s ; wrong address family' % (dns_name, old_ip_address))
|
|
continue
|
|
|
|
if ip_address == old_ip_address:
|
|
logger.debug('UNCHANGED: %s %s' % (dns_name, ip_address))
|
|
updated = True
|
|
continue
|
|
|
|
# Yes, we need to update this record - we know it's the same address type
|
|
|
|
dns_record_id = dns_record['id']
|
|
dns_record = {
|
|
'name':dns_name,
|
|
'type':ip_address_type,
|
|
'content':ip_address
|
|
}
|
|
try:
|
|
dns_record = cf.zones.dns_records.put(zone_id, dns_record_id, data=dns_record)
|
|
changed = True
|
|
except CloudFlare.exceptions.CloudFlareAPIError as e:
|
|
logger.error('/zones.dns_records.put %s - %d %s - api call failed' % (dns_name, e, e))
|
|
logger.debug('UPDATED: %s %s -> %s' % (dns_name, old_ip_address, ip_address))
|
|
updated = True
|
|
|
|
if updated:
|
|
return changed
|
|
|
|
# no exsiting dns record to update - so create dns record
|
|
dns_record = {
|
|
'name':dns_name,
|
|
'type':ip_address_type,
|
|
'content':ip_address
|
|
}
|
|
try:
|
|
dns_record = cf.zones.dns_records.post(zone_id, data=dns_record)
|
|
changed = True
|
|
except CloudFlare.exceptions.CloudFlareAPIError as e:
|
|
logger.error('/zones.dns_records.post %s - %d %s - api call failed' % (dns_name, e, e))
|
|
logger.debug('CREATED: %s %s' % (dns_name, ip_address))
|
|
return changed
|
|
|
|
|
|
if openwrt["enabled"] == True:
|
|
jar = requests.cookies.RequestsCookieJar()
|
|
fw_login = requests.get(OPENWRT_PATH + "/rpc/auth", cookies=jar, json={
|
|
"id": 1,
|
|
"method": "login",
|
|
"params": [
|
|
openwrt["username"],
|
|
openwrt["password"]
|
|
]
|
|
})
|
|
|
|
if "result" not in fw_login.json() or fw_login.json()["result"] is None:
|
|
exit("Incorrect OpenWrt Login")
|
|
|
|
o = urlparse(OPENWRT_PATH)
|
|
jar.set('sysauth', fw_login.json()["result"], domain=o.netloc, path=o.path)
|
|
|
|
def setFirewall(section, key, new_ipv6):
|
|
global jar
|
|
|
|
logger.info("Updating firewall " + section + "." + key + "=" + new_ipv6)
|
|
|
|
if not new_ipv6:
|
|
logger.debug("Empty IPv6... Skipping...")
|
|
return
|
|
|
|
r = requests.get(OPENWRT_PATH + "/rpc/uci", cookies=jar, json={
|
|
"id": 1,
|
|
"method": "get_all",
|
|
"params": [
|
|
"firewall",
|
|
section
|
|
]
|
|
})
|
|
data = r.json()
|
|
|
|
if not "result" in data or data["result"] is None:
|
|
logger.warning("Unknown firewall section %s... Skipping..." % (section))
|
|
return
|
|
|
|
result = data["result"]
|
|
|
|
if not "family" in result or result["family"] is None:
|
|
logger.debug("No family set in firewall section %s... Skipping..." % (section))
|
|
return
|
|
|
|
if not key in result or result[key] is None:
|
|
logger.debug("No %s set in firewall section %s... Skipping..." % (key, section))
|
|
return
|
|
|
|
if result["family"] != "ipv6":
|
|
logger.debuginfo("Section %s is no ipv6... Skipping..." % (section))
|
|
return
|
|
|
|
if result[key] == new_ipv6:
|
|
logger.debug("Section %s has same ipv6... Skipping..." % (section))
|
|
return
|
|
|
|
r = requests.get(OPENWRT_PATH + "/rpc/uci", cookies=jar, json={
|
|
"id": 1,
|
|
"method": "set",
|
|
"params": [
|
|
"firewall",
|
|
section,
|
|
key,
|
|
new_ipv6
|
|
]
|
|
})
|
|
logger.debug("updated = %s" % r.json()["result"])
|
|
return True
|
|
|
|
def commitFirewall():
|
|
global jar
|
|
r = requests.get(OPENWRT_PATH + "/rpc/uci", cookies=jar, json={
|
|
"id": 1,
|
|
"method": "commit",
|
|
"params": [
|
|
"firewall",
|
|
]
|
|
})
|
|
|
|
def updateFirewall():
|
|
global hosts
|
|
global dns
|
|
global firewall
|
|
global commitFirewall
|
|
global setFirewall
|
|
logger.info("Updating firewall")
|
|
changed = False
|
|
for section in firewall:
|
|
for key in firewall[section]:
|
|
value = firewall[section][key]
|
|
ipv6 = hosts.get(value, dict()).get("ipv6", "")
|
|
changed = setFirewall(section, key, ipv6)
|
|
if changed == True:
|
|
logger.info("Firewall set... Commiting...")
|
|
commitFirewall()
|
|
logger.info("Firewall commited")
|
|
else:
|
|
logger.info("Nothing changed. Skipping firewall commit...")
|
|
return changed
|
|
|
|
def updateDNS():
|
|
global hosts
|
|
global userdata
|
|
global dns
|
|
global getPublicIP
|
|
global do_dns_update
|
|
|
|
logger.info("Updating DNS")
|
|
|
|
|
|
cf = CloudFlare.CloudFlare(email=userdata["username"], token=userdata["apikey"])
|
|
|
|
PUBLIC = getPublicIP()
|
|
if not PUBLIC or PUBLIC is None:
|
|
logger.error("EMPTY PUBLIC IP?!")
|
|
return False
|
|
|
|
changed = False
|
|
|
|
for domain in dns:
|
|
params = {'name':domain}
|
|
zones = cf.zones.get(params=params)
|
|
|
|
zone = zones[0]
|
|
zone_name = zone['name']
|
|
zone_id = zone['id']
|
|
for subdomain in dns[domain]:
|
|
ipv4_host = dns[domain][subdomain].get("ipv4", "")
|
|
ipv6_host = dns[domain][subdomain].get("ipv6", "")
|
|
|
|
if ipv4_host.lower() == "public":
|
|
ipv4 = PUBLIC
|
|
else:
|
|
ipv4 = hosts.get(ipv4_host, dict()).get("ipv4", "")
|
|
|
|
ipv6 = hosts.get(ipv6_host, dict()).get("ipv6", "")
|
|
|
|
dns_records = []
|
|
|
|
if ipv4:
|
|
if do_dns_update(cf, zone_name, zone_id, subdomain, ipv4, "A"):
|
|
changed = True
|
|
|
|
if ipv6:
|
|
if do_dns_update(cf, zone_name, zone_id, subdomain, ipv6, "AAAA"):
|
|
changed = True
|
|
return changed
|
|
|
|
def sendMail(mail_data, title, content):
|
|
logger.info("Sending mail to %s" % mail_data["receipent"])
|
|
# Create a secure SSL context
|
|
context = ssl.create_default_context()
|
|
|
|
# Try to log in to server and send email
|
|
try:
|
|
server = smtplib.SMTP(mail_data["host"], mail_data["port"])
|
|
server.ehlo() # Can be omitted
|
|
server.starttls(context=context) # Secure the connection
|
|
server.ehlo() # Can be omitted
|
|
server.login(mail_data["username"], mail_data["password"])
|
|
|
|
message = "From: " + mail_data["username"] + "\n"
|
|
message += "To: " + mail_data["receipent"] + "\n"
|
|
message += "Subject: "+title+"\n"
|
|
message += "\n"
|
|
message += content
|
|
|
|
server.sendmail(mail_data["username"], mail_data["receipent"], message)
|
|
logger.debug("Mail send")
|
|
except Exception as e:
|
|
logger.error(e)
|
|
finally:
|
|
server.quit()
|
|
|
|
def updateAll():
|
|
global updateDNS
|
|
global updateFirewall
|
|
global openwrt
|
|
global log_capture_string
|
|
global mail_data
|
|
|
|
log_capture_string.truncate(0)
|
|
log_capture_string.seek(0)
|
|
|
|
changed = False
|
|
|
|
if openwrt["enabled"] == True:
|
|
if updateFirewall():
|
|
changed = True
|
|
else:
|
|
logger.info("Firewall update is disabled")
|
|
|
|
if updateDNS():
|
|
changed = True
|
|
|
|
if mail_data and "enabled" in mail_data and mail_data["enabled"] == True:
|
|
if changed == True:
|
|
sendMail(mail_data, "Info: DNS/Firewall Address Server", log_capture_string.getvalue())
|
|
else:
|
|
logger.debug("No email. nothing changed")
|
|
else:
|
|
logger.info("Email is disabled")
|
|
|
|
schedule.every(INTERVAL).seconds.do(updateAll)
|
|
|
|
|
|
|
|
def on_connect(client, userdata, flags, rc):
|
|
logger.debug("Connected with result code "+str(rc))
|
|
if rc != 0:
|
|
logger.error("ERROR: Please check MQTT Login")
|
|
client.subscribe("network/+/hostname")
|
|
client.subscribe("network/+/ipv4")
|
|
client.subscribe("network/+/ipv6")
|
|
|
|
def on_message(client, userdata, msg):
|
|
global hosts
|
|
pl = str((msg.payload).decode("utf-8")).lower()
|
|
logger.debug(msg.topic + ": " + pl)
|
|
|
|
if msg.topic.startswith("network/") and (msg.topic.endswith("/ipv4") or msg.topic.endswith("/ipv6")):
|
|
host = msg.topic.split("/")[1]
|
|
|
|
if not host in hosts:
|
|
hosts[host] = dict()
|
|
|
|
hosts[host]["ipv4" if msg.topic.endswith("/ipv4") else "ipv6"] = pl
|
|
|
|
client = mqtt.Client()
|
|
client.username_pw_set(username=mqtt_data["username"], password=mqtt_data["password"])
|
|
client.connect(mqtt_data["host"], mqtt_data["port"], 60)
|
|
client.on_connect = on_connect
|
|
client.on_message = on_message
|
|
|
|
client.loop_start()
|
|
|
|
try:
|
|
while True:
|
|
schedule.run_pending()
|
|
time.sleep(1)
|
|
except KeyboardInterrupt as ex:
|
|
saveFile()
|
|
|
|
# https://htmlpreview.github.io/?https://raw.githubusercontent.com/openwrt/luci/master/documentation/api/modules/luci.model.uci.html#Cursor.get
|
|
# https://wiki.teltonika.lt/view/UCI_command_usage |