|
1 |
| -import ipaddress |
2 |
| -import socket |
3 | 1 | from logging import Logger
|
4 | 2 | from typing import Any
|
5 |
| -from urllib.parse import urlparse |
6 | 3 |
|
7 | 4 | from flask import current_app as app
|
8 | 5 |
|
|
17 | 14 | repair_json_with_best_structure,
|
18 | 15 | )
|
19 | 16 | from unstract.prompt_service.utils.log import publish_log
|
| 17 | +from unstract.sdk.adapters.url_validator import URLValidator |
20 | 18 | from unstract.sdk.constants import LogLevel
|
21 | 19 | from unstract.sdk.exceptions import RateLimitError as SdkRateLimitError
|
22 | 20 | from unstract.sdk.exceptions import SdkError
|
|
26 | 24 | from unstract.sdk.llm import LLM
|
27 | 25 |
|
28 | 26 |
|
29 |
| -def _is_safe_public_url(url: str) -> bool: |
30 |
| - """Validate webhook URL for SSRF protection. |
31 |
| -
|
32 |
| - Only allows HTTPS and blocks private/loopback/internal addresses. |
33 |
| - Resolves all DNS records (A/AAAA) to prevent DNS rebinding attacks. |
34 |
| - """ |
35 |
| - try: |
36 |
| - p = urlparse(url) |
37 |
| - if p.scheme not in ("https",): # Only allow HTTPS for security |
38 |
| - return False |
39 |
| - host = p.hostname or "" |
40 |
| - # Block obvious local hosts |
41 |
| - if host in ("localhost",): |
42 |
| - return False |
43 |
| - |
44 |
| - addrs: set[str] = set() |
45 |
| - # If literal IP, validate directly; else resolve all records (A/AAAA) |
46 |
| - try: |
47 |
| - ipaddress.ip_address(host) |
48 |
| - addrs.add(host) |
49 |
| - except ValueError: |
50 |
| - try: |
51 |
| - for family, _type, _proto, _canonname, sockaddr in socket.getaddrinfo( |
52 |
| - host, None, type=socket.SOCK_STREAM |
53 |
| - ): |
54 |
| - addr = sockaddr[0] |
55 |
| - addrs.add(addr) |
56 |
| - except Exception: |
57 |
| - return False |
58 |
| - |
59 |
| - if not addrs: |
60 |
| - return False |
61 |
| - |
62 |
| - # Validate all resolved addresses |
63 |
| - for addr in addrs: |
64 |
| - try: |
65 |
| - ip = ipaddress.ip_address(addr) |
66 |
| - except ValueError: |
67 |
| - return False |
68 |
| - if ( |
69 |
| - ip.is_private |
70 |
| - or ip.is_loopback |
71 |
| - or ip.is_link_local |
72 |
| - or ip.is_reserved |
73 |
| - or ip.is_multicast |
74 |
| - ): |
75 |
| - return False |
76 |
| - return True |
77 |
| - except Exception: |
78 |
| - return False |
79 |
| - |
80 |
| - |
81 | 27 | class AnswerPromptService:
|
82 | 28 | @staticmethod
|
83 | 29 | def extract_variable(
|
@@ -342,23 +288,25 @@ def handle_json(
|
342 | 288 | app.logger.warning(
|
343 | 289 | "Postprocessing webhook enabled but URL missing; skipping."
|
344 | 290 | )
|
345 |
| - elif not _is_safe_public_url(webhook_url): |
346 |
| - app.logger.warning( |
347 |
| - "Postprocessing webhook URL is not allowed; skipping." |
348 |
| - ) |
349 | 291 | else:
|
350 |
| - try: |
351 |
| - processed_data, updated_highlight_data = postprocess_data( |
352 |
| - parsed_data, |
353 |
| - webhook_enabled=True, |
354 |
| - webhook_url=webhook_url, |
355 |
| - highlight_data=highlight_data, |
356 |
| - timeout=60, |
357 |
| - ) |
358 |
| - except Exception as e: |
| 292 | + is_valid, error_message = URLValidator.validate_url(webhook_url) |
| 293 | + if not is_valid: |
359 | 294 | app.logger.warning(
|
360 |
| - f"Postprocessing webhook failed: {e}. Using unprocessed data." |
| 295 | + f"Postprocessing webhook URL validation failed: {error_message}; skipping." |
361 | 296 | )
|
| 297 | + else: |
| 298 | + try: |
| 299 | + processed_data, updated_highlight_data = postprocess_data( |
| 300 | + parsed_data, |
| 301 | + webhook_enabled=True, |
| 302 | + webhook_url=webhook_url, |
| 303 | + highlight_data=highlight_data, |
| 304 | + timeout=60, |
| 305 | + ) |
| 306 | + except Exception as e: |
| 307 | + app.logger.warning( |
| 308 | + f"Postprocessing webhook failed: {e}. Using unprocessed data." |
| 309 | + ) |
362 | 310 |
|
363 | 311 | structured_output[prompt_key] = processed_data
|
364 | 312 |
|
|
0 commit comments