1
- from typing import Union
2
1
import dspy
3
2
from pydantic import BaseModel , Field
4
3
from nano_graphrag ._utils import clean_str
4
+ from nano_graphrag ._utils import logger
5
5
6
6
7
7
"""
@@ -75,6 +75,14 @@ class Entity(BaseModel):
75
75
description = "Importance score of the entity. Should be between 0 and 1 with 1 being the most important." ,
76
76
)
77
77
78
+ def to_dict (self ):
79
+ return {
80
+ "entity_name" : clean_str (self .entity_name .upper ()),
81
+ "entity_type" : clean_str (self .entity_type .upper ()),
82
+ "description" : clean_str (self .description ),
83
+ "importance_score" : float (self .importance_score ),
84
+ }
85
+
78
86
79
87
class Relationship (BaseModel ):
80
88
src_id : str = Field (..., description = "The name of the source entity." )
@@ -96,6 +104,15 @@ class Relationship(BaseModel):
96
104
description = "The order of the relationship. 1 for direct relationships, 2 for second-order, 3 for third-order." ,
97
105
)
98
106
107
+ def to_dict (self ):
108
+ return {
109
+ "src_id" : clean_str (self .src_id .upper ()),
110
+ "tgt_id" : clean_str (self .tgt_id .upper ()),
111
+ "description" : clean_str (self .description ),
112
+ "weight" : float (self .weight ),
113
+ "order" : int (self .order ),
114
+ }
115
+
99
116
100
117
class CombinedExtraction (dspy .Signature ):
101
118
"""
@@ -134,8 +151,85 @@ class CombinedExtraction(dspy.Signature):
134
151
entity_types : list [str ] = dspy .InputField (
135
152
desc = "List of entity types used for extraction."
136
153
)
137
- entities_relationships : list [Union [Entity , Relationship ]] = dspy .OutputField (
138
- desc = "List of entities and relationships extracted from the text."
154
+ entities : list [Entity ] = dspy .OutputField (
155
+ desc = "List of entities extracted from the text and the entity types."
156
+ )
157
+ relationships : list [Relationship ] = dspy .OutputField (
158
+ desc = "List of relationships extracted from the text and the entity types."
159
+ )
160
+
161
+
162
+ class CritiqueCombinedExtraction (dspy .Signature ):
163
+ """
164
+ Critique the current extraction of entities and relationships from a given text.
165
+ Focus on completeness, accuracy, and adherence to the provided entity types and extraction guidelines.
166
+
167
+ Critique Guidelines:
168
+ 1. Evaluate if all relevant entities from the input text are captured and correctly typed.
169
+ 2. Check if entity descriptions are comprehensive and follow the provided guidelines.
170
+ 3. Assess the completeness of relationship extractions, including higher-order relationships.
171
+ 4. Verify that relationship descriptions are detailed and follow the provided guidelines.
172
+ 5. Identify any inconsistencies, errors, or missed opportunities in the current extraction.
173
+ 6. Suggest specific improvements or additions to enhance the quality of the extraction.
174
+ """
175
+
176
+ input_text : str = dspy .InputField (
177
+ desc = "The original text from which entities and relationships were extracted."
178
+ )
179
+ entity_types : list [str ] = dspy .InputField (
180
+ desc = "List of valid entity types for this extraction task."
181
+ )
182
+ current_entities : list [Entity ] = dspy .InputField (
183
+ desc = "List of currently extracted entities to be critiqued."
184
+ )
185
+ current_relationships : list [Relationship ] = dspy .InputField (
186
+ desc = "List of currently extracted relationships to be critiqued."
187
+ )
188
+ entity_critique : str = dspy .OutputField (
189
+ desc = "Detailed critique of the current entities, highlighting areas for improvement for completeness and accuracy.."
190
+ )
191
+ relationship_critique : str = dspy .OutputField (
192
+ desc = "Detailed critique of the current relationships, highlighting areas for improvement for completeness and accuracy.."
193
+ )
194
+
195
+
196
+ class RefineCombinedExtraction (dspy .Signature ):
197
+ """
198
+ Refine the current extraction of entities and relationships based on the provided critique.
199
+ Improve completeness, accuracy, and adherence to the extraction guidelines.
200
+
201
+ Refinement Guidelines:
202
+ 1. Address all points raised in the entity and relationship critiques.
203
+ 2. Add missing entities and relationships identified in the critique.
204
+ 3. Improve entity and relationship descriptions as suggested.
205
+ 4. Ensure all refinements still adhere to the original extraction guidelines.
206
+ 5. Maintain consistency between entities and relationships during refinement.
207
+ 6. Focus on enhancing the overall quality and comprehensiveness of the extraction.
208
+ """
209
+
210
+ input_text : str = dspy .InputField (
211
+ desc = "The original text from which entities and relationships were extracted."
212
+ )
213
+ entity_types : list [str ] = dspy .InputField (
214
+ desc = "List of valid entity types for this extraction task."
215
+ )
216
+ current_entities : list [Entity ] = dspy .InputField (
217
+ desc = "List of currently extracted entities to be refined."
218
+ )
219
+ current_relationships : list [Relationship ] = dspy .InputField (
220
+ desc = "List of currently extracted relationships to be refined."
221
+ )
222
+ entity_critique : str = dspy .InputField (
223
+ desc = "Detailed critique of the current entities to guide refinement."
224
+ )
225
+ relationship_critique : str = dspy .InputField (
226
+ desc = "Detailed critique of the current relationships to guide refinement."
227
+ )
228
+ refined_entities : list [Entity ] = dspy .OutputField (
229
+ desc = "List of refined entities, addressing the entity critique and improving upon the current entities."
230
+ )
231
+ refined_relationships : list [Relationship ] = dspy .OutputField (
232
+ desc = "List of refined relationships, addressing the relationship critique and improving upon the current relationships."
139
233
)
140
234
141
235
@@ -159,7 +253,7 @@ def forward(self, **kwargs):
159
253
160
254
except Exception as e :
161
255
if isinstance (e , self .exception_types ):
162
- return dspy .Prediction (entities_relationships = [])
256
+ return dspy .Prediction (entities = [], relationships = [])
163
257
164
258
raise e
165
259
@@ -168,46 +262,63 @@ class TypedEntityRelationshipExtractor(dspy.Module):
168
262
def __init__ (
169
263
self ,
170
264
lm : dspy .LM = None ,
171
- reasoning : dspy .OutputField = None ,
172
265
max_retries : int = 3 ,
266
+ entity_types : list [str ] = ENTITY_TYPES ,
267
+ self_refine : bool = False ,
268
+ num_refine_turns : int = 1
173
269
):
174
270
super ().__init__ ()
175
271
self .lm = lm
176
- self .entity_types = ENTITY_TYPES
177
- self .extractor = dspy .TypedChainOfThought (
178
- signature = CombinedExtraction , reasoning = reasoning , max_retries = max_retries
179
- )
272
+ self .entity_types = entity_types
273
+ self .self_refine = self_refine
274
+ self .num_refine_turns = num_refine_turns
275
+
276
+ self .extractor = dspy .TypedChainOfThought (signature = CombinedExtraction , max_retries = max_retries )
180
277
self .extractor = TypedEntityRelationshipExtractorException (
181
278
self .extractor , exception_types = (ValueError ,)
182
279
)
280
+
281
+ if self .self_refine :
282
+ self .critique = dspy .TypedChainOfThought (
283
+ signature = CritiqueCombinedExtraction ,
284
+ max_retries = max_retries
285
+ )
286
+ self .refine = dspy .TypedChainOfThought (
287
+ signature = RefineCombinedExtraction ,
288
+ max_retries = max_retries
289
+ )
183
290
184
291
def forward (self , input_text : str ) -> dspy .Prediction :
185
292
with dspy .context (lm = self .lm if self .lm is not None else dspy .settings .lm ):
186
293
extraction_result = self .extractor (
187
294
input_text = input_text , entity_types = self .entity_types
188
295
)
296
+
297
+ current_entities : list [Entity ] = extraction_result .entities
298
+ current_relationships : list [Relationship ] = extraction_result .relationships
299
+
300
+ if self .self_refine :
301
+ for _ in range (self .num_refine_turns ):
302
+ critique_result = self .critique (
303
+ input_text = input_text ,
304
+ entity_types = self .entity_types ,
305
+ current_entities = current_entities ,
306
+ current_relationships = current_relationships
307
+ )
308
+ refined_result = self .refine (
309
+ input_text = input_text ,
310
+ entity_types = self .entity_types ,
311
+ current_entities = current_entities ,
312
+ current_relationships = current_relationships ,
313
+ entity_critique = critique_result .entity_critique ,
314
+ relationship_critique = critique_result .relationship_critique
315
+ )
316
+ logger .debug (f"entities: { len (current_entities )} | refined_entities: { len (refined_result .refined_entities )} " )
317
+ logger .debug (f"relationships: { len (current_relationships )} | refined_relationships: { len (refined_result .refined_relationships )} " )
318
+ current_entities = refined_result .refined_entities
319
+ current_relationships = refined_result .refined_relationships
189
320
190
- entities = [
191
- dict (
192
- entity_name = clean_str (entity .entity_name .upper ()),
193
- entity_type = clean_str (entity .entity_type .upper ()),
194
- description = clean_str (entity .description ),
195
- importance_score = float (entity .importance_score ),
196
- )
197
- for entity in extraction_result .entities_relationships
198
- if isinstance (entity , Entity )
199
- ]
200
-
201
- relationships = [
202
- dict (
203
- src_id = clean_str (relationship .src_id .upper ()),
204
- tgt_id = clean_str (relationship .tgt_id .upper ()),
205
- description = clean_str (relationship .description ),
206
- weight = float (relationship .weight ),
207
- order = int (relationship .order ),
208
- )
209
- for relationship in extraction_result .entities_relationships
210
- if isinstance (relationship , Relationship )
211
- ]
321
+ entities = [entity .to_dict () for entity in current_entities ]
322
+ relationships = [relationship .to_dict () for relationship in current_relationships ]
212
323
213
324
return dspy .Prediction (entities = entities , relationships = relationships )
0 commit comments