Coverage for api/roles/services.py: 93.83%

162 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2026-01-25 13:05 +0000

1from typing import Dict, List 

2from models.roles import Roles 

3from models.role_mapper import RoleMapper 

4from core.rbac import check_user_has_super_role 

5from sqlalchemy.ext.asyncio import AsyncSession 

6from sqlalchemy import select, func, delete, and_ 

7from models.role_attributes import RoleAttributes 

8from models.role_attributes_mapper import RoleAttributesMapper 

9from utils.custom_exception import ServerException, ConflictException, NotFoundException 

10from .schema import ( 

11 RoleResponse, RoleCreate, RoleUpdate, RolesListResponse, 

12 RoleAttributeMappingBatchResponse, AttributeMappingResult, 

13 RoleAttributesGroupedResponse, RoleAttributesGroup, RoleAttributeDetail, 

14 PermissionCheckResponse 

15) 

16 

17async def get_all_roles(db: AsyncSession) -> RolesListResponse: 

18 """Get all roles""" 

19 try: 

20 roles_query = select(Roles).order_by(Roles.name.asc()) 

21 roles_result = await db.execute(roles_query) 

22 roles = roles_result.scalars().all() 

23 

24 role_responses = [] 

25 for role in roles: 

26 role_response = RoleResponse( 

27 id=role.id, 

28 name=role.name, 

29 description=role.description 

30 ) 

31 role_responses.append(role_response) 

32 

33 return RolesListResponse(roles=role_responses) 

34 

35 except Exception as e: 

36 raise ServerException(f"Failed to retrieve roles: {str(e)}") 

37 

38async def create_role(db: AsyncSession, role_data: RoleCreate) -> RoleResponse: 

39 """Create a new role""" 

40 try: 

41 existing_role = await db.execute( 

42 select(Roles).where(Roles.name == role_data.name) 

43 ) 

44 if existing_role.scalar_one_or_none(): 

45 raise ConflictException("Role name already exists") 

46 

47 role = Roles( 

48 name=role_data.name, 

49 description=role_data.description 

50 ) 

51 db.add(role) 

52 await db.commit() 

53 await db.refresh(role) 

54 

55 return RoleResponse( 

56 id=role.id, 

57 name=role.name, 

58 description=role.description 

59 ) 

60 

61 except ConflictException: 

62 raise 

63 except Exception as e: 

64 raise ServerException(f"Failed to create role: {str(e)}") 

65 

66async def update_role(db: AsyncSession, role_id: str, role_data: RoleUpdate) -> RoleResponse: 

67 """Update role information""" 

68 try: 

69 role_result = await db.execute( 

70 select(Roles).where(Roles.id == role_id) 

71 ) 

72 role = role_result.scalar_one_or_none() 

73 if not role: 

74 raise NotFoundException("Role not found") 

75 

76 if role_data.name and role_data.name != role.name: 

77 existing_role = await db.execute( 

78 select(Roles).where(Roles.name == role_data.name, Roles.id != role_id) 

79 ) 

80 if existing_role.scalar_one_or_none(): 

81 raise ConflictException("Role name already exists") 

82 

83 update_data = role_data.model_dump(exclude_unset=True) 

84 for field, value in update_data.items(): 

85 setattr(role, field, value) 

86 

87 await db.commit() 

88 await db.refresh(role) 

89 

90 return RoleResponse( 

91 id=role.id, 

92 name=role.name, 

93 description=role.description 

94 ) 

95 

96 except (ConflictException, NotFoundException): 

97 raise 

98 except Exception as e: 

99 raise ServerException(f"Failed to update role: {str(e)}") 

100 

101async def delete_role(db: AsyncSession, role_id: str) -> bool: 

102 """Delete a role""" 

103 try: 

104 role_result = await db.execute( 

105 select(Roles).where(Roles.id == role_id) 

106 ) 

107 role = role_result.scalar_one_or_none() 

108 if not role: 

109 raise NotFoundException("Role not found") 

110 

111 user_count = await db.execute( 

112 select(func.count(RoleMapper.user_id)).where(RoleMapper.role_id == role_id) 

113 ) 

114 if user_count.scalar() > 0: 

115 raise ConflictException("Cannot delete role that is assigned to users") 

116 

117 await db.execute( 

118 delete(RoleAttributesMapper).where(RoleAttributesMapper.role_id == role_id) 

119 ) 

120 

121 await db.execute( 

122 delete(Roles).where(Roles.id == role_id) 

123 ) 

124 await db.commit() 

125 

126 return True 

127 

128 except (ConflictException, NotFoundException): 

129 raise 

130 except Exception as e: 

131 raise ServerException(f"Failed to delete role: {str(e)}") 

132 

133async def get_role_attribute_mapping(db: AsyncSession, role_id: str) -> RoleAttributesGroupedResponse: 

134 """Get role attributes mapping grouped by group and category (left join).""" 

135 try: 

136 role_result = await db.execute( 

137 select(Roles).where(Roles.id == role_id) 

138 ) 

139 role = role_result.scalar_one_or_none() 

140 if not role: 

141 raise NotFoundException("Role not found") 

142 

143 # Use LEFT JOIN to get all attributes and their mappings 

144 query = select( 

145 RoleAttributes.name, 

146 RoleAttributes.group, 

147 RoleAttributes.category, 

148 RoleAttributesMapper.value 

149 ).select_from( 

150 RoleAttributes 

151 ).outerjoin( 

152 RoleAttributesMapper, 

153 and_( 

154 RoleAttributes.id == RoleAttributesMapper.attributes_id, 

155 RoleAttributesMapper.role_id == role_id 

156 ) 

157 ).order_by(RoleAttributes.id) 

158 

159 result = await db.execute(query) 

160 rows = result.all() 

161 

162 grouped: Dict[str, Dict[str, List[RoleAttributeDetail]]] = {} 

163 for row in rows: 

164 group = row.group or "default" 

165 category = row.category or "uncategorized" 

166 grouped.setdefault(group, {}).setdefault(category, []).append( 

167 RoleAttributeDetail( 

168 name=row.name, 

169 value=row.value if row.value is not None else False, 

170 ) 

171 ) 

172 

173 groups = [] 

174 for group, categories in grouped.items(): 

175 groups.append( 

176 RoleAttributesGroup( 

177 group=group, 

178 categories=categories, 

179 ) 

180 ) 

181 

182 return RoleAttributesGroupedResponse(groups=groups) 

183 

184 except NotFoundException: 

185 raise 

186 except Exception as e: 

187 raise ServerException(f"Failed to get role attributes: {str(e)}") 

188 

189async def update_role_attribute_mapping(db: AsyncSession, role_id: str, attributes_data: Dict[str, bool]) -> RoleAttributeMappingBatchResponse: 

190 """Batch update role and attributes mapping with detailed results""" 

191 try: 

192 role_result = await db.execute( 

193 select(Roles).where(Roles.id == role_id) 

194 ) 

195 role = role_result.scalar_one_or_none() 

196 if not role: 

197 raise NotFoundException("Role not found") 

198 

199 results = [] 

200 success_count = 0 

201 failed_count = 0 

202 

203 # Get mapping from attribute names to IDs 

204 attribute_names = list(attributes_data.keys()) 

205 name_to_id_map = {} 

206 invalid_names = set() 

207 

208 if attribute_names: 

209 existing_attributes = await db.execute( 

210 select(RoleAttributes.id, RoleAttributes.name).where(RoleAttributes.name.in_(attribute_names)) 

211 ) 

212 for row in existing_attributes: 

213 name_to_id_map[row.name] = row.id 

214 

215 # Handle invalid attribute names 

216 invalid_names = set(attribute_names) - set(name_to_id_map.keys()) 

217 for invalid_name in invalid_names: 

218 results.append(AttributeMappingResult( 

219 attribute_id=invalid_name, # Keep name for error reporting 

220 status="failed", 

221 message="Invalid attribute name" 

222 )) 

223 failed_count += 1 

224 

225 # Process valid attributes 

226 for attribute_name, value in attributes_data.items(): 

227 if attribute_name in invalid_names: 

228 continue 

229 

230 attribute_id = name_to_id_map.get(attribute_name) 

231 if not attribute_id: 

232 continue 

233 

234 try: 

235 existing_mapping = await db.execute( 

236 select(RoleAttributesMapper).where( 

237 and_( 

238 RoleAttributesMapper.role_id == role_id, 

239 RoleAttributesMapper.attributes_id == attribute_id 

240 ) 

241 ) 

242 ) 

243 mapping = existing_mapping.scalar_one_or_none() 

244 

245 if mapping: 

246 # Update existing mapping 

247 mapping.value = value 

248 else: 

249 # Create new mapping 

250 new_mapping = RoleAttributesMapper( 

251 role_id=role_id, 

252 attributes_id=attribute_id, 

253 value=value 

254 ) 

255 db.add(new_mapping) 

256 

257 results.append(AttributeMappingResult( 

258 attribute_id=attribute_name, 

259 status="success", 

260 message="Updated successfully" 

261 )) 

262 success_count += 1 

263 

264 except Exception as e: 

265 results.append(AttributeMappingResult( 

266 attribute_id=attribute_name, 

267 status="failed", 

268 message=f"Failed to process: {str(e)}" 

269 )) 

270 failed_count += 1 

271 

272 await db.commit() 

273 

274 return RoleAttributeMappingBatchResponse( 

275 results=results, 

276 total_attributes=len(attributes_data), 

277 success_count=success_count, 

278 failed_count=failed_count 

279 ) 

280 

281 except NotFoundException: 

282 raise 

283 except Exception as e: 

284 raise ServerException(f"Failed to update role attributes mapping: {str(e)}") 

285 

286async def check_user_permissions( 

287 db: AsyncSession, 

288 user_id: str, 

289 required_attributes: List[str] = None 

290) -> PermissionCheckResponse: 

291 """Check if user has required permission attributes.""" 

292 try: 

293 # Get all available attributes 

294 all_attributes_result = await db.execute(select(RoleAttributes.name)) 

295 all_attributes = [row[0] for row in all_attributes_result.fetchall()] 

296 

297 if await check_user_has_super_role(user_id, db): 

298 if required_attributes: 

299 permissions = {attr: True for attr in required_attributes} 

300 else: 

301 permissions = {attr: True for attr in all_attributes} 

302 return PermissionCheckResponse(permissions=permissions) 

303 

304 user_role_query = select(RoleMapper.role_id).where(RoleMapper.user_id == user_id) 

305 user_role_result = await db.execute(user_role_query) 

306 user_role_id = user_role_result.scalar_one_or_none() 

307 

308 user_attributes_set = set() 

309 if user_role_id: 

310 attributes_query = select(RoleAttributes.name).join( 

311 RoleAttributesMapper, 

312 RoleAttributes.id == RoleAttributesMapper.attributes_id 

313 ).where( 

314 and_( 

315 RoleAttributesMapper.role_id == user_role_id, 

316 RoleAttributesMapper.value == True 

317 ) 

318 ) 

319 attributes_result = await db.execute(attributes_query) 

320 user_attributes_set = {row[0] for row in attributes_result.fetchall()} 

321 

322 if not required_attributes: 

323 permissions = {attr: attr in user_attributes_set for attr in all_attributes} 

324 else: 

325 permissions = {attr: attr in user_attributes_set for attr in required_attributes} 

326 

327 return PermissionCheckResponse(permissions=permissions) 

328 

329 except Exception as e: 

330 raise ServerException(f"Failed to check user permissions: {str(e)}")