forked from srl-labs/clab-io-draw
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdrawio2clab.py
309 lines (264 loc) · 13.5 KB
/
drawio2clab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import argparse, os
import xml.etree.ElementTree as ET
from lib.Yaml_processor import YAMLProcessor
def report_error(message):
"""Prints an error message to the console."""
print(f"Error: {message}")
def parse_xml(file_path, diagram_name=None):
"""
Parses an XML file and returns the mxGraphModel/root element for the specified diagram name.
If no diagram name is specified or the specified diagram is not found, defaults to the first diagram.
"""
tree = ET.parse(file_path)
root = tree.getroot()
# If a diagram name is specified, try to find the diagram by name
if diagram_name:
for diagram in root.findall('diagram'):
if diagram.get('name') == diagram_name:
# Directly navigate to the mxGraphModel/root within the selected diagram
mxGraphModel_root = diagram.find('.//mxGraphModel/root')
if mxGraphModel_root is not None:
return mxGraphModel_root
else:
print(f"mxGraphModel/root not found in diagram '{diagram_name}'.")
return None
print(f"Diagram named '{diagram_name}' not found.")
return None
# Default to the first diagram if no name is specified
first_diagram = root.find('diagram')
if first_diagram is not None:
mxGraphModel_root = first_diagram.find('.//mxGraphModel/root')
if mxGraphModel_root is not None:
return mxGraphModel_root
else:
print("mxGraphModel/root not found in the first diagram.")
return None
print("No diagrams found in the file.")
return None
def extract_nodes(mxGraphModel, default_kind):
"""
Extracts and returns node names and their IDs from the mxGraphModel.
Handles both standalone mxCell elements with their own IDs and
object elements with embedded mxCell, using the object's ID.
"""
node_details = {}
# Process all objects which might contain nodes or represent nodes directly
for obj in mxGraphModel.findall(".//object"):
node_id = obj.get('id')
node_label = obj.get('label', '').strip()
node_type = obj.get('type', None) # Capture the 'type' attribute
mgmt_ipv4 = obj.get('mgmt-ipv4', None)
group = obj.get('group', None)
labels = obj.get('labels', None) # Assuming 'labels' is stored directly; adjust if it's more complex
node_kind = obj.get('kind', default_kind)
# If label is not directly on the object, try to find it in a child mxCell
if not node_label:
mxCell = obj.get('mxCell')
if not mxCell:
continue
if mxCell is not None:
node_label = mxCell.get('value', '').strip()
# Add to node_details if a label was found
if node_label:
node_details[node_id] = {
'label': node_label,
'type': node_type,
'mgmt-ipv4': mgmt_ipv4,
'group': group,
'labels': labels,
'kind': node_kind
}
# Process all mxCell elements that have vertex='1'
for mxCell in mxGraphModel.findall(".//mxCell[@vertex='1']"):
node_id = mxCell.get('id')
# Ensure this mxCell is not a descendant of an object element
if mxCell.find("ancestor::object") is None and node_id not in node_details:
node_label = mxCell.get('value', '').strip()
style = mxCell.get('style', '')
if not "image=data" in style:
continue
if node_label:
node_details[node_id] = {
'label': node_label,
'kind': default_kind
}
return node_details
def extract_links(mxGraphModel, node_details):
"""
Extracts link information from the mxGraphModel, including source, target,
and any geometric data. Links are represented by mxCell elements with a source and target.
"""
links_info = {}
if mxGraphModel:
# Process links defined directly within mxGraphModel
for mxCell in mxGraphModel.findall(".//mxCell[@source][@target][@edge]"):
link_info = extract_link_info(mxCell, node_details)
if link_info:
links_info[link_info['id']] = link_info
# Process links defined within objects
for object_elem in mxGraphModel.findall(".//object"):
mxCells = object_elem.findall(".//mxCell[@source][@target][@edge]")
for mxCell in mxCells:
# Use the object's ID as a fallback if the mxCell lacks an ID
object_id = object_elem.get('id')
link_info = extract_link_info(mxCell, node_details, fallback_id=object_id)
if link_info:
links_info[link_info['id']] = link_info
return links_info
def extract_link_info(mxCell, node_details, fallback_id=None):
"""
Extracts information for a single link from an mxCell element,
including its source, target, and geometry.
"""
source_id, target_id = mxCell.get('source'), mxCell.get('target')
link_id = mxCell.get('id') or fallback_id
geometry = mxCell.find('mxGeometry')
x, y = (float(geometry.get(coord, 0)) for coord in ('x', 'y')) if geometry is not None else (None, None)
# Adjusted to access 'label' from node_details
source_label = node_details.get(source_id, {}).get('label', "Unknown")
target_label = node_details.get(target_id, {}).get('label', "Unknown")
if link_id:
return {
'id': link_id,
'source': source_label,
'target': target_label,
'geometry': {'x': x, 'y': y},
'labels': []
}
def extract_link_labels(mxGraphModel, links_info):
"""
Enhances links with label information, including the label's value and geometric position.
This function populates the labels array within each link's information.
"""
for mxCell in mxGraphModel.findall(".//mxCell[@value]"):
parent_id = mxCell.get('parent')
if parent_id in links_info:
label_value, geometry = mxCell.get('value'), mxCell.find("mxGeometry")
if label_value and geometry is not None:
x_position, y_position = float(geometry.get('x', 0)), float(geometry.get('y', 0))
links_info[parent_id]['labels'].append({
'value': label_value,
'x_position': x_position,
'y_position': y_position
})
def aggregate_node_information(node_details):
"""
Aggregates node information by potentially modifying the 'kind' for each node based on specific criteria.
If a 'kind' is explicitly provided, it is respected. If not, the function assigns 'linux' to nodes
with 'client' in their label, maintaining other kinds as defined or defaulting to 'nokia_srlinux'.
"""
updated_node_details = {}
for node_id, details in node_details.items():
# Use the existing 'kind' if it's explicitly defined and not default 'nokia_srlinux', or
# apply custom logic to determine 'kind'.
if details.get('kind') and details.get('kind') != 'nokia_srlinux':
# 'kind' is explicitly provided, so we keep it.
node_kind = details['kind']
else:
# Apply custom logic: if 'client' in label, set as 'linux'; otherwise, keep existing or default to 'nokia_srlinux'.
node_kind = 'linux' if 'client' in details['label'] else details.get('kind', 'nokia_srlinux')
# Update node details with potentially modified 'kind'.
updated_details = details.copy()
updated_details['kind'] = node_kind
updated_node_details[node_id] = updated_details
return updated_node_details
def compile_link_information(links_info, style='block'):
"""
Compiles and formats link information into a structured format.
When there are three or more labels on a link, only the labels closest to the source and destination are considered.
The 'style' parameter determines the format of the endpoints in the output.
"""
compiled_links = []
for link_id, info in links_info.items():
sorted_labels = sorted(info['labels'], key=lambda label: label['x_position'])
# Handle insufficient labels gracefully
if len(sorted_labels) < 2:
report_error(f"Not enough labels for link {link_id}. At least 2 labels are required.")
continue # Skip this link
source_label = sorted_labels[0]['value']
target_label = sorted_labels[-1]['value']
endpoints = [f"{info['source']}:{source_label}", f"{info['target']}:{target_label}"]
compiled_links.append({'endpoints': endpoints})
# Sort the compiled_links list by the source of the endpoint
compiled_links.sort(key=lambda x: (x['endpoints'][0].split(':')[0] if style == 'block' else x['endpoints']))
return compiled_links
def filter_nodes(links_info, node_details):
"""
Filters nodes to include only those nodes involved in the links, based on their labels.
This helps to eliminate any nodes that are not part of the actual topology being described.
"""
# Collect labels of nodes involved in links
involved_labels = {info['source'] for info in links_info.values()} | {info['target'] for info in links_info.values()}
# Filter node_details to include only those nodes with labels involved in links
filtered_details = {node_id: details for node_id, details in node_details.items() if details['label'] in involved_labels}
return filtered_details
def generate_yaml_structure(filtered_node_details, compiled_links, input_file):
"""
Generates the final YAML structure based on filtered node details and compiled link information,
omitting the 'type' key for nodes where it is not explicitly provided.
"""
base_name = os.path.splitext(os.path.basename(input_file))[0]
nodes = {}
kinds = {}
for node_id, details in filtered_node_details.items():
node_label = details['label']
node_kind = details['kind']
# Dynamically construct the kinds dictionary with default settings for known kinds
if node_kind not in kinds:
if node_kind == 'nokia_srlinux':
kinds[node_kind] = {'image': 'ghcr.io/nokia/srlinux', 'type': 'ixrd3'}
elif node_kind == 'linux':
kinds[node_kind] = {'image': 'ghcr.io/hellt/network-multitool'}
elif node_kind == 'vr-sros':
kinds[node_kind] = {'image': 'registry.srlinux.dev/pub/vr-sros:23.10'}
else:
kinds[node_kind] = {'image': None} # Add more conditions as necessary for other kinds
# Prepare node information, conditionally including 'type' if it exists and is not None
node_info = {'kind': node_kind}
if details.get('type'):
node_info['type'] = details['type']
# Optionally add other details if they exist and are not None
for attr in ['mgmt-ipv4', 'group', 'labels']:
attr_value = details.get(attr)
if attr_value is not None: # Check for None to exclude 'null' values
node_info[attr] = attr_value
# Assign the node information to the node label, ensuring no null values are included
nodes[node_label] = node_info
return {
'name': base_name,
'topology': {
'kinds': kinds,
'nodes': nodes,
'links': compiled_links
}
}
def main(input_file, output_file, style='block', diagram_name=None, default_kind='nokia_srlinux'):
"""
The main function orchestrates the parsing, extraction, and processing of .drawio XML content,
and then generates and writes the YAML file. It ties together all the steps necessary to convert
.drawio diagrams into YAML-based network topologies.
"""
if not output_file:
output_file = os.path.splitext(input_file)[0] + ".yaml"
root = parse_xml(input_file, diagram_name)
node_details = extract_nodes(root, default_kind)
links_info = extract_links(root, node_details)
extract_link_labels(root, links_info)
node_details = aggregate_node_information(node_details)
compiled_links = compile_link_information(links_info, style)
filtered_nodes = filter_nodes(links_info, node_details)
yaml_data = generate_yaml_structure(filtered_nodes, compiled_links, input_file)
processor = YAMLProcessor()
if style == "flow":
processor.save_yaml(yaml_data, output_file)
else:
processor.save_yaml(yaml_data, output_file, flow_style='block')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Parse a draw.io XML file and generate a YAML file with a specified style.")
parser.add_argument("-i", "--input", dest="input_file", required=True, help="The input XML file to be parsed.")
parser.add_argument("-o", "--output", dest="output_file", required=False, help="The output YAML file.")
parser.add_argument("--style", dest="style", choices=['block', 'flow'], default="flow", help="The style for YAML endpoints. Choose 'block' or 'flow'. Default is 'block'.")
parser.add_argument("--diagram-name", dest="diagram_name", required=False, help="The name of the diagram (tab) to be parsed.")
parser.add_argument("--default-kind", dest="default_kind", default="nokia_srlinux", help="The default kind for nodes. Default is 'nokia_srlinux'.")
args = parser.parse_args()
main(args.input_file, args.output_file, args.style, args.diagram_name, args.default_kind)