Coverage for ivatar/utils.py: 68%

188 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 00:08 +0000

1""" 

2Simple module providing reusable random_string function 

3""" 

4 

5import contextlib 

6import http.client 

7import random 

8import string 

9import logging 

10from io import BytesIO 

11from PIL import Image, ImageDraw, ImageSequence 

12from urllib.parse import urlparse 

13from urllib.error import URLError 

14import requests 

15from ivatar.settings import DEBUG, URL_TIMEOUT 

16from urllib.request import urlopen as urlopen_orig 

17 

18# Initialize logger 

19logger = logging.getLogger("ivatar") 

20 

21BLUESKY_IDENTIFIER = None 

22BLUESKY_APP_PASSWORD = None 

23with contextlib.suppress(Exception): 

24 from ivatar.settings import BLUESKY_IDENTIFIER, BLUESKY_APP_PASSWORD 

25 

26 

27def urlopen(url, timeout=URL_TIMEOUT): 

28 ctx = None 

29 if DEBUG: 

30 import ssl 

31 

32 ctx = ssl.create_default_context() 

33 ctx.check_hostname = False 

34 ctx.verify_mode = ssl.CERT_NONE 

35 

36 try: 

37 return urlopen_orig(url, timeout=timeout, context=ctx) 

38 except Exception as exc: 

39 # Handle malformed URLs and other HTTP client errors gracefully 

40 if isinstance(exc, http.client.InvalidURL): 

41 logger.warning( 

42 f"Invalid URL detected (possible injection attempt): {url!r} - {exc}" 

43 ) 

44 # Re-raise as URLError to maintain compatibility with existing error handling 

45 raise URLError(f"Invalid URL: {exc}") from exc 

46 elif isinstance(exc, (ValueError, UnicodeError)): 

47 logger.warning(f"Malformed URL detected: {url!r} - {exc}") 

48 raise URLError(f"Malformed URL: {exc}") from exc 

49 else: 

50 # Re-raise other exceptions as-is 

51 raise 

52 

53 

54class Bluesky: 

55 """ 

56 Handle Bluesky client access with persistent session management 

57 """ 

58 

59 identifier = "" 

60 app_password = "" 

61 service = "https://bsky.social" 

62 session = None 

63 _shared_session = None # Class-level shared session 

64 _session_expires_at = None # Track session expiration 

65 

66 def __init__( 

67 self, 

68 identifier: str = BLUESKY_IDENTIFIER, 

69 app_password: str = BLUESKY_APP_PASSWORD, 

70 service: str = "https://bsky.social", 

71 ): 

72 self.identifier = identifier 

73 self.app_password = app_password 

74 self.service = service 

75 

76 def _is_session_valid(self) -> bool: 

77 """ 

78 Check if the current session is still valid 

79 """ 

80 if not self._shared_session or not self._session_expires_at: 

81 return False 

82 

83 import time 

84 

85 # Add 5 minute buffer before actual expiration 

86 return time.time() < (self._session_expires_at - 300) 

87 

88 def login(self): 

89 """ 

90 Login to Bluesky with session persistence 

91 """ 

92 # Use shared session if available and valid 

93 if self._is_session_valid(): 

94 self.session = self._shared_session 

95 logger.debug("Reusing existing Bluesky session") 

96 return 

97 

98 logger.debug("Creating new Bluesky session") 

99 auth_response = requests.post( 

100 f"{self.service}/xrpc/com.atproto.server.createSession", 

101 json={"identifier": self.identifier, "password": self.app_password}, 

102 ) 

103 auth_response.raise_for_status() 

104 self.session = auth_response.json() 

105 

106 # Store session data for reuse 

107 self._shared_session = self.session 

108 import time 

109 

110 # Sessions typically expire in 24 hours, but we'll refresh every 12 hours 

111 self._session_expires_at = time.time() + (12 * 60 * 60) 

112 

113 logger.debug( 

114 "Created new Bluesky session, expires at: %s", 

115 time.strftime( 

116 "%Y-%m-%d %H:%M:%S", time.localtime(self._session_expires_at) 

117 ), 

118 ) 

119 

120 @classmethod 

121 def clear_shared_session(cls): 

122 """ 

123 Clear the shared session (useful for testing) 

124 """ 

125 cls._shared_session = None 

126 cls._session_expires_at = None 

127 logger.debug("Cleared shared Bluesky session") 

128 

129 def normalize_handle(self, handle: str) -> str: 

130 """ 

131 Return the normalized handle for given handle 

132 """ 

133 # Normalize Bluesky handle in case someone enters an '@' at the beginning 

134 while handle.startswith("@"): 

135 handle = handle[1:] 

136 # Remove trailing spaces or spaces at the beginning 

137 while handle.startswith(" "): 

138 handle = handle[1:] 

139 while handle.endswith(" "): 

140 handle = handle[:-1] 

141 return handle 

142 

143 def _make_profile_request(self, handle: str): 

144 """ 

145 Make a profile request to Bluesky API with automatic retry on session expiration 

146 """ 

147 try: 

148 profile_response = requests.get( 

149 f"{self.service}/xrpc/app.bsky.actor.getProfile", 

150 headers={"Authorization": f'Bearer {self.session["accessJwt"]}'}, 

151 params={"actor": handle}, 

152 ) 

153 profile_response.raise_for_status() 

154 return profile_response.json() 

155 except requests.exceptions.HTTPError as exc: 

156 if exc.response.status_code == 401: 

157 # Session expired, try to login again 

158 logger.warning("Bluesky session expired, re-authenticating") 

159 self.clear_shared_session() 

160 self.login() 

161 # Retry the request 

162 profile_response = requests.get( 

163 f"{self.service}/xrpc/app.bsky.actor.getProfile", 

164 headers={"Authorization": f'Bearer {self.session["accessJwt"]}'}, 

165 params={"actor": handle}, 

166 ) 

167 profile_response.raise_for_status() 

168 return profile_response.json() 

169 else: 

170 logger.warning(f"Bluesky profile fetch failed with HTTP error: {exc}") 

171 return None 

172 except Exception as exc: 

173 logger.warning(f"Bluesky profile fetch failed with error: {exc}") 

174 return None 

175 

176 def get_profile(self, handle: str) -> str: 

177 if not self.session or not self._is_session_valid(): 

178 self.login() 

179 return self._make_profile_request(handle) 

180 

181 def get_avatar(self, handle: str): 

182 """ 

183 Get avatar URL for a handle 

184 """ 

185 profile = self.get_profile(handle) 

186 return profile["avatar"] if profile else None 

187 

188 

189def random_string(length=10): 

190 """ 

191 Return some random string with default length 10 

192 """ 

193 return "".join( 

194 random.SystemRandom().choice(string.ascii_lowercase + string.digits) 

195 for _ in range(length) 

196 ) 

197 

198 

199def generate_random_email(): 

200 """ 

201 Generate a random email address using the same pattern as test_views.py 

202 """ 

203 username = random_string() 

204 domain = random_string() 

205 tld = random_string(2) 

206 return f"{username}@{domain}.{tld}" 

207 

208 

209def random_ip_address(): 

210 """ 

211 Return a random IP address (IPv4) 

212 """ 

213 return f"{random.randint(1, 254)}.{random.randint(1, 254)}.{random.randint(1, 254)}.{random.randint(1, 254)}" 

214 

215 

216def openid_variations(openid): 

217 """ 

218 Return the various OpenID variations, ALWAYS in the same order: 

219 - http w/ trailing slash 

220 - http w/o trailing slash 

221 - https w/ trailing slash 

222 - https w/o trailing slash 

223 """ 

224 

225 # Make the 'base' version: http w/ trailing slash 

226 if openid.startswith("https://"): 

227 openid = openid.replace("https://", "http://") 

228 if openid[-1] != "/": 

229 openid = f"{openid}/" 

230 

231 # http w/o trailing slash 

232 var1 = openid[:-1] 

233 var2 = openid.replace("http://", "https://") 

234 var3 = var2[:-1] 

235 return (openid, var1, var2, var3) 

236 

237 

238def mm_ng( 

239 idhash, size=80, add_red=0, add_green=0, add_blue=0 

240): # pylint: disable=too-many-locals 

241 """ 

242 Return an MM (mystery man) image, based on a given hash 

243 add some red, green or blue, if specified 

244 """ 

245 

246 # Make sure the lightest bg color we paint is e0, else 

247 # we do not see the MM any more 

248 if idhash[0] == "f": 

249 idhash = "e0" 

250 

251 # How large is the circle? 

252 circle_size = size * 0.6 

253 

254 # Coordinates for the circle 

255 start_x = int(size * 0.2) 

256 end_x = start_x + circle_size 

257 start_y = int(size * 0.05) 

258 end_y = start_y + circle_size 

259 

260 # All are the same, based on the input hash 

261 # this should always result in a "gray-ish" background 

262 red = idhash[:2] 

263 green = idhash[:2] 

264 blue = idhash[:2] 

265 

266 # Add some red (i/a) and make sure it's not over 255 

267 red = hex(int(red, 16) + add_red).replace("0x", "") 

268 if int(red, 16) > 255: 

269 red = "ff" 

270 if len(red) == 1: 

271 red = f"0{red}" 

272 

273 # Add some green (i/a) and make sure it's not over 255 

274 green = hex(int(green, 16) + add_green).replace("0x", "") 

275 if int(green, 16) > 255: 

276 green = "ff" 

277 if len(green) == 1: 

278 green = f"0{green}" 

279 

280 # Add some blue (i/a) and make sure it's not over 255 

281 blue = hex(int(blue, 16) + add_blue).replace("0x", "") 

282 if int(blue, 16) > 255: 

283 blue = "ff" 

284 if len(blue) == 1: 

285 blue = f"0{blue}" 

286 

287 # Assemble the bg color "string" in web notation. Eg. '#d3d3d3' 

288 bg_color = f"#{red}{green}{blue}" 

289 

290 # Image 

291 image = Image.new("RGB", (size, size)) 

292 draw = ImageDraw.Draw(image) 

293 

294 # Draw background 

295 draw.rectangle(((0, 0), (size, size)), fill=bg_color) 

296 

297 # Draw MMs head 

298 draw.ellipse((start_x, start_y, end_x, end_y), fill="white") 

299 

300 # Draw MMs 'body' 

301 draw.polygon( 

302 ( 

303 (start_x + circle_size / 2, size / 2.5), 

304 (size * 0.15, size), 

305 (size - size * 0.15, size), 

306 ), 

307 fill="white", 

308 ) 

309 

310 return image 

311 

312 

313def is_trusted_url(url, url_filters): 

314 """ 

315 Check if a URL is valid and considered a trusted URL. 

316 If the URL is malformed, returns False. 

317 

318 Based on: https://developer.mozilla.org/en-US/docs/Mozilla/Add-ons/WebExtensions/API/events/UrlFilter 

319 """ 

320 (scheme, netloc, path, params, query, fragment) = urlparse(url) 

321 

322 for ufilter in url_filters: 

323 if "schemes" in ufilter: 

324 schemes = ufilter["schemes"] 

325 

326 if scheme not in schemes: 

327 continue 

328 

329 if "host_equals" in ufilter: 

330 host_equals = ufilter["host_equals"] 

331 

332 if netloc != host_equals: 

333 continue 

334 

335 if "host_suffix" in ufilter: 

336 host_suffix = ufilter["host_suffix"] 

337 

338 if not netloc.endswith(host_suffix): 

339 continue 

340 

341 if "path_prefix" in ufilter: 

342 path_prefix = ufilter["path_prefix"] 

343 

344 if not path.startswith(path_prefix): 

345 continue 

346 

347 if "url_prefix" in ufilter: 

348 url_prefix = ufilter["url_prefix"] 

349 

350 if not url.startswith(url_prefix): 

351 continue 

352 

353 return True 

354 

355 return False 

356 

357 

358def resize_animated_gif(input_pil: Image, size: list) -> BytesIO: 

359 def _thumbnail_frames(image): 

360 for frame in ImageSequence.Iterator(image): 

361 new_frame = frame.copy() 

362 new_frame.thumbnail(size) 

363 yield new_frame 

364 

365 frames = list(_thumbnail_frames(input_pil)) 

366 output = BytesIO() 

367 output_image = frames[0] 

368 output_image.save( 

369 output, 

370 format="gif", 

371 save_all=True, 

372 optimize=False, 

373 append_images=frames[1:], 

374 disposal=input_pil.disposal_method, 

375 **input_pil.info, 

376 ) 

377 return output