summaryrefslogtreecommitdiff
path: root/runtime/pattern.py
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/pattern.py')
-rw-r--r--runtime/pattern.py114
1 files changed, 114 insertions, 0 deletions
diff --git a/runtime/pattern.py b/runtime/pattern.py
new file mode 100644
index 0000000..d2cdf4b
--- /dev/null
+++ b/runtime/pattern.py
@@ -0,0 +1,114 @@
+"""Methods to simplify looking for patterns in a Structure.
+"""
+# pylint: disable=import-error
+from collections import defaultdict
+import runtime.utils as utils
+
+class Pattern:
+ """Represents an existential query/pattern to be matched against.
+ """
+ def __init__(self, runtime, constraints, maybe_equal, variable_names):
+ """Initializes the Pattern.
+
+ Arguments
+ =========
+ - @constraints should be a list of 3-tuples with strings (representing
+ nodes) or integers (representing variables) as elements.
+ - @maybe_equal should be a dictionary mapping {variable:
+ maybe_equivalent_variables}.
+ - @variable_names should be the corresponding names (in the structure)
+ of all variables.
+ """
+ self.runtime = runtime
+ self.constraints = constraints
+ self.maybe_equal = maybe_equal
+ self.variable_names = variable_names
+
+ def named_assignment_to_vars(self, assignment):
+ """Converts an assignment mapping node names to mapping numbers."""
+ names_to_vars = dict({name: var
+ for var, name in self.variable_names.items()})
+ return dict({names_to_vars[k]: v for k, v in assignment.items()})
+
+ def n_variables(self):
+ """Returns the number of variables to solve for in the pattern."""
+ return len(set(
+ arg for fact in self.constraints for arg in fact
+ if isinstance(arg, int)))
+
+ def assignments(self, partial_assignment=None):
+ """Yields assignments satisfying the pattern.
+
+ Each assignment is a dict {variable: node}. @partial_assignment, if
+ provided, can be used to initialize some of the variables (see eg.
+ production_rule.py:typecheck_with_facts).
+ """
+ if not self.constraints and partial_assignment is not None:
+ yield partial_assignment
+ return
+ assert self.constraints
+
+ partial_assignment = partial_assignment or dict()
+ partial_assignment = utils.Translator(partial_assignment)
+
+ constraints = partial_assignment.translate_tuples(self.constraints)
+ assignments = self.runtime.solver.assignments(
+ constraints, self.maybe_equal)
+ for assignment in assignments:
+ assignment = partial_assignment.concatenated_with(assignment)
+ if self.valid_maybe_equals(assignment):
+ yield assignment
+
+ def equivalence_class(self, member):
+ """Returns the equivalence class corresponding to variable @member.
+ """
+ if self.maybe_equal and member in self.maybe_equal:
+ return self.maybe_equal[member]
+ return set({member})
+
+ def valid_maybe_equals(self, assignment):
+ """Ensures any variables assigned together are marked maybe_equal.
+
+ This happens particularly when using @partial_assignment.
+ """
+ preimages = defaultdict(set)
+ for variable, node in assignment.items():
+ preimages[node].add(variable)
+ for variables in preimages.values():
+ equivalence_class = self.equivalence_class(next(iter(variables)))
+ if not variables <= equivalence_class:
+ return False
+ return True
+
+ def add_constraint(self, constraint):
+ """Adds more constraints to the Pattern.
+ """
+ self.constraints.append(constraint)
+
+ def is_assignment(self, assignment):
+ """True iff @assignment is a valid assignment. to the pattern."""
+ return len(assignment) == self.n_variables() and self.is_partial(assignment)
+
+ def is_partial(self, partial, check_eq=True):
+ """If @partial can be extended to a valid assignment, returns True.
+
+ NOTE: This is an *IF* not if*F*.
+ If @check_eq=True, then it will always return False if @partial
+ contradicts the equality constraints of the pattern.
+ """
+ assert not partial or isinstance(list(partial.keys())[0], int)
+ if check_eq and not self.valid_maybe_equals(partial):
+ return False
+ partial = partial or dict()
+ partial = utils.Translator(partial)
+ constraints = partial.translate_tuples(self.constraints)
+ for constraint in constraints:
+ constraint = tuple(arg if isinstance(arg, str) else None
+ for arg in constraint)
+ if not self.runtime.ts.lookup(*constraint, read_direct=True):
+ return False
+ return True
+
+ def __str__(self):
+ """Human-readable version of the Pattern."""
+ return str(self.constraints)
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback