forked from jaberg/hyperopt
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraphviz.py
More file actions
76 lines (61 loc) · 2.38 KB
/
graphviz.py
File metadata and controls
76 lines (61 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Use graphviz's dot language to express the relationship between hyperparamters
in a search space.
"""
import StringIO
from pyll_utils import expr_to_config
def dot_hyperparameters(expr):
"""
Return a dot language specification of a graph which describes the
relationship between hyperparameters. Each hyperparameter within the
pyll expression `expr` is represented by a rectangular node, and
each value of each choice node that creates a conditional variable
in the search space is represented by an elliptical node.
The direction of the arrows corresponds to the sequence of events
in an ancestral sampling process.
E.g.:
>>> open('foo.dot', 'wb').write(dot_hyperparameters(search_space()))
Then later from the shell, type e.g.
dot -Tpng foo.dot > foo.png && eog foo.png
Graphviz has other tools too: http://www.graphviz.org
"""
conditions = ()
hps = {}
expr_to_config(expr, conditions, hps)
rval = StringIO.StringIO()
print >> rval, "digraph {"
edges = set()
def var_node(a):
print >> rval, '"%s" [ shape=box];' % a
def cond_node(a):
print >> rval, '"%s" [ shape=ellipse];' % a
def edge(a, b):
text = '"%s" -> "%s";' % (a, b)
if text not in edges:
print >> rval, text
edges.add(text)
for hp, dct in hps.items():
# create the node
var_node(hp)
# create an edge from anything it depends on
for and_conds in dct['conditions']:
if len(and_conds) > 1:
parent_label = ' & '.join([
'%(name)s%(op)s%(val)s' % cond.__dict__
for cond in and_conds])
cond_node(parent_label)
edge(parent_label, hp)
for cond in and_conds:
sub_parent_label = '%s%s%s' % (
cond.name, cond.op, cond.val)
cond_node(sub_parent_label)
edge(cond.name, sub_parent_label)
edge(sub_parent_label, parent_label)
elif len(and_conds) == 1:
parent_label = '%s%s%s' % (
and_conds[0].name, and_conds[0].op, and_conds[0].val)
edge(and_conds[0].name, parent_label)
cond_node(parent_label)
edge(parent_label, hp)
print >> rval, "}"
return rval.getvalue()