Skip to content

Commit c061781

Browse files
Add multi-turn self-refine for entity relationship extractor (#73)
1 parent cddcda7 commit c061781

File tree

7 files changed

+570
-58
lines changed

7 files changed

+570
-58
lines changed

docs/benchmark-dspy-entity-extraction.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
# Main Takeaways
1+
# Chain Of Thought Prompting with DSPy-AI (v2.4.16)
2+
## Main Takeaways
23
- Time difference: 156.99 seconds
34
- Execution time with DSPy-AI: 304.38 seconds
45
- Execution time without DSPy-AI: 147.39 seconds
56
- Entities extracted: 22 (without DSPy-AI) vs 37 (with DSPy-AI)
67
- Relationships extracted: 21 (without DSPy-AI) vs 36 (with DSPy-AI)
78

89

9-
# Results
10+
## Results
1011
```markdown
1112
> python examples/benchmarks/dspy_entity.py
1213

@@ -264,4 +265,12 @@ Relationships:
264265
"朱元璋早年为刘德放牛,这段经历对他的成长有重要影响。"
265266
- "朱元璋" -> "吴老太":
266267
"朱元璋曾希望托吴老太找一个媳妇,显示了他对家庭的渴望。"
267-
```
268+
```
269+
270+
# Self-Refine with DSPy-AI (v2.5.6)
271+
## Main Takeaways
272+
- Time difference: 66.24 seconds
273+
- Execution time with DSPy-AI: 211.04 seconds
274+
- Execution time without DSPy-AI: 144.80 seconds
275+
- Entities extracted: 38 (without DSPy-AI) vs 16 (with DSPy-AI)
276+
- Relationships extracted: 38 (without DSPy-AI) vs 16 (with DSPy-AI)

examples/benchmarks/dspy_entity.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import time
88
import shutil
99
from nano_graphrag.entity_extraction.extract import extract_entities_dspy
10-
from nano_graphrag._storage import NetworkXStorage, BaseKVStorage
10+
from nano_graphrag.base import BaseKVStorage
11+
from nano_graphrag._storage import NetworkXStorage
1112
from nano_graphrag._utils import compute_mdhash_id, compute_args_hash
1213
from nano_graphrag._op import extract_entities
1314

@@ -106,14 +107,12 @@ def print_extraction_results(graph_storage: NetworkXStorage):
106107
async def run_benchmark(text: str):
107108
print("\nRunning benchmark with DSPy-AI:")
108109
system_prompt = """
109-
You are a world-class AI system, capable of complex rationale and reflection.
110-
Reason through the query, and then provide your final response.
111-
If you detect that you made a mistake in your rationale at any point, correct yourself.
112-
Think carefully.
110+
You are an expert system specialized in entity and relationship extraction from complex texts.
111+
Your task is to thoroughly analyze the given text and extract all relevant entities and their relationships with utmost precision and completeness.
113112
"""
114113
system_prompt_dspy = f"{system_prompt} Time: {time.time()}."
115-
lm = dspy.OpenAI(
116-
model="deepseek-chat",
114+
lm = dspy.LM(
115+
model="deepseek/deepseek-chat",
117116
model_type="chat",
118117
api_provider="openai",
119118
api_key=os.environ["DEEPSEEK_API_KEY"],
@@ -127,7 +126,6 @@ async def run_benchmark(text: str):
127126
print(f"Execution time with DSPy-AI: {time_with_dspy:.2f} seconds")
128127
print_extraction_results(graph_storage_with_dspy)
129128

130-
import pdb; pdb.set_trace()
131129
print("Running benchmark without DSPy-AI:")
132130
system_prompt_no_dspy = f"{system_prompt} Time: {time.time()}."
133131
graph_storage_without_dspy, time_without_dspy = await benchmark_entity_extraction(text, system_prompt_no_dspy, use_dspy=False)
@@ -148,7 +146,7 @@ async def run_benchmark(text: str):
148146

149147

150148
if __name__ == "__main__":
151-
with open("./examples/data/test.txt", encoding="utf-8-sig") as f:
149+
with open("./tests/zhuyuanzhang.txt", encoding="utf-8-sig") as f:
152150
text = f.read()
153151

154152
asyncio.run(run_benchmark(text=text))

examples/using_dspy_entity_extraction.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,12 @@ def query():
130130

131131

132132
if __name__ == "__main__":
133-
system_prompt = """
134-
You are a world-class AI system, capable of complex rationale and reflection.
135-
Reason through the query, and then provide your final response.
136-
If you detect that you made a mistake in your rationale at any point, correct yourself.
137-
Think carefully.
138-
"""
139-
lm = dspy.OpenAI(
140-
model="deepseek-chat",
133+
lm = dspy.LM(
134+
model="deepseek/deepseek-chat",
141135
model_type="chat",
142136
api_provider="openai",
143137
api_key=os.environ["DEEPSEEK_API_KEY"],
144138
base_url=os.environ["DEEPSEEK_BASE_URL"],
145-
system_prompt=system_prompt,
146139
temperature=1.0,
147140
max_tokens=8192
148141
)

nano_graphrag/entity_extraction/extract.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ async def generate_dataset(
2121
save_dataset: bool = True,
2222
global_config: dict = {},
2323
) -> list[dspy.Example]:
24-
entity_extractor = TypedEntityRelationshipExtractor()
24+
entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)
2525

2626
if global_config.get("use_compiled_dspy_entity_relationship", False):
2727
entity_extractor.load(global_config["entity_relationship_module_path"])
@@ -84,7 +84,7 @@ async def extract_entities_dspy(
8484
entity_vdb: BaseVectorStorage,
8585
global_config: dict,
8686
) -> Union[BaseGraphStorage, None]:
87-
entity_extractor = TypedEntityRelationshipExtractor()
87+
entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)
8888

8989
if global_config.get("use_compiled_dspy_entity_relationship", False):
9090
entity_extractor.load(global_config["entity_relationship_module_path"])

nano_graphrag/entity_extraction/module.py

Lines changed: 142 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Union
21
import dspy
32
from pydantic import BaseModel, Field
43
from nano_graphrag._utils import clean_str
4+
from nano_graphrag._utils import logger
55

66

77
"""
@@ -75,6 +75,14 @@ class Entity(BaseModel):
7575
description="Importance score of the entity. Should be between 0 and 1 with 1 being the most important.",
7676
)
7777

78+
def to_dict(self):
79+
return {
80+
"entity_name": clean_str(self.entity_name.upper()),
81+
"entity_type": clean_str(self.entity_type.upper()),
82+
"description": clean_str(self.description),
83+
"importance_score": float(self.importance_score),
84+
}
85+
7886

7987
class Relationship(BaseModel):
8088
src_id: str = Field(..., description="The name of the source entity.")
@@ -96,6 +104,15 @@ class Relationship(BaseModel):
96104
description="The order of the relationship. 1 for direct relationships, 2 for second-order, 3 for third-order.",
97105
)
98106

107+
def to_dict(self):
108+
return {
109+
"src_id": clean_str(self.src_id.upper()),
110+
"tgt_id": clean_str(self.tgt_id.upper()),
111+
"description": clean_str(self.description),
112+
"weight": float(self.weight),
113+
"order": int(self.order),
114+
}
115+
99116

100117
class CombinedExtraction(dspy.Signature):
101118
"""
@@ -134,8 +151,85 @@ class CombinedExtraction(dspy.Signature):
134151
entity_types: list[str] = dspy.InputField(
135152
desc="List of entity types used for extraction."
136153
)
137-
entities_relationships: list[Union[Entity, Relationship]] = dspy.OutputField(
138-
desc="List of entities and relationships extracted from the text."
154+
entities: list[Entity] = dspy.OutputField(
155+
desc="List of entities extracted from the text and the entity types."
156+
)
157+
relationships: list[Relationship] = dspy.OutputField(
158+
desc="List of relationships extracted from the text and the entity types."
159+
)
160+
161+
162+
class CritiqueCombinedExtraction(dspy.Signature):
163+
"""
164+
Critique the current extraction of entities and relationships from a given text.
165+
Focus on completeness, accuracy, and adherence to the provided entity types and extraction guidelines.
166+
167+
Critique Guidelines:
168+
1. Evaluate if all relevant entities from the input text are captured and correctly typed.
169+
2. Check if entity descriptions are comprehensive and follow the provided guidelines.
170+
3. Assess the completeness of relationship extractions, including higher-order relationships.
171+
4. Verify that relationship descriptions are detailed and follow the provided guidelines.
172+
5. Identify any inconsistencies, errors, or missed opportunities in the current extraction.
173+
6. Suggest specific improvements or additions to enhance the quality of the extraction.
174+
"""
175+
176+
input_text: str = dspy.InputField(
177+
desc="The original text from which entities and relationships were extracted."
178+
)
179+
entity_types: list[str] = dspy.InputField(
180+
desc="List of valid entity types for this extraction task."
181+
)
182+
current_entities: list[Entity] = dspy.InputField(
183+
desc="List of currently extracted entities to be critiqued."
184+
)
185+
current_relationships: list[Relationship] = dspy.InputField(
186+
desc="List of currently extracted relationships to be critiqued."
187+
)
188+
entity_critique: str = dspy.OutputField(
189+
desc="Detailed critique of the current entities, highlighting areas for improvement for completeness and accuracy.."
190+
)
191+
relationship_critique: str = dspy.OutputField(
192+
desc="Detailed critique of the current relationships, highlighting areas for improvement for completeness and accuracy.."
193+
)
194+
195+
196+
class RefineCombinedExtraction(dspy.Signature):
197+
"""
198+
Refine the current extraction of entities and relationships based on the provided critique.
199+
Improve completeness, accuracy, and adherence to the extraction guidelines.
200+
201+
Refinement Guidelines:
202+
1. Address all points raised in the entity and relationship critiques.
203+
2. Add missing entities and relationships identified in the critique.
204+
3. Improve entity and relationship descriptions as suggested.
205+
4. Ensure all refinements still adhere to the original extraction guidelines.
206+
5. Maintain consistency between entities and relationships during refinement.
207+
6. Focus on enhancing the overall quality and comprehensiveness of the extraction.
208+
"""
209+
210+
input_text: str = dspy.InputField(
211+
desc="The original text from which entities and relationships were extracted."
212+
)
213+
entity_types: list[str] = dspy.InputField(
214+
desc="List of valid entity types for this extraction task."
215+
)
216+
current_entities: list[Entity] = dspy.InputField(
217+
desc="List of currently extracted entities to be refined."
218+
)
219+
current_relationships: list[Relationship] = dspy.InputField(
220+
desc="List of currently extracted relationships to be refined."
221+
)
222+
entity_critique: str = dspy.InputField(
223+
desc="Detailed critique of the current entities to guide refinement."
224+
)
225+
relationship_critique: str = dspy.InputField(
226+
desc="Detailed critique of the current relationships to guide refinement."
227+
)
228+
refined_entities: list[Entity] = dspy.OutputField(
229+
desc="List of refined entities, addressing the entity critique and improving upon the current entities."
230+
)
231+
refined_relationships: list[Relationship] = dspy.OutputField(
232+
desc="List of refined relationships, addressing the relationship critique and improving upon the current relationships."
139233
)
140234

141235

@@ -159,7 +253,7 @@ def forward(self, **kwargs):
159253

160254
except Exception as e:
161255
if isinstance(e, self.exception_types):
162-
return dspy.Prediction(entities_relationships=[])
256+
return dspy.Prediction(entities=[], relationships=[])
163257

164258
raise e
165259

@@ -168,46 +262,63 @@ class TypedEntityRelationshipExtractor(dspy.Module):
168262
def __init__(
169263
self,
170264
lm: dspy.LM = None,
171-
reasoning: dspy.OutputField = None,
172265
max_retries: int = 3,
266+
entity_types: list[str] = ENTITY_TYPES,
267+
self_refine: bool = False,
268+
num_refine_turns: int = 1
173269
):
174270
super().__init__()
175271
self.lm = lm
176-
self.entity_types = ENTITY_TYPES
177-
self.extractor = dspy.TypedChainOfThought(
178-
signature=CombinedExtraction, reasoning=reasoning, max_retries=max_retries
179-
)
272+
self.entity_types = entity_types
273+
self.self_refine = self_refine
274+
self.num_refine_turns = num_refine_turns
275+
276+
self.extractor = dspy.TypedChainOfThought(signature=CombinedExtraction, max_retries=max_retries)
180277
self.extractor = TypedEntityRelationshipExtractorException(
181278
self.extractor, exception_types=(ValueError,)
182279
)
280+
281+
if self.self_refine:
282+
self.critique = dspy.TypedChainOfThought(
283+
signature=CritiqueCombinedExtraction,
284+
max_retries=max_retries
285+
)
286+
self.refine = dspy.TypedChainOfThought(
287+
signature=RefineCombinedExtraction,
288+
max_retries=max_retries
289+
)
183290

184291
def forward(self, input_text: str) -> dspy.Prediction:
185292
with dspy.context(lm=self.lm if self.lm is not None else dspy.settings.lm):
186293
extraction_result = self.extractor(
187294
input_text=input_text, entity_types=self.entity_types
188295
)
296+
297+
current_entities: list[Entity] = extraction_result.entities
298+
current_relationships: list[Relationship] = extraction_result.relationships
299+
300+
if self.self_refine:
301+
for _ in range(self.num_refine_turns):
302+
critique_result = self.critique(
303+
input_text=input_text,
304+
entity_types=self.entity_types,
305+
current_entities=current_entities,
306+
current_relationships=current_relationships
307+
)
308+
refined_result = self.refine(
309+
input_text=input_text,
310+
entity_types=self.entity_types,
311+
current_entities=current_entities,
312+
current_relationships=current_relationships,
313+
entity_critique=critique_result.entity_critique,
314+
relationship_critique=critique_result.relationship_critique
315+
)
316+
logger.debug(f"entities: {len(current_entities)} | refined_entities: {len(refined_result.refined_entities)}")
317+
logger.debug(f"relationships: {len(current_relationships)} | refined_relationships: {len(refined_result.refined_relationships)}")
318+
current_entities = refined_result.refined_entities
319+
current_relationships = refined_result.refined_relationships
189320

190-
entities = [
191-
dict(
192-
entity_name=clean_str(entity.entity_name.upper()),
193-
entity_type=clean_str(entity.entity_type.upper()),
194-
description=clean_str(entity.description),
195-
importance_score=float(entity.importance_score),
196-
)
197-
for entity in extraction_result.entities_relationships
198-
if isinstance(entity, Entity)
199-
]
200-
201-
relationships = [
202-
dict(
203-
src_id=clean_str(relationship.src_id.upper()),
204-
tgt_id=clean_str(relationship.tgt_id.upper()),
205-
description=clean_str(relationship.description),
206-
weight=float(relationship.weight),
207-
order=int(relationship.order),
208-
)
209-
for relationship in extraction_result.entities_relationships
210-
if isinstance(relationship, Relationship)
211-
]
321+
entities = [entity.to_dict() for entity in current_entities]
322+
relationships = [relationship.to_dict() for relationship in current_relationships]
212323

213324
return dspy.Prediction(entities=entities, relationships=relationships)

0 commit comments

Comments
 (0)