summaryrefslogtreecommitdiff
path: root/runtime/production_rule.py
blob: c0c648199dff99edfef96d346187cdcd2e10c18b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""Methods for parsing, mapping against, and executing structure /RULEs.
"""
# pylint: disable=no-name-in-module,import-error
from collections import defaultdict
from runtime.pattern import Pattern
import runtime.utils as utils

MAP_TYPES = ["/MUST_MAP", "/TRY_MAP", "/NO_MAP"]
ACTION_TYPES = ["/REMOVE", "/SUBTRACT", "/INSERT"]
NODE_TYPES = MAP_TYPES + ACTION_TYPES

class ProductionRule:
    """Represents a single /RULE in the TripletStructure.
    """
    def __init__(self, runtime, rule):
        """Initializes a ProductionRule given the corresponding node.

        NOTE: The structure can be modified arbitrarily once a ProductionRule
        is initialized; all relevant information is copied into the object
        itself.
        """
        self.runtime = runtime
        self.ts = runtime.ts
        self.name = rule
        self.is_backtracking = None
        self.parse_rule()
        self.assign_variables()
        self.prepare_constraints()
        self.facts = [fact for node in self.all_nodes
                      for fact in self.ts.lookup(node, None, None)]
        self.indexed_facts = dict({
            node: self.ts.lookup(node, None, None).copy()
            for node in self.all_nodes
        })

    def parse_rule(self):
        """Parses the relevant nodes to the rule (eg. MUST_MAP, etc.)
        """
        rule_node, solver = self.name, self.runtime.solver

        self.nodes_by_type = defaultdict(list)
        self.nodes_by_type["/NO_MAP"] = defaultdict(list)
        self.map_nodes = []
        self.all_nodes = set({rule_node})
        pattern = [(0, rule_node, "/RULE"), (0, 1, 2)]
        for assignment in solver.assignments(pattern):
            self.all_nodes.add(assignment[0])
            value, key = assignment[1], assignment[2]
            try:
                node_type = next(node_type for node_type in NODE_TYPES
                                 if key.startswith(node_type))
            except StopIteration:
                continue
            if node_type != "/NO_MAP":
                assert node_type == key
                self.nodes_by_type[key].append(value)
            else:
                index = int(key.split("/NO_MAP")[1].strip("_") or 0)
                self.nodes_by_type[node_type][index].append(value)
            if node_type in MAP_TYPES:
                self.map_nodes.append(value)
            self.all_nodes.add(value)

        self.equal = defaultdict(set)
        for assignment in solver.assignments([(0, rule_node, "/RULE"),
                                              (0, 1, "/="),
                                              (0, 2, "/=")]):
            # NOTE that assignments will call this for each permutation of the
            # two, so we don't actually need to special-case that.
            self.equal[assignment[1]].add(assignment[2])

        self.maybe_equal = defaultdict(set)
        for assignment in solver.assignments([(0, rule_node, "/RULE"),
                                              (0, 1, "/MAYBE="),
                                              (0, 2, "/MAYBE=")]):
            # NOTE see above.
            self.maybe_equal[assignment[1]].add(assignment[2])

    def assign_variables(self):
        """Gives each node in the mapping a variable name/ID/number.

        Populates node_to_variable, variable_to_node, and
        maybe_equal_variables.
        """
        # Every node that we might map against gets a variable name/ID.
        # Variable ID -> Node
        self.node_to_variable = dict()
        self.variable_to_node = dict()
        self.maybe_equal_variables = dict()
        # It's tempting here to give variables just to nodes marked /MUST_MAP,
        # but then you quickly run into problems because we often want to match
        # for the /NO_MAP nodes, etc. Plus other parts of the code will often
        # use variable numbers and node names interchangeably, so it's just
        # easier to support variable numbers for all relevant nodes. Unused
        # variable numbers don't have any negative impact.
        for node in sorted(self.all_nodes):
            variable = len(self.node_to_variable)
            for equivalent in self.equal[node]:
                try:
                    variable = self.node_to_variable[equivalent]
                    break
                except KeyError:
                    continue
            self.node_to_variable[node] = variable
            # The last one will take priority; we sort the iteration order so
            # this is deterministic.
            self.variable_to_node[variable] = node
            self.maybe_equal_variables[variable] = set({variable})

        for key, values in self.maybe_equal.items():
            key = self.node_to_variable[key]
            values = set(map(self.node_to_variable.get, values))
            self.maybe_equal_variables[key].update(values)

    def prepare_constraints(self):
        """Extracts the constraints corresponding to the rule.

        MUST be called after parse_rule().
        """
        ts = self.ts
        node_to_variable = utils.Translator(self.node_to_variable)

        # We will keep track of which nodes have constraints on them to avoid
        # free nodes.
        constrained = set()

        no_map = dict({
            node: index for index in self.nodes_by_type["/NO_MAP"].keys()
            for node in self.nodes_by_type["/NO_MAP"][index]})
        no_map_nodes = set(no_map.keys())

        def pattern():
            return Pattern(self.runtime, [], self.maybe_equal_variables,
                           self.variable_to_node)
        self.must_pattern = pattern()
        self.try_pattern = pattern()
        self.never_patterns = defaultdict(pattern)
        relevant_facts = (fact for node in self.map_nodes
                          for fact in ts.lookup(node, None, None))
        for fact in relevant_facts:
            constrained.update(fact)
            constraint = node_to_variable.translate_tuple(fact)
            arguments = set(fact)
            if arguments & no_map_nodes:
                index = next(no_map[argument] for argument in fact
                             if argument in no_map)
                self.never_patterns[index].add_constraint(constraint)
            elif arguments & set(self.nodes_by_type["/TRY_MAP"]):
                self.try_pattern.add_constraint(constraint)
            elif arguments & set(self.nodes_by_type["/INSERT"]):
                # We can try or fail to map against INSERT nodes, but they
                # should never be must_maps.
                pass
            else:
                assert set(fact) & set(self.nodes_by_type["/MUST_MAP"])
                self.must_pattern.add_constraint(constraint)

        assert not no_map_nodes & set(self.nodes_by_type["/TRY_MAP"])
        assert set(self.map_nodes) <= constrained

    def invalid(self, assignment):
        """True iff @assignment allows some of the /NEVER_MAPs to map.
        """
        return any(not utils.is_empty(pattern.assignments(assignment))
                   for pattern in self.never_patterns.values())
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback