Source code for metalpipe.node
"""
Node module
===========
The ``node`` module contains the ``MetalNode`` class, which is the foundation
for MetalPipe.
"""
import time
import datetime
import uuid
import importlib
import logging
import os
import threading
import pprint
import sys
import copy
import random
import functools
import csv
import re
import io
import yaml
import types
import inspect
import prettytable
import requests
import graphviz
from timed_dict.timed_dict import TimedDict
from metalpipe.message.batch import BatchStart, BatchEnd
from metalpipe.message.message import MetalPipeMessage
from metalpipe.node_queue.queue import MetalPipeQueue
from metalpipe.message.canary import Canary
from metalpipe.utils.set_attributes import set_kwarg_attributes
from metalpipe.utils.data_structures import Row, MySQLTypeSystem
from metalpipe.utils import data_structures as ds
# from metalpipe.metalpipe_recorder import RedisFixturizer
from metalpipe.utils.helpers import (
load_function,
replace_by_path,
remap_dictionary,
set_value,
get_value,
to_bool,
aggregate_values,
)
DEFAULT_MAX_QUEUE_SIZE = int(os.environ.get("DEFAULT_MAX_QUEUE_SIZE", 128))
MONITOR_INTERVAL = 1
STATS_COUNTER_MODULO = 4
LOGJAM_THRESHOLD = 0.25
SHORT_DELAY = 0.1
PROMETHEUS = False
[docs]def no_op(*args, **kwargs):
"""
No-op function to serve as default ``get_runtime_attrs``.
"""
return None
[docs]class bcolors:
"""
This class holds the values for the various colors that are used in the
tables that monitor the status of the nodes.
"""
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
[docs]class MetalNode:
"""
The foundational class of `MetalPipe`. This class is inherited by all
nodes in a computation graph.
Order of operations:
1. Child class ``__init__`` function
2. ``MetalNode`` ``__init__`` function
3. ``preflight_function`` (Specified in initialization params)
4. ``setup``
5. start
These methods have the following intended uses:
1. ``__init__`` Sets attribute values and calls the ``MetalNode`` ``__init__``
method.
2. ``get_runtime_attrs`` Sets any attribute values that are to be determined
at runtime, e.g. by checking environment variables or reading values
from a database. The ``get_runtime_attrs`` should return a dictionary
of attributes -> values, or else ``None``.
3. ``setup`` Sets the state of the ``MetalNode`` and/or creates any attributes
that require information available only at runtime.
Args:
send_batch_markers: If ``True``, then a ``BatchStart`` marker will
be sent when a new input is received, and a ``BatchEnd`` will be sent
after the input has been processed. The intention is that a number of
items will be emitted for each input received. For example, we might
emit a table row-by-row for each input.
get_runtime_attrs: A function that returns a dictionary-like object.
The keys and values will be saved to this ``MetalNode`` object's
attributes. The function is executed one time, upon starting the node.
get_runtime_attrs_args: A tuple of arguments to be passed to the
``get_runtime_attrs`` function upon starting the node.
get_runtime_attrs_kwargs: A dictionary of kwargs passed to the
``get_runtime_attrs`` function.
runtime_attrs_destinations: If set, this is a dictionary mapping
the keys returned from the ``get_runtime_attrs`` function to the
names of the attributes to which the values will be saved.
throttle: For each input received, a delay of ``throttle`` seconds
will be added.
keep_alive: If ``True``, keep the node's thread alive after
everything has been processed.
name: The name of the node. Defaults to a randomly generated hash.
Note that this hash is not consistent from one run to the next.
input_mapping: When the node receives a dictionary-like object,
this dictionary will cause the keys of the dictionary to be remapped
to new keys.
retain_input: If ``True``, then combine the dictionary-like input
with the output. If keys clash, the output value will be kept.
input_message_keypath: Read the value in this keypath as the content
of the incoming message.
"""
def __init__(
self,
*args,
batch=False,
get_runtime_attrs=no_op,
get_runtime_attrs_args=None,
get_runtime_attrs_kwargs=None,
runtime_attrs_destinations=None,
input_mapping=None,
retain_input=True,
throttle=0,
keep_alive=True,
max_errors=0,
max_messages_received=None,
name=None,
input_message_keypath=None,
key=None,
messages_received_counter=0,
prefer_existing_value=False,
messages_sent_counter=0,
post_process_function=None,
post_process_keypath=None,
summary="",
fixturize=False,
post_process_function_kwargs=None,
output_key=None,
break_test=None,
**kwargs
):
self.name = name or uuid.uuid4().hex
self.input_mapping = input_mapping or {}
self.input_queue_list = []
self.output_queue_list = []
self.input_node_list = []
self.input_message_keypath = input_message_keypath or []
self.output_node_list = []
self.max_messages_received = max_messages_received
self.global_dict = None # We'll add a dictionary upon startup
self.thread_dict = {}
self.kill_thread = False
self.prefer_existing_value = prefer_existing_value
self.accumulator = {}
self.output_key = output_key
self.fixturize = fixturize
self.keep_alive = keep_alive
self.retain_input = (
retain_input
) # Keep the input dictionary and send it downstream
if break_test is not None:
self.break_test = load_function(break_test)
else:
self.break_test = None
self.throttle = throttle
self.get_runtime_attrs = get_runtime_attrs
self.get_runtime_attrs_args = get_runtime_attrs_args or tuple()
self.cleanup_called = False
self.get_runtime_attrs_kwargs = get_runtime_attrs_kwargs or {}
self.runtime_attrs_destinations = runtime_attrs_destinations or {}
self.key = key
self.messages_received_counter = messages_received_counter
self.messages_sent_counter = messages_sent_counter
self.instantiated_at = datetime.datetime.now()
self.started_at = None
self.stopped_at = None
self.finished = False
self.error_counter = 0
self.status = "stopped" # running, error, success
self.max_errors = max_errors
self.post_process_function_name = (
post_process_function
) # Function to be run on result
self.post_process_function_kwargs = post_process_function_kwargs or {}
self.summary = summary
self.prometheus_objects = None
self.logjam_score = {"polled": 0.0, "logjam": 0.0}
# Get post process function if one is named
if self.post_process_function_name is not None:
components = self.post_process_function_name.split("__")
if len(components) == 1:
module = None
function_name = components[0]
self.post_process_function = globals()[function_name]
else:
module = ".".join(components[:-1])
function_name = components[-1]
module = importlib.import_module(module)
self.post_process_function = getattr(module, function_name)
else:
self.post_process_function = None
self.post_process_keypath = (
post_process_keypath.split(".")
if post_process_keypath is not None
else None
)
if self.fixturize:
self.fixturizer = RedisFixturizer()
else:
self.fixturizer = None
[docs] def setup(self):
"""
For classes that require initialization at runtime, which can't be done
when the class's ``__init__`` function is called. The ``MetalNode`` base
class's setup function is just a logging call.
It should be unusual to have to make use of ``setup`` because in practice,
initialization can be done in the ``__init__`` function.
"""
logging.debug(
"No ``setup`` method for {class_name}.".format(
class_name=self.__class__.__name__
)
)
pass
def __gt__(self, other):
"""
Convenience method so that we can link two nodes by ``node1 > node2``.
This just calls ``add_edge``.
"""
self.add_edge(other)
return other
@property
def is_source(self):
"""
Tests whether the node is a source or not, i.e. whether there are no
inputs to the node.
Returns:
(bool): ``True`` if the node has no inputs, ``False`` otherwise.
"""
return len(self.input_queue_list) == 0
@property
def is_sink(self):
"""
Tests whether the node is a sink or not, i.e. whether there are no
outputs from the node.
Returns:
(bool): ``True`` if the node has no output nodes, ``False`` otherwise.
"""
return len(self.output_queue_list) == 0
[docs] def add_edge(self, target, **kwargs):
"""
Create an edge connecting `self` to `target`.
This method instantiates the ``MetalPipeQueue`` object that connects the
nodes. Connecting the nodes together consists in (1) adding the queue to
the other's ``input_queue_list`` or ``output_queue_list`` and (2) setting
the queue's ``source_node`` and ``target_node`` attributes.
Args:
target (``MetalNode``): The node to which ``self`` will be connected.
"""
max_queue_size = kwargs.get("max_queue_size", DEFAULT_MAX_QUEUE_SIZE)
edge_queue = MetalPipeQueue(max_queue_size)
self.output_node_list.append(target)
target.input_node_list.append(self)
edge_queue.source_node = self
edge_queue.target_node = target
target.input_queue_list.append(edge_queue)
self.output_queue_list.append(edge_queue)
def _get_message_content(self, one_item):
# Get the content of a specific keypath, if one has
# been defined in the ``MetalNode`` initialization.
message_content = (
get_value(one_item.message_content, self.input_message_keypath)
if len(self.input_message_keypath) > 0
else one_item.message_content
)
if (
isinstance(message_content, (dict,))
and len(message_content) == 1
and "__value__" in message_content
):
message_content = message_content["__value__"]
return message_content
[docs] def wait_for_pipeline_finish(self):
while not self.pipeline_finished:
time.sleep(SHORT_DELAY)
[docs] def start(self):
"""
Starts the node. This is called by ``MetalNode.global_start()``.
The node's main loop is contained in this method. The main loop does
the following:
1. records the timestamp to the node's ``started_at`` attribute.
#. calls ``get_runtime_attrs`` (TODO: check if we can deprecate this)
#. calls the ``setup`` method for the class (which is a no-op by default)
#. if the node is a source, then successively yield all the results of
the node's ``generator`` method, then exit.
#. if the node is not a source, then loop over the input queues, getting
the next message. Note that when the message is pulled from the queue,
the ``MetalPipeQueue`` yields it as a dictionary.
#. gets either the content of the entire message if the node has no ``key``
attribute, or the value of ``message[self.key]``.
#. remaps the message content if a ``remapping`` dictionary has been
given in the node's configuration
#. calls the node's ``process_item`` method, yielding back the results.
(Note that a single input message may cause the node to yield zero,
one, or more than one output message.)
#. places the results into each of the node's output queues.
"""
self.started_at = datetime.datetime.now()
logging.debug(
"Starting node: {node}".format(node=self.__class__.__name__)
)
# ``get_runtime_attrs`` returns a dict-like object whose keys and
# values are stored as attributes of the ``MetalNode`` object.
if self.get_runtime_attrs is not None:
pre_flight_results = (
self.get_runtime_attrs(
*self.get_runtime_attrs_args,
**self.get_runtime_attrs_kwargs
)
or {}
)
if self.runtime_attrs_destinations is not None:
for key, value in pre_flight_results.items():
setattr(self, self.runtime_attrs_destinations[key], value)
elif self.runtime_attrs_destinations is None:
for key, value in pre_flight_results.items():
setattr(self, key, value)
else:
raise Exception(
"There is a ``get_runtime_attrs``, but the "
"``runtime_attrs_destinations`` is neither None nor a "
"dict-like object."
)
# We have to separate the pre-flight function, the setup of the
# class, and any necessary startup functions (such as connecting
# to a database).
self.setup() # Setup function?
if self.is_source and not isinstance(self, (DynamicClassMediator,)):
for output in self.generator():
if self.fixturizer:
self.fixturizer.record_source_node(self, output)
yield output, None
else:
logging.debug(
"About to enter loop for reading input queue in {node}.".format(
node=str(self)
)
)
while not self.finished:
for input_queue in self.input_queue_list:
one_item = input_queue.get()
if one_item is None:
continue
# Keep track of where the message came from, useful for
# managing streaming joins, e.g.
message_source = input_queue.source_node
self.messages_received_counter += 1
if (
self.max_messages_received is not None
and self.messages_received_counter
> self.max_messages_received
):
self.finished = True
break
# The ``throttle`` keyword introduces a delay in seconds
time.sleep(self.throttle)
# Retrieve the ``message_content``
message_content = self._get_message_content(one_item)
# If we receive ``None`` or a ``NothingToSeeHere``, continue.
if message_content is None or isinstance(
message_content, (NothingToSeeHere,)
):
continue
# Record the message and its source in the node's attributes
self.message = message_content
self.message_source = message_source
# Otherwise, process the message as usual, by calling
# the ``MetalNode`` object's ``process_item`` method.
for output in self._process_item():
# Put redis recording here
if self.fixturizer:
self.fixturizer.record_worker_node(
self, one_item, output
)
yield output, one_item # yield previous message
### Do the self.break_test() if it's been defined
### Execute the function and break
### if it returns True
if self.break_test is not None:
break_test_result = self.break_test(
output_message=output,
input_message=self.__message__,
)
logging.debug(
"NODE BREAK TEST: " + str(break_test_result)
)
self.finished = break_test_result
# Check input node(s) here to see if they're all ``.finished``
self.log_info(
"checking whether cleanup is a generator. " + str(self.name)
)
cleanup_output = self._cleanup()
if isinstance(cleanup_output, (types.GeneratorType,)):
self.log_info("GeneratorType found. Calling cleanup.")
for i in cleanup_output:
yield i, one_item
else:
pass
[docs] def cleanup(self):
"""
If there is any cleanup (closing files, shutting down database connections),
necessary when the node is stopped, then the node's class should provide
a ``cleanup`` method. By default, the method is just a logging statement.
"""
pass
def _cleanup(self):
self.log_info("Cleanup called after shutdown.")
self.cleanup()
self.log_info("Cleanup executed.")
self.cleanup_called = True
[docs] def log_info(self, message=""):
logging.debug(
"{node_name}: {message}".format(
node_name=self.name, message=message
)
)
[docs] def terminate_pipeline(self, error=False):
"""
This method can be called on any node in a pipeline, and it will cause
all of the nodes to terminate if they haven't stopped already.
Args:
error (bool): Not yet implemented.
"""
self.log_info("terminate_pipeline called..." + str(self.name))
for node in self.all_connected():
if not node.finished:
node.stopped_at = datetime.datetime.now()
node.finished = True
@property
def __message__(self):
"""
If the node has an ``output_key`` defined, return the corresponding
value in the message dictionary. If it does not, return the entire
message dictionary.
Nodes should access the content of their incoming message via this
property.
"""
if self.key is None:
out = self.message
elif isinstance(self.key, (str,)):
out = self.message[self.key]
elif isinstance(self.key, (list,)):
out = get_value(self.message, self.key)
else:
raise Exception("Bad type for input key.")
return out
def _process_item(self, *args, **kwargs):
"""
This method wraps the node's ``process_item`` method. It provides a place
to insert code for logging, error handling, etc.
There's lots of experimental code here, particularly the code for
Prometheus monitoring.
"""
# Swap out the message if ``key`` is specified
# If we're using prometheus, then increment a counter
if self.prometheus_objects is not None:
self.prometheus_objects["incoming_message_summary"].observe(
random.random()
)
message_arrival_time = time.time()
try:
for out in self.process_item(*args, **kwargs):
if (
not isinstance(out, (dict, NothingToSeeHere))
and self.output_key is None
):
logging.debug(
"Exception raised due to no key" + str(self.name)
)
raise Exception(
"Either message must be a dictionary or `output_key` "
"must be specified. {name}".format(self.name)
)
# Apply post_process_function if it's defined
if self.post_process_function is not None:
set_value(
out,
self.post_process_keypath,
self.post_process_function(
get_value(out, self.post_process_keypath),
**self.post_process_function_kwargs
),
)
if self.prometheus_objects is not None:
self.prometheus_objects["outgoing_message_summary"].set(
time.time() - message_arrival_time
)
yield out
except Exception as err:
self.error_counter += 1
logging.error(
"message: "
+ str(err.args)
+ str(self.__class__.__name__)
+ str(self.name)
)
if self.error_counter > self.max_errors:
self.terminate_pipeline(error=True)
self.status = "error" #
else:
logging.warning("oops")
[docs] def stream(self):
"""
Called in each ``MetalNode`` thread.
"""
self.status = "running"
try:
for output, previous_message in self.start():
logging.debug(
"In MetalNode.stream.stream() --> " + str(output)
)
for output_queue in self.output_queue_list:
self.messages_sent_counter += 1
output_queue.put(
output,
block=True,
timeout=None,
previous_message=previous_message,
)
except Exception as error:
self.status = "error"
self.stopped_at = datetime.datetime.now()
raise error
self.status = "success"
self.stopped_at = datetime.datetime.now()
@property
def time_running(self):
"""
Return the number of wall-clock seconds elapsed since the node was
started.
"""
if self.status == "stopped":
return None
elif self.status == "running":
return datetime.datetime.now() - self.started_at
elif self.stopped_at is None:
return datetime.datetime.now() - self.started_at
else:
return self.stopped_at - self.started_at
[docs] def all_connected(self, seen=None):
"""
Returns all the nodes connected (directly or indirectly) to ``self``.
This allows us to loop over all the nodes in a pipeline even if we
have a handle on only one. This is used by ``global_start``, for
example.
Args:
seen (set): A set of all the nodes that have been identified as
connected to ``self``.
Returns:
(set of ``MetalNode``): All the nodes connected to ``self``. This
includes ``self``.
"""
seen = seen or set()
if isinstance(self, (DynamicClassMediator,)):
for node_name, node_dict in self.node_dict.items():
node_obj = node_dict["obj"]
seen = seen | node_obj.all_connected(seen=seen)
else:
if self not in seen:
seen.add(self)
for node in self.input_node_list + self.output_node_list:
if node in seen:
continue
seen.add(node)
seen = seen | node.all_connected(seen=seen)
return seen
[docs] def broadcast(self, broadcast_message):
"""
Puts the message into all the input queues for all connected nodes.
"""
for node in self.all_connected():
for input_queue in node.input_queue_list:
input_queue.put(broadcast_message)
@property
def logjam(self):
"""
Returns the logjam score, which measures the degree to which the
node is holding up progress in downstream nodes.
We're defining a logjam as a node whose input queue is full, but
whose output queue(s) is not. More specifically, we poll each node
in the ``monitor_thread``, and increment a counter if the node is
a logjam at that time. This property returns the percentage of
samples in which the node is a logjam. Our intention is that if
this score exceeds a threshold, the user is alerted, or the load
is rebalanced somehow (not yet implemented).
Returns:
(float): Logjam score
"""
if self.logjam_score["polled"] == 0:
return 0.0
else:
return self.logjam_score["logjam"] / self.logjam_score["polled"]
[docs] def global_start(
self,
prometheus=False,
pipeline_name=None,
max_time=None,
fixturize=False,
):
"""
Starts every node connected to ``self``. Mainly, it:
1. calls ``start()`` on each node
#. sets some global variables
#. optionally starts some experimental code for monitoring
"""
def prometheus_init():
"""
Experimental code for enabling Prometheus monitoring.
"""
from prometheus_client import (
start_http_server,
Summary,
Gauge,
Histogram,
Counter,
)
for node in self.all_connected():
node.prometheus_objects = {}
summary = Summary(
node.name + "_incoming", "Summary of incoming messages"
)
node.prometheus_objects["incoming_message_summary"] = summary
node.prometheus_objects["outgoing_message_summary"] = Gauge(
node.name + "_outgoing", "Summary of outgoing messages"
)
start_http_server(8000)
if PROMETHEUS:
prometheus_init()
# thread_dict = self.thread_dict
global_dict = {}
run_id = uuid.uuid4().hex
for node in self.all_connected():
# Set the pipeline name on the attribute of each node
node.pipeline_name = pipeline_name or uuid.uuid4().hex
# Set a unique run_id
node.run_id = run_id
node.fixturize = fixturize
node.global_dict = global_dict # Establishing shared globals
logging.debug("global_start:" + str(self))
thread = threading.Thread(
target=MetalNode.stream, args=(node,), daemon=False
)
thread.start()
node.thread_dict = self.thread_dict
self.thread_dict[node.name] = thread
node.status = "running"
monitor_thread = threading.Thread(
target=MetalNode.thread_monitor,
args=(self,),
kwargs={"max_time": max_time},
daemon=True,
)
monitor_thread.start()
@property
def input_queue_size(self):
"""
Return the total number of items in all of the queues that are inputs
to this node.
"""
return sum(
[
input_queue.queue.qsize()
for input_queue in self.input_queue_list
]
)
[docs] def draw_pipeline(self):
"""
Draw the pipeline structure using graphviz.
"""
dot = graphviz.Digraph()
for node in self.all_connected():
dot.node(node.name, node.name, shape="box")
for node in self.all_connected():
for target_node in node.output_node_list:
dot.edge(node.name, target_node.name)
dot.render("pipeline_drawing.gv", view=True)
@property
def pipeline_finished(self):
finished = all(node.finished for node in self.all_connected())
self.log_info("finished. " + str(self.name))
return finished
[docs] def thread_monitor(self, max_time=None):
"""
This function loops over all of the threads in the pipeline, checking
that they are either ``finished`` or ``running``. If any have had an
abnormal exit, terminate the entire pipeline.
"""
counter = 0
error = False
time_started = time.time()
while not self.pipeline_finished:
logging.debug("MONITOR THREAD")
time.sleep(MONITOR_INTERVAL)
counter += 1
if max_time is not None:
print("checking max_time...")
if time.time() - time_started >= max_time:
self.pipeline_finished = True
print("finished because of max_time")
for node in self.all_connected():
node.finished = True
continue
# Check whether all the workers have ``.finished``
# self.pipeline_finished = all(
# node.finished for node in self.all_connected())
if counter % STATS_COUNTER_MODULO == 0:
table = prettytable.PrettyTable(
[
"Node",
"Class",
"Received",
"Sent",
"Queued",
"Status",
"Time",
]
)
for node in sorted(
list(self.all_connected()), key=lambda x: x.name
):
if node.status == "running":
status_color = bcolors.WARNING
elif node.status == "stopped":
status_color = ""
elif node.status == "error":
status_color = bcolors.FAIL
error = True
elif node.status == "success":
status_color = bcolors.OKGREEN
else:
assert False
if node.logjam >= LOGJAM_THRESHOLD:
logjam_color = bcolors.FAIL
else:
logjam_color = ""
table.add_row(
[
logjam_color + node.name + bcolors.ENDC,
node.__class__.__name__,
node.messages_received_counter,
node.messages_sent_counter,
node.input_queue_size,
status_color + node.status + bcolors.ENDC,
node.time_running,
]
)
self.log_info("\n" + str(table))
if error:
logging.error("Terminating due to error.")
self.terminate_pipeline(error=True)
self.pipeline_finished = True
break
# Check for blocked nodes
for node in self.all_connected():
input_queue_full = [
input_queue.approximately_full()
for input_queue in node.input_queue_list
]
output_queue_full = [
output_queue.approximately_full()
for output_queue in node.output_queue_list
]
logjam = (
not node.is_source
and all(input_queue_full)
and not any(output_queue_full)
)
node.logjam_score["polled"] += 1
logging.debug(
"LOGJAM SCORE: {logjam}".format(logjam=str(node.logjam))
)
if logjam:
node.logjam_score["logjam"] += 1
logging.debug(
"LOGJAM {logjam} {name}".format(
logjam=logjam, name=node.name
)
)
self.log_info("Pipeline finished.")
self.log_info("Sending terminate signal to nodes.")
self.log_info("Messages that are being processed will complete.")
# HERE
if error:
self.log_info("Abnormal exit")
sys.exit(1)
else:
self.log_info("Normal exit.")
sys.exit(0)
[docs]class CounterOfThings(MetalNode):
[docs] def bar__init__(self, *args, start=0, end=None, **kwargs):
self.start = start
self.end = end
super(CounterOfThings, self).__init__(*args, **kwargs)
[docs] def generator(self):
"""
Just start counting integers
"""
counter = 1
while 1:
yield counter
counter += 1
if counter > 10:
assert False
[docs]class FunctionOfMessage(MetalNode):
def __init__(self, function_name, *args, **kwargs):
self.function_name = function_name
components = self.function_name.split("__")
if len(components) == 1:
module = None
function_name = components[0]
function_obj = globals()[function_name]
else:
module = ".".join(components[:-1])
function_name = components[-1]
module = importlib.import_module(module)
function = getattr(module, function_name)
self.function = function
super(FunctionOfMessage, self).__init__(*args, **kwargs)
[docs]class InsertData(MetalNode):
def __init__(
self, overwrite=True, overwrite_if_null=True, value_dict=None, **kwargs
):
self.overwrite = overwrite
self.overwrite_if_null = overwrite_if_null
self.value_dict = value_dict or {}
super(InsertData, self).__init__(**kwargs)
[docs] def process_item(self):
logging.debug("INSERT DATA: " + str(self.__message__))
for key, value in self.value_dict.items():
if (
(key not in self.__message__)
or self.overwrite
or (
self.__message__.get(key) == None
and self.overwrite_if_null
)
):
self.__message__[key] = value
yield self.__message__
[docs]class RandomSample(MetalNode):
"""
Lets through only a random sample of incoming messages. Might be useful
for testing, or when only approximate results are necessary.
"""
def __init__(self, sample=0.1):
self.sample = sample
[docs]class SubstituteRegex(MetalNode):
def __init__(
self, match_regex=None, substitute_string=None, *args, **kwargs
):
self.match_regex = match_regex
self.substitute_string = substitute_string
self.regex_obj = re.compile(self.match_regex)
super(SubstituteRegex, self).__init__(*args, **kwargs)
[docs] def process_item(self):
out = self.regex_obj.sub(
self.substitute_string, self.message[self.key]
)
yield out
[docs]class CSVToDictionaryList(MetalNode):
def __init__(self, **kwargs):
super(CSVToDictionaryList, self).__init__(**kwargs)
[docs] def process_item(self):
csv_file_obj = io.StringIO(self.__message__)
csv_reader = csv.DictReader(csv_file_obj)
output = [row for row in csv_reader]
yield output
[docs]class SequenceEmitter(MetalNode):
"""
Emits ``sequence`` ``max_sequences`` times, or forever if
``max_sequences`` is ``None``.
"""
def __init__(self, sequence, *args, max_sequences=1, **kwargs):
self.sequence = sequence
self.max_sequences = max_sequences
super(SequenceEmitter, self).__init__(*args, **kwargs)
[docs] def generator(self):
"""
Emit the sequence ``max_sequences`` times.
"""
type_dict = {
"int": int,
"integer": int,
"str": str,
"string": str,
"float": float,
"bool": to_bool,
}
counter = 0
while counter < self.max_sequences:
for item in self.sequence:
if (
isinstance(item, (dict,))
and "value" in item
and "type" in item
):
item = type_dict[item["type"].lower()](item["value"])
item = {self.output_key: item}
yield item
counter += 1
[docs] def process_item(self):
"""
Emit the sequence ``max_sequences`` times.
"""
type_dict = {
"int": int,
"integer": int,
"str": str,
"string": str,
"float": float,
"bool": to_bool,
}
counter = 0
while counter < self.max_sequences:
for item in self.sequence:
if (
isinstance(item, (dict,))
and "value" in item
and "type" in item
):
item = type_dict[item["type"].lower()](item["value"])
item = {self.output_key: item}
yield item
counter += 1
[docs]class GetEnvironmentVariables(MetalNode):
'''
This node reads environment variables and stores them in the message.
The required keyword argument for this node is ``environment_variables``,
which is a list of -- you guessed it! -- environment variables. By
default, they will be read and stored in the outgoing message under
keys with the same names as the environment variables. E.g. ``FOO_VAR``
will be stored in the message ``{"FOO_BAR": whatever}``.
Optionally, you can provide a dictionary to the ``mappings`` keyword
argument, which maps environment variable names to new names. E.g.
if ``mappings = {"FOO_VAR": "bar_var"}``, then the value of ``FOO_VAR``
will be stored in the message ``{"bar_var": whatever}``.
If the environment variable is not defined, then its value will be
set to ``None``.
Args:
mappings (dict): An optional dictionary mapping environment variable
names to new names.
environment_variables (list): A list of environment variable names.
'''
def __init__(self, mappings=None, environment_variables=None, **kwargs):
self.environment_mappings = mappings or {}
self.environment_variables = environment_variables or []
super(GetEnvironmentVariables, self).__init__(**kwargs)
[docs] def generator(self):
environment = {
self.environment_mappings.get(
environment_variable, environment_variable
): os.environ.get(environment_variable, None)
for environment_variable in self.environment_variables
}
yield environment
[docs] def process_item(self):
environment = {
self.environment_mappings.get(
environment_variable, environment_variable
): os.environ.get(environment_variable, None)
for environment_variable in self.environment_variables
}
yield environment
[docs]class SimpleTransforms(MetalNode):
def __init__(
self,
missing_keypath_action="ignore",
starting_path=None,
transform_mapping=None,
target_value=None,
keypath=None,
**kwargs
):
self.missing_keypath_action = missing_keypath_action
self.transform_mapping = transform_mapping or []
self.functions_dict = {}
self.starting_path = starting_path
for transform in self.transform_mapping:
# Not doing the transforms; only loading the right functions here
function_name = transform.get("target_function", None)
full_function_name = function_name
if function_name is not None:
components = function_name.split("__")
if len(components) == 1:
module = None
function_name = components[0]
function_obj = globals()[function_name]
else:
module = ".".join(components[:-1])
function_name = components[-1]
module = importlib.import_module(module)
function = getattr(module, function_name)
self.functions_dict[full_function_name] = function
super(SimpleTransforms, self).__init__(**kwargs)
[docs] def process_item(self):
logging.debug("TRANSFORM " + str(self.name))
logging.debug(self.name + " " + str(self.message))
for transform in self.transform_mapping:
path = transform["path"]
target_value = transform.get("target_value", None)
function_name = transform.get("target_function", None)
starting_path = transform.get("starting_path", None)
if function_name is not None:
function = self.functions_dict[function_name]
else:
function = None
function_kwargs = transform.get("function_kwargs", None)
function_args = transform.get("function_args", None)
logging.debug(self.name + " calling replace_by_path:")
replace_by_path(
self.message,
tuple(path),
target_value=target_value,
function=function,
function_args=function_args,
starting_path=starting_path,
function_kwargs=function_kwargs,
)
logging.debug(
"after SimpleTransform: " + self.name + str(self.message)
)
yield self.message
[docs]class Serializer(MetalNode):
"""
Takes an iterable thing as input, and successively yields its items.
"""
def __init__(self, values=False, *args, **kwargs):
self.values = values
super(Serializer, self).__init__(**kwargs)
[docs] def process_item(self):
if self.__message__ is None:
yield None
elif self.values:
for item in self.__message__.values():
yield item
else:
for item in self.__message__:
logging.debug(self.name + " " + str(item))
yield item
[docs]class AggregateValues(MetalNode):
"""
Does that.
"""
def __init__(self, values=False, tail_path=None, **kwargs):
self.tail_path = tail_path
self.values = values
super(AggregateValues, self).__init__(**kwargs)
[docs] def process_item(self):
values = aggregate_values(
self.__message__, self.tail_path, values=self.values
)
logging.debug("aggregate_values " + self.name + " " + str(values))
yield values
[docs]class Filter(MetalNode):
"""
Applies tests to each message and filters out messages that don't pass
Built-in tests:
key_exists
value_is_true
value_is_not_none
Example:
{'test': 'key_exists',
'key': mykey}
"""
def __init__(
self, test=None, test_keypath=None, value=True, *args, **kwargs
):
self.test = test
self.value = value
self.test_keypath = test_keypath or []
super(Filter, self).__init__(*args, **kwargs)
@staticmethod
def _key_exists(message, key):
return key in message
@staticmethod
def _value_is_not_none(message, key):
logging.debug(
"value_is_not_none: {message} {key}".format(
message=str(message), key=key
)
)
return get_value(message, key) is not None
@staticmethod
def _value_is_true(message, key):
return to_bool(message.get(key, False))
[docs] def process_item(self):
if self.test in ["key_exists", "value_is_not_none", "value_is_true"]:
result = (
getattr(self, "_" + self.test)(
self.__message__, self.test_keypath
)
== self.value
)
else:
raise Exception("Unknown test: {test_name}".format(test_name=test))
if result:
logging.debug("Sending message through")
yield self.message
else:
logging.debug("Blocking message: " + str(self.__message__))
yield NothingToSeeHere()
[docs]class StreamMySQLTable(MetalNode):
def __init__(
self,
*args,
host="localhost",
user=None,
table=None,
password=None,
database=None,
port=3306,
to_row_obj=False,
send_batch_markers=True,
**kwargs
):
self.host = host
self.user = user
self.to_row_obj = to_row_obj
self.password = password
self.database = database
self.port = port
self.table = table
super(StreamMySQLTable, self).__init__(**kwargs)
[docs] def setup(self):
self.db = MySQLdb.connect(
passwd=self.password,
db=self.database,
user=self.user,
port=self.port,
)
self.cursor = MySQLdb.cursors.DictCursor(self.db)
self.table_schema_query = (
"""SELECT column_name, column_type """
"""FROM information_schema.columns """
"""WHERE table_name='{table}';""".format(table=self.table)
)
self.table_schema = self.get_schema()
# Need a mapping from header to MYSQL TYPE
for mapping in self.table_schema:
column = mapping["column_name"]
type_string = mapping["column_type"]
this_type = ds.MySQLTypeSystem.type_mapping(type_string)
# Unfinished experimental code
# Start here:
# store the type_mapping
# use it to cast the data into the MySQLTypeSchema
# ensure that the generator is emitting MySQLTypeSchema objects
[docs] def get_schema(self):
self.cursor.execute(self.table_schema_query)
table_schema = self.cursor.fetchall()
return table_schema
[docs] def generator(self):
if self.send_batch_markers:
yield BatchStart(schema=self.table_schema)
self.cursor.execute(
"""SELECT * FROM {table};""".format(table=self.table)
)
result = self.cursor.fetchone()
while result is not None:
if self.to_row_obj:
result = Row.from_dict(result, type_system=MySQLTypeSystem)
yield result
result = self.cursor.fetchone()
if self.send_batch_markers:
yield BatchEnd()
[docs]class PrinterOfThings(MetalNode):
@set_kwarg_attributes()
def __init__(
self, disable=False, pretty=False, prepend="printer: ", **kwargs
):
self.disable = disable
self.pretty = pretty
super(PrinterOfThings, self).__init__(**kwargs)
logging.debug("Initialized printer...")
[docs] def process_item(self):
if not self.disable:
print(self.prepend)
if self.pretty:
pprint.pprint(self.__message__, indent=2)
else:
print(str(self.__message__))
print("\n")
print("------------")
yield self.message
[docs]class ConstantEmitter(MetalNode):
"""
Send a thing every n seconds
"""
def __init__(self, thing=None, max_loops=5, delay=0.5, **kwargs):
self.thing = thing
self.delay = delay
self.max_loops = max_loops
super(ConstantEmitter, self).__init__(**kwargs)
[docs] def generator(self):
counter = 0
while counter < self.max_loops:
if random.random() < -0.1:
assert False
time.sleep(self.delay)
yield self.thing
counter += 1
[docs]class TimeWindowAccumulator(MetalNode):
"""
Every N seconds, put the latest M seconds data on the queue.
"""
@set_kwarg_attributes()
def __init__(self, time_window=None, send_interval=None, **kwargs):
pass
[docs]class LocalFileReader(MetalNode):
@set_kwarg_attributes()
def __init__(
self,
directory=".",
send_batch_markers=True,
serialize=False,
read_mode="r",
**kwargs
):
super(LocalFileReader, self).__init__(**kwargs)
[docs] def process_item(self):
filename = "/".join([self.directory, self.message])
with open(filename, self.read_mode) as file_obj:
if self.serialize:
if self.send_batch_markers:
yield BatchStart()
for line in file_obj:
output = line
yield output
yield BatchEnd()
else:
output = file_obj.read()
yield output
[docs]class CSVReader(MetalNode):
@set_kwarg_attributes()
def __init__(self, send_batch_markers=True, to_row_obj=True, **kwargs):
super(CSVReader, self).__init__(**kwargs)
[docs] def process_item(self):
file_obj = io.StringIO(self.message)
reader = csv.DictReader(file_obj)
if self.send_batch_markers:
yield BatchStart()
for row in reader:
if self.to_row_obj:
row = Row.from_dict(row)
yield row
if self.send_batch_markers:
yield BatchEnd()
[docs]class LocalDirectoryWatchdog(MetalNode):
def __init__(self, directory=".", check_interval=3, **kwargs):
self.directory = directory
self.latest_arrival = time.time()
self.check_interval = check_interval
super(LocalDirectoryWatchdog, self).__init__(**kwargs)
[docs] def generator(self):
while self.keep_alive:
logging.debug("sleeping...")
time.sleep(self.check_interval)
time_in_interval = None
for filename in os.listdir(self.directory):
last_modified_time = os.path.getmtime(
"/".join([self.directory, filename])
)
if last_modified_time > self.latest_arrival:
yield "/".join([self.directory, filename])
if (
time_in_interval is None
or last_modified_time > time_in_interval
):
time_in_interval = last_modified_time
logging.debug(
"time_in_interval: " + str(time_in_interval)
)
if time_in_interval is not None:
self.latest_arrival = time_in_interval
[docs]class StreamingJoin(MetalNode):
"""
Joins two streams on a key, using exact match only. MVP.
"""
def __init__(self, window=30, streams=None, *args, **kwargs):
self.window = window
self.streams = streams
self.stream_paths = streams
self.buffers = {
stream_name: TimedDict(timeout=self.window)
for stream_name in self.stream_paths.keys()
}
super(StreamingJoin, self).__init__(*args, **kwargs)
[docs] def process_item(self):
"""
"""
value_to_match = get_value(
self.message, self.stream_paths[self.message_source.name]
)
# Check for matches in all other streams.
# If complete set of matches, yield the merged result
# If not, add it to the `TimedDict`.
yield ("hi")
[docs]class DynamicClassMediator(MetalNode):
def __init__(self, *args, **kwargs):
super(DynamicClassMediator, self).__init__(**kwargs)
for node_name, node_dict in self.node_dict.items():
cls_obj = node_dict["cls_obj"]
node_obj = cls_obj(**kwargs)
node_dict["obj"] = node_obj
for edge in self.raw_config["edges"]:
source_node_obj = self.node_dict[edge["from"]]["obj"]
target_node_obj = self.node_dict[edge["to"]]["obj"]
source_node_obj > target_node_obj
def bind_methods():
for attr_name in dir(DynamicClassMediator):
if attr_name.startswith("_"):
continue
attr_obj = getattr(DynamicClassMediator, attr_name)
if not isinstance(attr_obj, types.FunctionType):
continue
setattr(self, attr_name, types.MethodType(attr_obj, self))
bind_methods()
source = self.get_source()
self.input_queue_list = source.input_queue_list
sink = self.get_sink()
self.output_queue_list = sink.output_queue_list
self.output_node_list = sink.output_node_list
self.input_node_list = source.input_node_list
[docs] def get_sink(self):
sinks = self.sink_list()
if len(sinks) > 1:
raise Exception(
"`DynamicClassMediator` may have no more than one sink."
)
elif len(sinks) == 0:
return None
return sinks[0]
[docs] def get_source(self):
sources = self.source_list()
if len(sources) > 1:
raise Exception(
"`DynamicClassMediator` may have no more than one source."
)
elif len(sources) == 0:
return None
return sources[0]
[docs] def sink_list(self):
sink_nodes = []
for node_name, node_dict in self.node_dict.items():
node_obj = node_dict["obj"]
if len(node_obj.output_queue_list) == 0:
sink_nodes.append(node_obj)
return sink_nodes
[docs] def source_list(self):
source_nodes = [
node_dict["obj"]
for node_dict in self.node_dict.values()
if node_dict["obj"].is_source
]
return source_nodes
[docs]def get_node_dict(node_config):
node_dict = {}
for node_config in node_config["nodes"]:
node_class = globals()[node_config["class"]]
node_name = node_config["name"]
node_dict[node_name] = {}
node_dict[node_name]["class"] = node_class
frozen_arguments = node_config.get("frozen_arguments", {})
node_dict[node_name]["frozen_arguments"] = frozen_arguments
node_obj = node_class(**frozen_arguments)
node_dict[node_name]["remapping"] = node_config.get("arg_mapping", {})
return node_dict
[docs]def kwarg_remapper(f, **kwarg_mapping):
reverse_mapping = {value: key for key, value in kwarg_mapping.items()}
logging.debug("kwarg_mapping:" + str(kwarg_mapping))
parameters = [i for i, _ in list(inspect.signature(f).parameters.items())]
for kwarg in parameters:
if kwarg not in kwarg_mapping:
reverse_mapping[kwarg] = kwarg
def remapped_function(*args, **kwargs):
remapped_kwargs = {}
for key, value in kwargs.items():
if key in reverse_mapping:
remapped_kwargs[reverse_mapping[key]] = value
logging.debug("renamed function with kwargs: " + str(remapped_kwargs))
return f(*args, **remapped_kwargs)
return remapped_function
[docs]def template_class(
class_name, parent_class, kwargs_remapping, frozen_arguments_mapping
):
kwargs_remapping = kwargs_remapping or {}
frozen_init = functools.partial(
parent_class.__init__, **frozen_arguments_mapping
)
if isinstance(parent_class, (str,)):
parent_class = globals()[parent_class]
cls = type(class_name, (parent_class,), {})
setattr(cls, "__init__", kwarg_remapper(frozen_init, **kwargs_remapping))
return cls
[docs]def class_factory(raw_config):
new_class = type(raw_config["name"], (DynamicClassMediator,), {})
new_class.node_dict = get_node_dict(raw_config)
new_class.class_name = raw_config["name"]
new_class.edge_list_dict = raw_config.get("edges", [])
new_class.raw_config = raw_config
for node_name, node_config in new_class.node_dict.items():
_class = node_config["class"]
cls = template_class(
node_name,
_class,
node_config["remapping"],
node_config["frozen_arguments"],
)
setattr(cls, "raw_config", raw_config)
node_config["cls_obj"] = cls
# Inject?
globals()[new_class.__name__] = new_class
return new_class
[docs]class Remapper(MetalNode):
def __init__(self, mapping=None, **kwargs):
self.remapping_dict = mapping or {}
super(Remapper, self).__init__(**kwargs)
[docs] def process_item(self):
logging.debug(
"Remapper {node}:".format(node=self.name) + str(self.__message__)
)
out = remap_dictionary(self.__message__, self.remapping_dict)
yield out
[docs]class BatchMessages(MetalNode):
def __init__(
self, batch_size=None, batch_list=None, counter=0, timeout=5, **kwargs
):
self.batch_size = batch_size
self.timeout = timeout
self.counter = 0
self.batch_list = batch_list or []
super(BatchMessages, self).__init__(**kwargs)
[docs] def process_item(self):
self.counter += 1
self.batch_list.append(self.__message__)
logging.debug(self.name + " " + str(self.__message__))
out = NothingToSeeHere()
if self.counter % self.batch_size == 0:
out = self.batch_list
logging.debug("BatchMessages: " + str(out))
self.batch_list = []
yield out
[docs] def cleanup(self):
self.log_info(self.name + " in cleanup, sending remainder of batch...")
yield self.batch_list
if __name__ == "__main__":
pass