Skip to content

Commit 18fa3a4

Browse files
authored
Feat: Add Amazon Bedrock support (#97)
* Add Amazon Bedrock support * add sample script to test amazon bedrock integration * add the latest Claude 3.5 Sonnet v1&v2 model * Add a factory function for bedrock completion instead of creating one for each model * update README.md to explain the Bedrock option. * clean up
1 parent a8043a6 commit 18fa3a4

File tree

7 files changed

+172
-8
lines changed

7 files changed

+172
-8
lines changed

examples/using_amazon_bedrock.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from nano_graphrag import GraphRAG, QueryParam
2+
3+
graph_func = GraphRAG(
4+
working_dir="../bedrock_example",
5+
using_amazon_bedrock=True,
6+
best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0",
7+
cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",
8+
)
9+
10+
with open("../tests/mock_data.txt") as f:
11+
graph_func.insert(f.read())
12+
13+
prompt = "What are the top themes in this story?"
14+
15+
# Perform global graphrag search
16+
print(graph_func.query(prompt, param=QueryParam(mode="global")))
17+
18+
# Perform local graphrag search (I think is better and more scalable one)
19+
print(graph_func.query(prompt, param=QueryParam(mode="local")))

nano_graphrag/_llm.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import json
12
import numpy as np
3+
from typing import Optional, List, Any, Callable
24

5+
import aioboto3
36
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
47

58
from tenacity import (
@@ -15,6 +18,7 @@
1518

1619
global_openai_async_client = None
1720
global_azure_openai_async_client = None
21+
global_amazon_bedrock_async_client = None
1822

1923

2024
def get_openai_async_client_instance():
@@ -31,6 +35,13 @@ def get_azure_openai_async_client_instance():
3135
return global_azure_openai_async_client
3236

3337

38+
def get_amazon_bedrock_async_client_instance():
39+
global global_amazon_bedrock_async_client
40+
if global_amazon_bedrock_async_client is None:
41+
global_amazon_bedrock_async_client = aioboto3.Session()
42+
return global_amazon_bedrock_async_client
43+
44+
3445
@retry(
3546
stop=stop_after_attempt(5),
3647
wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -64,6 +75,82 @@ async def openai_complete_if_cache(
6475
return response.choices[0].message.content
6576

6677

78+
@retry(
79+
stop=stop_after_attempt(5),
80+
wait=wait_exponential(multiplier=1, min=4, max=10),
81+
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
82+
)
83+
async def amazon_bedrock_complete_if_cache(
84+
model, prompt, system_prompt=None, history_messages=[], **kwargs
85+
) -> str:
86+
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
87+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
88+
messages = []
89+
messages.extend(history_messages)
90+
messages.append({"role": "user", "content": [{"text": prompt}]})
91+
if hashing_kv is not None:
92+
args_hash = compute_args_hash(model, messages)
93+
if_cache_return = await hashing_kv.get_by_id(args_hash)
94+
if if_cache_return is not None:
95+
return if_cache_return["return"]
96+
97+
inference_config = {
98+
"temperature": 0,
99+
"maxTokens": 4096 if "max_tokens" not in kwargs else kwargs["max_tokens"],
100+
}
101+
102+
async with amazon_bedrock_async_client.client(
103+
"bedrock-runtime",
104+
region_name=os.getenv("AWS_REGION", "us-east-1")
105+
) as bedrock_runtime:
106+
if system_prompt:
107+
response = await bedrock_runtime.converse(
108+
modelId=model, messages=messages, inferenceConfig=inference_config,
109+
system=[{"text": system_prompt}]
110+
)
111+
else:
112+
response = await bedrock_runtime.converse(
113+
modelId=model, messages=messages, inferenceConfig=inference_config,
114+
)
115+
116+
if hashing_kv is not None:
117+
await hashing_kv.upsert(
118+
{args_hash: {"return": response["output"]["message"]["content"][0]["text"], "model": model}}
119+
)
120+
await hashing_kv.index_done_callback()
121+
return response["output"]["message"]["content"][0]["text"]
122+
123+
124+
def create_amazon_bedrock_complete_function(model_id: str) -> Callable:
125+
"""
126+
Factory function to dynamically create completion functions for Amazon Bedrock
127+
128+
Args:
129+
model_id (str): Amazon Bedrock model identifier (e.g., "us.anthropic.claude-3-sonnet-20240229-v1:0")
130+
131+
Returns:
132+
Callable: Generated completion function
133+
"""
134+
async def bedrock_complete(
135+
prompt: str,
136+
system_prompt: Optional[str] = None,
137+
history_messages: List[Any] = [],
138+
**kwargs
139+
) -> str:
140+
return await amazon_bedrock_complete_if_cache(
141+
model_id,
142+
prompt,
143+
system_prompt=system_prompt,
144+
history_messages=history_messages,
145+
**kwargs
146+
)
147+
148+
# Set function name for easier debugging
149+
bedrock_complete.__name__ = f"{model_id}_complete"
150+
151+
return bedrock_complete
152+
153+
67154
async def gpt_4o_complete(
68155
prompt, system_prompt=None, history_messages=[], **kwargs
69156
) -> str:
@@ -88,6 +175,35 @@ async def gpt_4o_mini_complete(
88175
)
89176

90177

178+
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
179+
@retry(
180+
stop=stop_after_attempt(5),
181+
wait=wait_exponential(multiplier=1, min=4, max=10),
182+
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
183+
)
184+
async def amazon_bedrock_embedding(texts: list[str]) -> np.ndarray:
185+
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
186+
187+
async with amazon_bedrock_async_client.client(
188+
"bedrock-runtime",
189+
region_name=os.getenv("AWS_REGION", "us-east-1")
190+
) as bedrock_runtime:
191+
embeddings = []
192+
for text in texts:
193+
body = json.dumps(
194+
{
195+
"inputText": text,
196+
"dimensions": 1024,
197+
}
198+
)
199+
response = await bedrock_runtime.invoke_model(
200+
modelId="amazon.titan-embed-text-v2:0", body=body,
201+
)
202+
response_body = await response.get("body").read()
203+
embeddings.append(json.loads(response_body))
204+
return np.array([dp["embedding"] for dp in embeddings])
205+
206+
91207
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
92208
@retry(
93209
stop=stop_after_attempt(5),

nano_graphrag/_op.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ async def extract_entities(
293293
knwoledge_graph_inst: BaseGraphStorage,
294294
entity_vdb: BaseVectorStorage,
295295
global_config: dict,
296+
using_amazon_bedrock: bool=False,
296297
) -> Union[BaseGraphStorage, None]:
297298
use_llm_func: callable = global_config["best_model_func"]
298299
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@@ -320,12 +321,14 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
320321
content = chunk_dp["content"]
321322
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
322323
final_result = await use_llm_func(hint_prompt)
324+
if isinstance(final_result, list):
325+
final_result = final_result[0]["text"]
323326

324-
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
327+
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, using_amazon_bedrock)
325328
for now_glean_index in range(entity_extract_max_gleaning):
326329
glean_result = await use_llm_func(continue_prompt, history_messages=history)
327330

328-
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
331+
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result, using_amazon_bedrock)
329332
final_result += glean_result
330333
if now_glean_index == entity_extract_max_gleaning - 1:
331334
break

nano_graphrag/_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,17 @@ def load_json(file_name):
162162

163163

164164
# it's dirty to type, so it's a good way to have fun
165-
def pack_user_ass_to_openai_messages(*args: str):
166-
roles = ["user", "assistant"]
167-
return [
168-
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
169-
]
165+
def pack_user_ass_to_openai_messages(prompt: str, generated_content: str, using_amazon_bedrock: bool):
166+
if using_amazon_bedrock:
167+
return [
168+
{"role": "user", "content": [{"text": prompt}]},
169+
{"role": "assistant", "content": [{"text": generated_content}]},
170+
]
171+
else:
172+
return [
173+
{"role": "user", "content": prompt},
174+
{"role": "assistant", "content": generated_content},
175+
]
170176

171177

172178
def is_float_regex(value):

nano_graphrag/graphrag.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010

1111
from ._llm import (
12+
amazon_bedrock_embedding,
13+
create_amazon_bedrock_complete_function,
1214
gpt_4o_complete,
1315
gpt_4o_mini_complete,
1416
openai_embedding,
@@ -107,6 +109,9 @@ class GraphRAG:
107109

108110
# LLM
109111
using_azure_openai: bool = False
112+
using_amazon_bedrock: bool = False
113+
best_model_id: str = "us.anthropic.claude-3-sonnet-20240229-v1:0"
114+
cheap_model_id: str = "us.anthropic.claude-3-haiku-20240307-v1:0"
110115
best_model_func: callable = gpt_4o_complete
111116
best_model_max_token_size: int = 32768
112117
best_model_max_async: int = 16
@@ -145,6 +150,14 @@ def __post_init__(self):
145150
"Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
146151
)
147152

153+
if self.using_amazon_bedrock:
154+
self.best_model_func = create_amazon_bedrock_complete_function(self.best_model_id)
155+
self.cheap_model_func = create_amazon_bedrock_complete_function(self.cheap_model_id)
156+
self.embedding_func = amazon_bedrock_embedding
157+
logger.info(
158+
"Switched the default openai funcs to Amazon Bedrock"
159+
)
160+
148161
if not os.path.exists(self.working_dir) and self.always_create_working_dir:
149162
logger.info(f"Creating working directory {self.working_dir}")
150163
os.makedirs(self.working_dir)
@@ -298,6 +311,7 @@ async def ainsert(self, string_or_strings):
298311
knwoledge_graph_inst=self.chunk_entity_relation_graph,
299312
entity_vdb=self.entities_vdb,
300313
global_config=asdict(self),
314+
using_amazon_bedrock=self.using_amazon_bedrock,
301315
)
302316
if maybe_new_kg is None:
303317
logger.warning("No new entities found")

readme.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ pip install nano-graphrag
7373
> [!TIP]
7474
> If you're using Azure OpenAI API, refer to the [.env.example](./.env.example.azure) to set your azure openai. Then pass `GraphRAG(...,using_azure_openai=True,...)` to enable.
7575
76+
> [!TIP]
77+
> If you're using Amazon Bedrock API, please ensure your credentials are properly set through commands like `aws configure`. Then enable it by configuring like this: `GraphRAG(...,using_amazon_bedrock=True, best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0", cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",...)`. Refer to an [example script](./examples/using_amazon_bedrock.py).
78+
7679
> [!TIP]
7780
>
7881
> If you don't have any key, check out this [example](./examples/no_openai_key_at_all.py) that using `transformers` and `ollama` . If you like to use another LLM or Embedding Model, check [Advances](#Advances).
@@ -167,9 +170,11 @@ Below are the components you can use:
167170
| Type | What | Where |
168171
| :-------------- | :----------------------------------------------------------: | :-----------------------------------------------: |
169172
| LLM | OpenAI | Built-in |
173+
| | Amazon Bedrock | Built-in |
170174
| | DeepSeek | [examples](./examples) |
171175
| | `ollama` | [examples](./examples) |
172176
| Embedding | OpenAI | Built-in |
177+
| | Amazon Bedrock | Built-in |
173178
| | Sentence-transformers | [examples](./examples) |
174179
| Vector DataBase | [`nano-vectordb`](https://github.com/gusye1234/nano-vectordb) | Built-in |
175180
| | [`hnswlib`](https://github.com/nmslib/hnswlib) | Built-in, [examples](./examples) |

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ hnswlib
88
xxhash
99
tenacity
1010
dspy-ai
11-
neo4j
11+
neo4j
12+
aioboto3

0 commit comments

Comments
 (0)