diff --git a/src/main/java/com/olympus/hermione/services/Neo4JUitilityService.java b/src/main/java/com/olympus/hermione/services/Neo4JUitilityService.java index 0efdc70..1b98497 100644 --- a/src/main/java/com/olympus/hermione/services/Neo4JUitilityService.java +++ b/src/main/java/com/olympus/hermione/services/Neo4JUitilityService.java @@ -85,10 +85,42 @@ public class Neo4JUitilityService { relationshipObject.add("properties", properties); } + enrichRelationshipsWithNodeTypes(session, relationshipsMap); + JsonArray relationshipsArray = new JsonArray(); relationshipsMap.values().forEach(relationshipsArray::add); return relationshipsArray; } + private void enrichRelationshipsWithNodeTypes(Session session, Map relationshipsMap) { + String query = "MATCH (n1)-[r]->(n2) RETURN DISTINCT type(r) AS relType, labels(n1) AS startLabels, labels(n2) AS endLabels"; + List records = session.run(query).list(); + + for (Record record : records) { + String relType = record.get("relType").asString(); + List startLabels = record.get("startLabels").asList(Value::asString); + List endLabels = record.get("endLabels").asList(Value::asString); + + JsonObject relationshipObject = relationshipsMap.get(":`"+relType+"`"); + if (relationshipObject != null) { + JsonArray startNodeTypes = relationshipObject.has("startNodeTypes") ? relationshipObject.get("startNodeTypes").getAsJsonArray() : new JsonArray(); + startLabels.forEach(label -> { + if (!startNodeTypes.contains(new JsonPrimitive(label))) { + startNodeTypes.add(new JsonPrimitive(label)); + } + }); + relationshipObject.add("startNodeTypes", startNodeTypes); + + JsonArray endNodeTypes = relationshipObject.has("endNodeTypes") ? relationshipObject.get("endNodeTypes").getAsJsonArray() : new JsonArray(); + endLabels.forEach(label -> { + if (!endNodeTypes.contains(new JsonPrimitive(label))) { + endNodeTypes.add(new JsonPrimitive(label)); + } + }); + relationshipObject.add("endNodeTypes", endNodeTypes); + } + } + } + }